Merge branch 'main' into patch-1

This commit is contained in:
Kristoffer Dalby 2021-11-15 23:00:45 +00:00 committed by GitHub
commit bd7b5e97cb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
66 changed files with 2981 additions and 1869 deletions

View file

@ -18,12 +18,11 @@ jobs:
- name: Setup Go
uses: actions/setup-go@v2
with:
go-version: "1.16.3"
go-version: "1.17.3"
- name: Install dependencies
run: |
go version
go install golang.org/x/lint/golint@latest
sudo apt update
sudo apt install -y make

View file

@ -1,20 +1,37 @@
---
name: CI
on: [push, pull_request]
jobs:
# The "build" workflow
lint:
# The type of runner that the job will run on
golangci-lint:
runs-on: ubuntu-latest
# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2
# Install and run golangci-lint as a separate step, it's much faster this
# way because this action has caching. It'll get run again in `make lint`
# below, but it's still much faster in the end than installing
# golangci-lint manually in the `Run lint` step.
- uses: golangci/golangci-lint-action@v2
- name: golangci-lint
uses: golangci/golangci-lint-action@v2
with:
version: latest
prettier-lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Prettify code
uses: creyD/prettier_action@v4.0
with:
prettier_options: >-
--check **/*.{ts,js,md,yaml,yml,sass,css,scss,html}
only_changed: false
dry: true
proto-lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: bufbuild/buf-setup-action@v0.7.0
- uses: bufbuild/buf-lint-action@v1
with:
input: "proto"

View file

@ -1,7 +1,53 @@
---
run:
timeout: 5m
timeout: 10m
issues:
skip-dirs:
- gen
linters:
enable-all: true
disable:
- exhaustivestruct
- revive
- lll
- interfacer
- scopelint
- maligned
- golint
- gofmt
- gochecknoglobals
- gochecknoinits
- gocognit
- funlen
- exhaustivestruct
- tagliatelle
- godox
- ireturn
# In progress
- gocritic
# We should strive to enable these:
- wrapcheck
- dupl
- makezero
# We might want to enable this, but it might be a lot of work
- cyclop
- nestif
- wsl # might be incompatible with gofumpt
- testpackage
- paralleltest
linters-settings:
varnamelen:
ignore-type-assert-ok: true
ignore-map-index-ok: true
ignore-names:
- err
- db
- id
- ip
- ok
- c

View file

@ -1,6 +1,14 @@
# Calculate version
version = $(shell ./scripts/version-at-commit.sh)
rwildcard=$(foreach d,$(wildcard $1*),$(call rwildcard,$d/,$2) $(filter $(subst *,%,$2),$d))
# GO_SOURCES = $(wildcard *.go)
# PROTO_SOURCES = $(wildcard **/*.proto)
GO_SOURCES = $(call rwildcard,,*.go)
PROTO_SOURCES = $(call rwildcard,,*.proto)
build:
go build -ldflags "-s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=$(version)" cmd/headscale/headscale.go
@ -19,7 +27,12 @@ coverprofile_html:
go tool cover -html=coverage.out
lint:
golangci-lint run --fix
golangci-lint run --fix --timeout 10m
fmt:
prettier --write '**/**.{ts,js,md,yaml,yml,sass,css,scss,html}'
golines --max-len=88 --base-formatter=gofumpt -w $(GO_SOURCES)
clang-format -style="{BasedOnStyle: Google, IndentWidth: 4, AlignConsecutiveDeclarations: true, AlignConsecutiveAssignments: true, ColumnLimit: 0}" -i $(PROTO_SOURCES)
proto-lint:
cd proto/ && buf lint

View file

@ -54,7 +54,6 @@ Suggestions/PRs welcomed!
Please have a look at the documentation under [`docs/`](docs/).
## Disclaimer
1. We have nothing to do with Tailscale, or Tailscale Inc.
@ -64,6 +63,23 @@ Please have a look at the documentation under [`docs/`](docs/).
To contribute to Headscale you would need the lastest version of [Go](https://golang.org) and [Buf](https://buf.build)(Protobuf generator).
### Code style
To ensure we have some consistency with a growing number of contributes, this project has adopted linting and style/formatting rules:
The **Go** code is linted with [`golangci-lint`](https://golangci-lint.run) and
formatted with [`golines`](https://github.com/segmentio/golines) (width 88) and
[`gofumpt`](https://github.com/mvdan/gofumpt).
Please configure your editor to run the tools while developing and make sure to
run `make lint` and `make fmt` before committing any code.
The **Proto** code is linted with [`buf`](https://docs.buf.build/lint/overview) and
formatted with [`clang-format`](https://clang.llvm.org/docs/ClangFormat.html).
The **rest** (markdown, yaml, etc) is formatted with [`prettier`](https://prettier.io).
Check out the `.golangci.yaml` and `Makefile` to see the specific configuration.
### Install development tools
- Go
@ -81,6 +97,7 @@ Some parts of the project requires the generation of Go code from Protobuf (if c
```shell
make generate
```
**Note**: Please check in changes from `gen/` in a separate commit to make it easier to review.
To run the tests:
@ -261,5 +278,3 @@ make build
</td>
</tr>
</table>

149
acls.go
View file

@ -9,23 +9,30 @@ import (
"strings"
"github.com/rs/zerolog/log"
"github.com/tailscale/hujson"
"inet.af/netaddr"
"tailscale.com/tailcfg"
)
const (
errorEmptyPolicy = Error("empty policy")
errorInvalidAction = Error("invalid action")
errorInvalidUserSection = Error("invalid user section")
errorInvalidGroup = Error("invalid group")
errorInvalidTag = Error("invalid tag")
errorInvalidNamespace = Error("invalid namespace")
errorInvalidPortFormat = Error("invalid port format")
errEmptyPolicy = Error("empty policy")
errInvalidAction = Error("invalid action")
errInvalidUserSection = Error("invalid user section")
errInvalidGroup = Error("invalid group")
errInvalidTag = Error("invalid tag")
errInvalidNamespace = Error("invalid namespace")
errInvalidPortFormat = Error("invalid port format")
)
// LoadACLPolicy loads the ACL policy from the specify path, and generates the ACL rules
const (
Base10 = 10
BitSize16 = 16
portRangeBegin = 0
portRangeEnd = 65535
expectedTokenItems = 2
)
// LoadACLPolicy loads the ACL policy from the specify path, and generates the ACL rules.
func (h *Headscale) LoadACLPolicy(path string) error {
policyFile, err := os.Open(path)
if err != nil {
@ -34,23 +41,23 @@ func (h *Headscale) LoadACLPolicy(path string) error {
defer policyFile.Close()
var policy ACLPolicy
b, err := io.ReadAll(policyFile)
policyBytes, err := io.ReadAll(policyFile)
if err != nil {
return err
}
ast, err := hujson.Parse(b)
ast, err := hujson.Parse(policyBytes)
if err != nil {
return err
}
ast.Standardize()
b = ast.Pack()
err = json.Unmarshal(b, &policy)
policyBytes = ast.Pack()
err = json.Unmarshal(policyBytes, &policy)
if err != nil {
return err
}
if policy.IsZero() {
return errorEmptyPolicy
return errEmptyPolicy
}
h.aclPolicy = &policy
@ -59,37 +66,40 @@ func (h *Headscale) LoadACLPolicy(path string) error {
return err
}
h.aclRules = rules
return nil
}
func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
rules := []tailcfg.FilterRule{}
for i, a := range h.aclPolicy.ACLs {
if a.Action != "accept" {
return nil, errorInvalidAction
for index, acl := range h.aclPolicy.ACLs {
if acl.Action != "accept" {
return nil, errInvalidAction
}
r := tailcfg.FilterRule{}
filterRule := tailcfg.FilterRule{}
srcIPs := []string{}
for j, u := range a.Users {
srcs, err := h.generateACLPolicySrcIP(u)
for innerIndex, user := range acl.Users {
srcs, err := h.generateACLPolicySrcIP(user)
if err != nil {
log.Error().
Msgf("Error parsing ACL %d, User %d", i, j)
Msgf("Error parsing ACL %d, User %d", index, innerIndex)
return nil, err
}
srcIPs = append(srcIPs, srcs...)
}
r.SrcIPs = srcIPs
filterRule.SrcIPs = srcIPs
destPorts := []tailcfg.NetPortRange{}
for j, d := range a.Ports {
dests, err := h.generateACLPolicyDestPorts(d)
for innerIndex, ports := range acl.Ports {
dests, err := h.generateACLPolicyDestPorts(ports)
if err != nil {
log.Error().
Msgf("Error parsing ACL %d, Port %d", i, j)
Msgf("Error parsing ACL %d, Port %d", index, innerIndex)
return nil, err
}
destPorts = append(destPorts, dests...)
@ -108,10 +118,12 @@ func (h *Headscale) generateACLPolicySrcIP(u string) ([]string, error) {
return h.expandAlias(u)
}
func (h *Headscale) generateACLPolicyDestPorts(d string) ([]tailcfg.NetPortRange, error) {
func (h *Headscale) generateACLPolicyDestPorts(
d string,
) ([]tailcfg.NetPortRange, error) {
tokens := strings.Split(d, ":")
if len(tokens) < 2 || len(tokens) > 3 {
return nil, errorInvalidPortFormat
if len(tokens) < expectedTokenItems || len(tokens) > 3 {
return nil, errInvalidPortFormat
}
var alias string
@ -121,7 +133,7 @@ func (h *Headscale) generateACLPolicyDestPorts(d string) ([]tailcfg.NetPortRange
// tag:montreal-webserver:80,443
// tag:api-server:443
// example-host-1:*
if len(tokens) == 2 {
if len(tokens) == expectedTokenItems {
alias = tokens[0]
} else {
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
@ -146,34 +158,36 @@ func (h *Headscale) generateACLPolicyDestPorts(d string) ([]tailcfg.NetPortRange
dests = append(dests, pr)
}
}
return dests, nil
}
func (h *Headscale) expandAlias(s string) ([]string, error) {
if s == "*" {
func (h *Headscale) expandAlias(alias string) ([]string, error) {
if alias == "*" {
return []string{"*"}, nil
}
if strings.HasPrefix(s, "group:") {
if _, ok := h.aclPolicy.Groups[s]; !ok {
return nil, errorInvalidGroup
if strings.HasPrefix(alias, "group:") {
if _, ok := h.aclPolicy.Groups[alias]; !ok {
return nil, errInvalidGroup
}
ips := []string{}
for _, n := range h.aclPolicy.Groups[s] {
for _, n := range h.aclPolicy.Groups[alias] {
nodes, err := h.ListMachinesInNamespace(n)
if err != nil {
return nil, errorInvalidNamespace
return nil, errInvalidNamespace
}
for _, node := range nodes {
ips = append(ips, node.IPAddress)
}
}
return ips, nil
}
if strings.HasPrefix(s, "tag:") {
if _, ok := h.aclPolicy.TagOwners[s]; !ok {
return nil, errorInvalidTag
if strings.HasPrefix(alias, "tag:") {
if _, ok := h.aclPolicy.TagOwners[alias]; !ok {
return nil, errInvalidTag
}
// This will have HORRIBLE performance.
@ -183,10 +197,10 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
return nil, err
}
ips := []string{}
for _, m := range machines {
for _, machine := range machines {
hostinfo := tailcfg.Hostinfo{}
if len(m.HostInfo) != 0 {
hi, err := m.HostInfo.MarshalJSON()
if len(machine.HostInfo) != 0 {
hi, err := machine.HostInfo.MarshalJSON()
if err != nil {
return nil, err
}
@ -197,17 +211,19 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
// FIXME: Check TagOwners allows this
for _, t := range hostinfo.RequestTags {
if s[4:] == t {
ips = append(ips, m.IPAddress)
if alias[4:] == t {
ips = append(ips, machine.IPAddress)
break
}
}
}
}
return ips, nil
}
n, err := h.GetNamespace(s)
n, err := h.GetNamespace(alias)
if err == nil {
nodes, err := h.ListMachinesInNamespace(n.Name)
if err != nil {
@ -217,49 +233,54 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
for _, n := range nodes {
ips = append(ips, n.IPAddress)
}
return ips, nil
}
if h, ok := h.aclPolicy.Hosts[s]; ok {
if h, ok := h.aclPolicy.Hosts[alias]; ok {
return []string{h.String()}, nil
}
ip, err := netaddr.ParseIP(s)
ip, err := netaddr.ParseIP(alias)
if err == nil {
return []string{ip.String()}, nil
}
cidr, err := netaddr.ParseIPPrefix(s)
cidr, err := netaddr.ParseIPPrefix(alias)
if err == nil {
return []string{cidr.String()}, nil
}
return nil, errorInvalidUserSection
return nil, errInvalidUserSection
}
func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) {
if s == "*" {
return &[]tailcfg.PortRange{{First: 0, Last: 65535}}, nil
func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
if portsStr == "*" {
return &[]tailcfg.PortRange{
{First: portRangeBegin, Last: portRangeEnd},
}, nil
}
ports := []tailcfg.PortRange{}
for _, p := range strings.Split(s, ",") {
rang := strings.Split(p, "-")
if len(rang) == 1 {
pi, err := strconv.ParseUint(rang[0], 10, 16)
for _, portStr := range strings.Split(portsStr, ",") {
rang := strings.Split(portStr, "-")
switch len(rang) {
case 1:
port, err := strconv.ParseUint(rang[0], Base10, BitSize16)
if err != nil {
return nil, err
}
ports = append(ports, tailcfg.PortRange{
First: uint16(pi),
Last: uint16(pi),
First: uint16(port),
Last: uint16(port),
})
} else if len(rang) == 2 {
start, err := strconv.ParseUint(rang[0], 10, 16)
case expectedTokenItems:
start, err := strconv.ParseUint(rang[0], Base10, BitSize16)
if err != nil {
return nil, err
}
last, err := strconv.ParseUint(rang[1], 10, 16)
last, err := strconv.ParseUint(rang[1], Base10, BitSize16)
if err != nil {
return nil, err
}
@ -267,9 +288,11 @@ func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) {
First: uint16(start),
Last: uint16(last),
})
} else {
return nil, errorInvalidPortFormat
default:
return nil, errInvalidPortFormat
}
}
return &ports, nil
}

View file

@ -5,54 +5,58 @@ import (
)
func (s *Suite) TestWrongPath(c *check.C) {
err := h.LoadACLPolicy("asdfg")
err := app.LoadACLPolicy("asdfg")
c.Assert(err, check.NotNil)
}
func (s *Suite) TestBrokenHuJson(c *check.C) {
err := h.LoadACLPolicy("./tests/acls/broken.hujson")
err := app.LoadACLPolicy("./tests/acls/broken.hujson")
c.Assert(err, check.NotNil)
}
func (s *Suite) TestInvalidPolicyHuson(c *check.C) {
err := h.LoadACLPolicy("./tests/acls/invalid.hujson")
err := app.LoadACLPolicy("./tests/acls/invalid.hujson")
c.Assert(err, check.NotNil)
c.Assert(err, check.Equals, errorEmptyPolicy)
c.Assert(err, check.Equals, errEmptyPolicy)
}
func (s *Suite) TestParseHosts(c *check.C) {
var hs Hosts
err := hs.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100","example-host-2": "100.100.101.100/24"}`))
c.Assert(hs, check.NotNil)
var hosts Hosts
err := hosts.UnmarshalJSON(
[]byte(
`{"example-host-1": "100.100.100.100","example-host-2": "100.100.101.100/24"}`,
),
)
c.Assert(hosts, check.NotNil)
c.Assert(err, check.IsNil)
}
func (s *Suite) TestParseInvalidCIDR(c *check.C) {
var hs Hosts
err := hs.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100/42"}`))
c.Assert(hs, check.IsNil)
var hosts Hosts
err := hosts.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100/42"}`))
c.Assert(hosts, check.IsNil)
c.Assert(err, check.NotNil)
}
func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
err := h.LoadACLPolicy("./tests/acls/acl_policy_invalid.hujson")
err := app.LoadACLPolicy("./tests/acls/acl_policy_invalid.hujson")
c.Assert(err, check.NotNil)
}
func (s *Suite) TestBasicRule(c *check.C) {
err := h.LoadACLPolicy("./tests/acls/acl_policy_basic_1.hujson")
err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_1.hujson")
c.Assert(err, check.IsNil)
rules, err := h.generateACLRules()
rules, err := app.generateACLRules()
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
}
func (s *Suite) TestPortRange(c *check.C) {
err := h.LoadACLPolicy("./tests/acls/acl_policy_basic_range.hujson")
err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_range.hujson")
c.Assert(err, check.IsNil)
rules, err := h.generateACLRules()
rules, err := app.generateACLRules()
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
@ -63,10 +67,10 @@ func (s *Suite) TestPortRange(c *check.C) {
}
func (s *Suite) TestPortWildcard(c *check.C) {
err := h.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.hujson")
err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.hujson")
c.Assert(err, check.IsNil)
rules, err := h.generateACLRules()
rules, err := app.generateACLRules()
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
@ -79,33 +83,35 @@ func (s *Suite) TestPortWildcard(c *check.C) {
}
func (s *Suite) TestPortNamespace(c *check.C) {
n, err := h.CreateNamespace("testnamespace")
namespace, err := app.CreateNamespace("testnamespace")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine("testnamespace", "testmachine")
_, err = app.GetMachine("testnamespace", "testmachine")
c.Assert(err, check.NotNil)
ip, _ := h.getAvailableIP()
m := Machine{
ip, _ := app.getAvailableIP()
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: ip.String(),
AuthKeyID: uint(pak.ID),
}
h.db.Save(&m)
app.db.Save(&machine)
err = h.LoadACLPolicy("./tests/acls/acl_policy_basic_namespace_as_user.hujson")
err = app.LoadACLPolicy(
"./tests/acls/acl_policy_basic_namespace_as_user.hujson",
)
c.Assert(err, check.IsNil)
rules, err := h.generateACLRules()
rules, err := app.generateACLRules()
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
@ -119,33 +125,33 @@ func (s *Suite) TestPortNamespace(c *check.C) {
}
func (s *Suite) TestPortGroup(c *check.C) {
n, err := h.CreateNamespace("testnamespace")
namespace, err := app.CreateNamespace("testnamespace")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine("testnamespace", "testmachine")
_, err = app.GetMachine("testnamespace", "testmachine")
c.Assert(err, check.NotNil)
ip, _ := h.getAvailableIP()
m := Machine{
ip, _ := app.getAvailableIP()
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: ip.String(),
AuthKeyID: uint(pak.ID),
}
h.db.Save(&m)
app.db.Save(&machine)
err = h.LoadACLPolicy("./tests/acls/acl_policy_basic_groups.hujson")
err = app.LoadACLPolicy("./tests/acls/acl_policy_basic_groups.hujson")
c.Assert(err, check.IsNil)
rules, err := h.generateACLRules()
rules, err := app.generateACLRules()
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)

View file

@ -8,7 +8,7 @@ import (
"inet.af/netaddr"
)
// ACLPolicy represents a Tailscale ACL Policy
// ACLPolicy represents a Tailscale ACL Policy.
type ACLPolicy struct {
Groups Groups `json:"Groups"`
Hosts Hosts `json:"Hosts"`
@ -17,61 +17,63 @@ type ACLPolicy struct {
Tests []ACLTest `json:"Tests"`
}
// ACL is a basic rule for the ACL Policy
// ACL is a basic rule for the ACL Policy.
type ACL struct {
Action string `json:"Action"`
Users []string `json:"Users"`
Ports []string `json:"Ports"`
}
// Groups references a series of alias in the ACL rules
// Groups references a series of alias in the ACL rules.
type Groups map[string][]string
// Hosts are alias for IP addresses or subnets
// Hosts are alias for IP addresses or subnets.
type Hosts map[string]netaddr.IPPrefix
// TagOwners specify what users (namespaces?) are allow to use certain tags
// TagOwners specify what users (namespaces?) are allow to use certain tags.
type TagOwners map[string][]string
// ACLTest is not implemented, but should be use to check if a certain rule is allowed
// ACLTest is not implemented, but should be use to check if a certain rule is allowed.
type ACLTest struct {
User string `json:"User"`
Allow []string `json:"Allow"`
Deny []string `json:"Deny,omitempty"`
}
// UnmarshalJSON allows to parse the Hosts directly into netaddr objects
func (h *Hosts) UnmarshalJSON(data []byte) error {
hosts := Hosts{}
hs := make(map[string]string)
// UnmarshalJSON allows to parse the Hosts directly into netaddr objects.
func (hosts *Hosts) UnmarshalJSON(data []byte) error {
newHosts := Hosts{}
hostIPPrefixMap := make(map[string]string)
ast, err := hujson.Parse(data)
if err != nil {
return err
}
ast.Standardize()
data = ast.Pack()
err = json.Unmarshal(data, &hs)
err = json.Unmarshal(data, &hostIPPrefixMap)
if err != nil {
return err
}
for k, v := range hs {
if !strings.Contains(v, "/") {
v = v + "/32"
for host, prefixStr := range hostIPPrefixMap {
if !strings.Contains(prefixStr, "/") {
prefixStr += "/32"
}
prefix, err := netaddr.ParseIPPrefix(v)
prefix, err := netaddr.ParseIPPrefix(prefixStr)
if err != nil {
return err
}
hosts[k] = prefix
newHosts[host] = prefix
}
*h = hosts
*hosts = newHosts
return nil
}
// IsZero is perhaps a bit naive here
func (p ACLPolicy) IsZero() bool {
if len(p.Groups) == 0 && len(p.Hosts) == 0 && len(p.ACLs) == 0 {
// IsZero is perhaps a bit naive here.
func (policy ACLPolicy) IsZero() bool {
if len(policy.Groups) == 0 && len(policy.Hosts) == 0 && len(policy.ACLs) == 0 {
return true
}
return false
}

297
api.go
View file

@ -10,31 +10,37 @@ import (
"strings"
"time"
"github.com/rs/zerolog/log"
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/wgkey"
)
const reservedResponseHeaderSize = 4
// KeyHandler provides the Headscale pub key
// Listens in /key
func (h *Headscale) KeyHandler(c *gin.Context) {
c.Data(200, "text/plain; charset=utf-8", []byte(h.publicKey.HexString()))
// Listens in /key.
func (h *Headscale) KeyHandler(ctx *gin.Context) {
ctx.Data(
http.StatusOK,
"text/plain; charset=utf-8",
[]byte(h.publicKey.HexString()),
)
}
// RegisterWebAPI shows a simple message in the browser to point to the CLI
// Listens in /register
func (h *Headscale) RegisterWebAPI(c *gin.Context) {
mKeyStr := c.Query("key")
if mKeyStr == "" {
c.String(http.StatusBadRequest, "Wrong params")
// Listens in /register.
func (h *Headscale) RegisterWebAPI(ctx *gin.Context) {
machineKeyStr := ctx.Query("key")
if machineKeyStr == "" {
ctx.String(http.StatusBadRequest, "Wrong params")
return
}
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
<html>
<body>
<h1>headscale</h1>
@ -51,43 +57,45 @@ func (h *Headscale) RegisterWebAPI(c *gin.Context) {
</body>
</html>
`, mKeyStr)))
`, machineKeyStr)))
}
// RegistrationHandler handles the actual registration process of a machine
// Endpoint /machine/:id
func (h *Headscale) RegistrationHandler(c *gin.Context) {
body, _ := io.ReadAll(c.Request.Body)
mKeyStr := c.Param("id")
mKey, err := wgkey.ParseHex(mKeyStr)
// Endpoint /machine/:id.
func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
body, _ := io.ReadAll(ctx.Request.Body)
machineKeyStr := ctx.Param("id")
machineKey, err := wgkey.ParseHex(machineKeyStr)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot parse machine key")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
c.String(http.StatusInternalServerError, "Sad!")
ctx.String(http.StatusInternalServerError, "Sad!")
return
}
req := tailcfg.RegisterRequest{}
err = decode(body, &req, &mKey, h.privateKey)
err = decode(body, &req, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot decode message")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
c.String(http.StatusInternalServerError, "Very sad!")
ctx.String(http.StatusInternalServerError, "Very sad!")
return
}
now := time.Now().UTC()
m, err := h.GetMachineByMachineKey(mKey.HexString())
machine, err := h.GetMachineByMachineKey(machineKey.HexString())
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
newMachine := Machine{
Expiry: &time.Time{},
MachineKey: mKey.HexString(),
MachineKey: machineKey.HexString(),
Name: req.Hostinfo.Hostname,
}
if err := h.db.Create(&newMachine).Error; err != nil {
@ -95,88 +103,96 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Str("handler", "Registration").
Err(err).
Msg("Could not create row")
machineRegistrations.WithLabelValues("unknown", "web", "error", m.Namespace.Name).Inc()
machineRegistrations.WithLabelValues("unknown", "web", "error", machine.Namespace.Name).
Inc()
return
}
m = &newMachine
machine = &newMachine
}
if !m.Registered && req.Auth.AuthKey != "" {
h.handleAuthKey(c, h.db, mKey, req, *m)
if !machine.Registered && req.Auth.AuthKey != "" {
h.handleAuthKey(ctx, h.db, machineKey, req, *machine)
return
}
resp := tailcfg.RegisterResponse{}
// We have the updated key!
if m.NodeKey == wgkey.Key(req.NodeKey).HexString() {
if machine.NodeKey == wgkey.Key(req.NodeKey).HexString() {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
log.Info().
Str("handler", "Registration").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Client requested logout")
m.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired
h.db.Save(&m)
machine.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired
h.db.Save(&machine)
resp.AuthURL = ""
resp.MachineAuthorized = false
resp.User = *m.Namespace.toUser()
respBody, err := encode(resp, &mKey, h.privateKey)
resp.User = *machine.Namespace.toUser()
respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "")
ctx.String(http.StatusInternalServerError, "")
return
}
c.Data(200, "application/json; charset=utf-8", respBody)
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
return
}
if m.Registered && m.Expiry.UTC().After(now) {
if machine.Registered && machine.Expiry.UTC().After(now) {
// The machine registration is valid, respond with redirect to /map
log.Debug().
Str("handler", "Registration").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Client is registered and we have the current NodeKey. All clear to /map")
resp.AuthURL = ""
resp.MachineAuthorized = true
resp.User = *m.Namespace.toUser()
resp.Login = *m.Namespace.toLogin()
resp.User = *machine.Namespace.toUser()
resp.Login = *machine.Namespace.toLogin()
respBody, err := encode(resp, &mKey, h.privateKey)
respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
machineRegistrations.WithLabelValues("update", "web", "error", m.Namespace.Name).Inc()
c.String(http.StatusInternalServerError, "")
machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name).
Inc()
ctx.String(http.StatusInternalServerError, "")
return
}
machineRegistrations.WithLabelValues("update", "web", "success", m.Namespace.Name).Inc()
c.Data(200, "application/json; charset=utf-8", respBody)
machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name).
Inc()
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
return
}
// The client has registered before, but has expired
log.Debug().
Str("handler", "Registration").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Machine registration has expired. Sending a authurl to register")
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
}
// When a client connects, it may request a specific expiry time in its
@ -185,102 +201,120 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
// into two steps (which cant pass arbitrary data between them easily) and needs to be
// retrieved again after the user has authenticated. After the authentication flow
// completes, RequestedExpiry is copied into Expiry.
m.RequestedExpiry = &req.Expiry
machine.RequestedExpiry = &req.Expiry
h.db.Save(&m)
h.db.Save(&machine)
respBody, err := encode(resp, &mKey, h.privateKey)
respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
machineRegistrations.WithLabelValues("new", "web", "error", m.Namespace.Name).Inc()
c.String(http.StatusInternalServerError, "")
machineRegistrations.WithLabelValues("new", "web", "error", machine.Namespace.Name).
Inc()
ctx.String(http.StatusInternalServerError, "")
return
}
machineRegistrations.WithLabelValues("new", "web", "success", m.Namespace.Name).Inc()
c.Data(200, "application/json; charset=utf-8", respBody)
machineRegistrations.WithLabelValues("new", "web", "success", machine.Namespace.Name).
Inc()
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
return
}
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && m.Expiry.UTC().After(now) {
if machine.NodeKey == wgkey.Key(req.OldNodeKey).HexString() &&
machine.Expiry.UTC().After(now) {
log.Debug().
Str("handler", "Registration").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("We have the OldNodeKey in the database. This is a key refresh")
m.NodeKey = wgkey.Key(req.NodeKey).HexString()
h.db.Save(&m)
machine.NodeKey = wgkey.Key(req.NodeKey).HexString()
h.db.Save(&machine)
resp.AuthURL = ""
resp.User = *m.Namespace.toUser()
respBody, err := encode(resp, &mKey, h.privateKey)
resp.User = *machine.Namespace.toUser()
respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "Extremely sad!")
ctx.String(http.StatusInternalServerError, "Extremely sad!")
return
}
c.Data(200, "application/json; charset=utf-8", respBody)
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
return
}
// The machine registration is new, redirect the client to the registration URL
log.Debug().
Str("handler", "Registration").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("The node is sending us a new NodeKey, sending auth url")
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
resp.AuthURL = fmt.Sprintf(
"%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
machineKey.HexString(),
)
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
}
// save the requested expiry time for retrieval later in the authentication flow
m.RequestedExpiry = &req.Expiry
m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey
h.db.Save(&m)
machine.RequestedExpiry = &req.Expiry
machine.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey
h.db.Save(&machine)
respBody, err := encode(resp, &mKey, h.privateKey)
respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "")
ctx.String(http.StatusInternalServerError, "")
return
}
c.Data(200, "application/json; charset=utf-8", respBody)
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
}
func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Machine) ([]byte, error) {
func (h *Headscale) getMapResponse(
machineKey wgkey.Key,
req tailcfg.MapRequest,
machine *Machine,
) ([]byte, error) {
log.Trace().
Str("func", "getMapResponse").
Str("machine", req.Hostinfo.Hostname).
Msg("Creating Map response")
node, err := m.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil {
log.Error().
Str("func", "getMapResponse").
Err(err).
Msg("Cannot convert to node")
return nil, err
}
peers, err := h.getPeers(m)
peers, err := h.getPeers(machine)
if err != nil {
log.Error().
Str("func", "getMapResponse").
Err(err).
Msg("Cannot fetch peers")
return nil, err
}
profiles := getMapResponseUserProfiles(*m, peers)
profiles := getMapResponseUserProfiles(*machine, peers)
nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil {
@ -288,17 +322,16 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma
Str("func", "getMapResponse").
Err(err).
Msg("Failed to convert peers to Tailscale nodes")
return nil, err
}
dnsConfig, err := getMapResponseDNSConfig(h.cfg.DNSConfig, h.cfg.BaseDomain, *m, peers)
if err != nil {
log.Error().
Str("func", "getMapResponse").
Err(err).
Msg("Failed generate the DNSConfig")
return nil, err
}
dnsConfig := getMapResponseDNSConfig(
h.cfg.DNSConfig,
h.cfg.BaseDomain,
*machine,
peers,
)
resp := tailcfg.MapResponse{
KeepAlive: false,
@ -323,66 +356,71 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma
encoder, _ := zstd.NewWriter(nil)
srcCompressed := encoder.EncodeAll(src, nil)
respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey)
respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
} else {
respBody, err = encode(resp, &mKey, h.privateKey)
respBody, err = encode(resp, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
}
// declare the incoming size on the first 4 bytes
data := make([]byte, 4)
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
return data, nil
}
func (h *Headscale) getMapKeepAliveResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Machine) ([]byte, error) {
resp := tailcfg.MapResponse{
func (h *Headscale) getMapKeepAliveResponse(
machineKey wgkey.Key,
mapRequest tailcfg.MapRequest,
) ([]byte, error) {
mapResponse := tailcfg.MapResponse{
KeepAlive: true,
}
var respBody []byte
var err error
if req.Compress == "zstd" {
src, _ := json.Marshal(resp)
if mapRequest.Compress == "zstd" {
src, _ := json.Marshal(mapResponse)
encoder, _ := zstd.NewWriter(nil)
srcCompressed := encoder.EncodeAll(src, nil)
respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey)
respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
} else {
respBody, err = encode(resp, &mKey, h.privateKey)
respBody, err = encode(mapResponse, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
}
data := make([]byte, 4)
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
return data, nil
}
func (h *Headscale) handleAuthKey(
c *gin.Context,
ctx *gin.Context,
db *gorm.DB,
idKey wgkey.Key,
req tailcfg.RegisterRequest,
m Machine,
reqisterRequest tailcfg.RegisterRequest,
machine Machine,
) {
log.Debug().
Str("func", "handleAuthKey").
Str("machine", req.Hostinfo.Hostname).
Msgf("Processing auth key for %s", req.Hostinfo.Hostname)
Str("machine", reqisterRequest.Hostinfo.Hostname).
Msgf("Processing auth key for %s", reqisterRequest.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{}
pak, err := h.checkKeyValidity(req.Auth.AuthKey)
pak, err := h.checkKeyValidity(reqisterRequest.Auth.AuthKey)
if err != nil {
log.Error().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Str("machine", machine.Name).
Err(err).
Msg("Failed authentication via AuthKey")
resp.MachineAuthorized = false
@ -390,48 +428,56 @@ func (h *Headscale) handleAuthKey(
if err != nil {
log.Error().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Str("machine", machine.Name).
Err(err).
Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc()
ctx.String(http.StatusInternalServerError, "")
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
Inc()
return
}
c.Data(401, "application/json; charset=utf-8", respBody)
ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody)
log.Error().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Failed authentication via AuthKey")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc()
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
Inc()
return
}
log.Debug().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Authentication key was valid, proceeding to acquire an IP address")
ip, err := h.getAvailableIP()
if err != nil {
log.Error().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Failed to find an available IP")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc()
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
Inc()
return
}
log.Info().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("ip", ip.String()).
Msgf("Assigning %s to %s", ip, m.Name)
Msgf("Assigning %s to %s", ip, machine.Name)
m.AuthKeyID = uint(pak.ID)
m.IPAddress = ip.String()
m.NamespaceID = pak.NamespaceID
m.NodeKey = wgkey.Key(req.NodeKey).HexString() // we update it just in case
m.Registered = true
m.RegisterMethod = "authKey"
db.Save(&m)
machine.AuthKeyID = uint(pak.ID)
machine.IPAddress = ip.String()
machine.NamespaceID = pak.NamespaceID
machine.NodeKey = wgkey.Key(reqisterRequest.NodeKey).
HexString()
// we update it just in case
machine.Registered = true
machine.RegisterMethod = "authKey"
db.Save(&machine)
pak.Used = true
db.Save(&pak)
@ -442,18 +488,21 @@ func (h *Headscale) handleAuthKey(
if err != nil {
log.Error().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Str("machine", machine.Name).
Err(err).
Msg("Cannot encode message")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc()
c.String(http.StatusInternalServerError, "Extremely sad!")
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
Inc()
ctx.String(http.StatusInternalServerError, "Extremely sad!")
return
}
machineRegistrations.WithLabelValues("new", "authkey", "success", m.Namespace.Name).Inc()
c.Data(200, "application/json; charset=utf-8", respBody)
machineRegistrations.WithLabelValues("new", "authkey", "success", machine.Namespace.Name).
Inc()
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
log.Info().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("ip", ip.String()).
Msg("Successfully authenticated via AuthKey")
}

271
app.go
View file

@ -18,20 +18,19 @@ import (
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/patrickmn/go-cache"
"golang.org/x/oauth2"
"github.com/gin-gonic/gin"
"github.com/grpc-ecosystem/go-grpc-middleware"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/philip-bui/grpc-zerolog"
"github.com/patrickmn/go-cache"
zerolog "github.com/philip-bui/grpc-zerolog"
zl "github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/soheilhy/cmux"
ginprometheus "github.com/zsais/go-gin-prometheus"
"golang.org/x/crypto/acme"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/oauth2"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
@ -48,7 +47,16 @@ import (
)
const (
AUTH_PREFIX = "Bearer "
AuthPrefix = "Bearer "
Postgres = "postgresql"
Sqlite = "sqlite3"
updateInterval = 5000
HTTPReadTimeout = 30 * time.Second
errUnsupportedDatabase = Error("unsupported DB")
errUnsupportedLetsEncryptChallengeType = Error(
"unknown value for Lets Encrypt challenge type",
)
)
// Config contains the initial Headscale configuration.
@ -151,16 +159,22 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
var dbString string
switch cfg.DBtype {
case "postgres":
dbString = fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s sslmode=disable", cfg.DBhost,
cfg.DBport, cfg.DBname, cfg.DBuser, cfg.DBpass)
case "sqlite3":
case Postgres:
dbString = fmt.Sprintf(
"host=%s port=%d dbname=%s user=%s password=%s sslmode=disable",
cfg.DBhost,
cfg.DBport,
cfg.DBname,
cfg.DBuser,
cfg.DBpass,
)
case Sqlite:
dbString = cfg.DBpath
default:
return nil, errors.New("unsupported DB")
return nil, errUnsupportedDatabase
}
h := Headscale{
app := Headscale{
cfg: cfg,
dbType: cfg.DBtype,
dbString: dbString,
@ -169,33 +183,32 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
aclRules: tailcfg.FilterAllowAll, // default allowall
}
err = h.initDB()
err = app.initDB()
if err != nil {
return nil, err
}
if cfg.OIDC.Issuer != "" {
err = h.initOIDC()
err = app.initOIDC()
if err != nil {
return nil, err
}
}
if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS
magicDNSDomains, err := generateMagicDNSRootDomains(h.cfg.IPPrefix, h.cfg.BaseDomain)
if err != nil {
return nil, err
}
if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS
magicDNSDomains := generateMagicDNSRootDomains(
app.cfg.IPPrefix,
)
// we might have routes already from Split DNS
if h.cfg.DNSConfig.Routes == nil {
h.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver)
if app.cfg.DNSConfig.Routes == nil {
app.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver)
}
for _, d := range magicDNSDomains {
h.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil
app.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil
}
}
return &h, nil
return &app, nil
}
// Redirect to our TLS url.
@ -221,30 +234,37 @@ func (h *Headscale) expireEphemeralNodesWorker() {
return
}
for _, ns := range namespaces {
machines, err := h.ListMachinesInNamespace(ns.Name)
for _, namespace := range namespaces {
machines, err := h.ListMachinesInNamespace(namespace.Name)
if err != nil {
log.Error().Err(err).Str("namespace", ns.Name).Msg("Error listing machines in namespace")
log.Error().
Err(err).
Str("namespace", namespace.Name).
Msg("Error listing machines in namespace")
return
}
for _, m := range machines {
if m.AuthKey != nil && m.LastSeen != nil && m.AuthKey.Ephemeral &&
time.Now().After(m.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
log.Info().Str("machine", m.Name).Msg("Ephemeral client removed from database")
for _, machine := range machines {
if machine.AuthKey != nil && machine.LastSeen != nil &&
machine.AuthKey.Ephemeral &&
time.Now().
After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
log.Info().
Str("machine", machine.Name).
Msg("Ephemeral client removed from database")
err = h.db.Unscoped().Delete(m).Error
err = h.db.Unscoped().Delete(machine).Error
if err != nil {
log.Error().
Err(err).
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("🤮 Cannot delete ephemeral machine from the database")
}
}
}
h.setLastStateChangeToNow(ns.Name)
h.setLastStateChangeToNow(namespace.Name)
}
}
@ -266,36 +286,56 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (interface{}, error) {
// Check if the request is coming from the on-server client.
// This is not secure, but it is to maintain maintainability
// with the "legacy" database-based client
// It is also neede for grpc-gateway to be able to connect to
// the server
p, _ := peer.FromContext(ctx)
client, _ := peer.FromContext(ctx)
log.Trace().Caller().Str("client_address", p.Addr.String()).Msg("Client is trying to authenticate")
log.Trace().
Caller().
Str("client_address", client.Addr.String()).
Msg("Client is trying to authenticate")
md, ok := metadata.FromIncomingContext(ctx)
meta, ok := metadata.FromIncomingContext(ctx)
if !ok {
log.Error().Caller().Str("client_address", p.Addr.String()).Msg("Retrieving metadata is failed")
return ctx, status.Errorf(codes.InvalidArgument, "Retrieving metadata is failed")
log.Error().
Caller().
Str("client_address", client.Addr.String()).
Msg("Retrieving metadata is failed")
return ctx, status.Errorf(
codes.InvalidArgument,
"Retrieving metadata is failed",
)
}
authHeader, ok := md["authorization"]
authHeader, ok := meta["authorization"]
if !ok {
log.Error().Caller().Str("client_address", p.Addr.String()).Msg("Authorization token is not supplied")
return ctx, status.Errorf(codes.Unauthenticated, "Authorization token is not supplied")
log.Error().
Caller().
Str("client_address", client.Addr.String()).
Msg("Authorization token is not supplied")
return ctx, status.Errorf(
codes.Unauthenticated,
"Authorization token is not supplied",
)
}
token := authHeader[0]
if !strings.HasPrefix(token, AUTH_PREFIX) {
if !strings.HasPrefix(token, AuthPrefix) {
log.Error().
Caller().
Str("client_address", p.Addr.String()).
Str("client_address", client.Addr.String()).
Msg(`missing "Bearer " prefix in "Authorization" header`)
return ctx, status.Error(codes.Unauthenticated, `missing "Bearer " prefix in "Authorization" header`)
return ctx, status.Error(
codes.Unauthenticated,
`missing "Bearer " prefix in "Authorization" header`,
)
}
// TODO(kradalby): Implement API key backend:
@ -307,7 +347,10 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
// Currently all other than localhost traffic is unauthorized, this is intentional to allow
// us to make use of gRPC for our CLI, but not having to implement any of the remote capabilities
// and API key auth
return ctx, status.Error(codes.Unauthenticated, "Authentication is not implemented yet")
return ctx, status.Error(
codes.Unauthenticated,
"Authentication is not implemented yet",
)
// if strings.TrimPrefix(token, AUTH_PREFIX) != a.Token {
// log.Error().Caller().Str("client_address", p.Addr.String()).Msg("invalid token")
@ -317,25 +360,25 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
// return handler(ctx, req)
}
func (h *Headscale) httpAuthenticationMiddleware(c *gin.Context) {
func (h *Headscale) httpAuthenticationMiddleware(ctx *gin.Context) {
log.Trace().
Caller().
Str("client_address", c.ClientIP()).
Str("client_address", ctx.ClientIP()).
Msg("HTTP authentication invoked")
authHeader := c.GetHeader("authorization")
authHeader := ctx.GetHeader("authorization")
if !strings.HasPrefix(authHeader, AUTH_PREFIX) {
if !strings.HasPrefix(authHeader, AuthPrefix) {
log.Error().
Caller().
Str("client_address", c.ClientIP()).
Str("client_address", ctx.ClientIP()).
Msg(`missing "Bearer " prefix in "Authorization" header`)
c.AbortWithStatus(http.StatusUnauthorized)
ctx.AbortWithStatus(http.StatusUnauthorized)
return
}
c.AbortWithStatus(http.StatusUnauthorized)
ctx.AbortWithStatus(http.StatusUnauthorized)
// TODO(kradalby): Implement API key backend
// Currently all traffic is unauthorized, this is intentional to allow
@ -359,6 +402,7 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error {
if _, err := os.Stat(h.cfg.UnixSocket); errors.Is(err, os.ErrNotExist) {
return nil
}
return os.Remove(h.cfg.UnixSocket)
}
@ -401,14 +445,17 @@ func (h *Headscale) Serve() error {
// Create the cmux object that will multiplex 2 protocols on the same port.
// The two following listeners will be served on the same port below gracefully.
m := cmux.New(networkListener)
networkMutex := cmux.New(networkListener)
// Match gRPC requests here
grpcListener := m.MatchWithWriters(
grpcListener := networkMutex.MatchWithWriters(
cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"),
cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc+proto"),
cmux.HTTP2MatchHeaderFieldSendSettings(
"content-type",
"application/grpc+proto",
),
)
// Otherwise match regular http requests.
httpListener := m.Match(cmux.Any())
httpListener := networkMutex.Match(cmux.Any())
grpcGatewayMux := runtime.NewServeMux()
@ -431,30 +478,33 @@ func (h *Headscale) Serve() error {
return err
}
r := gin.Default()
router := gin.Default()
p := ginprometheus.NewPrometheus("gin")
p.Use(r)
prometheus := ginprometheus.NewPrometheus("gin")
prometheus.Use(router)
r.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) })
r.GET("/key", h.KeyHandler)
r.GET("/register", h.RegisterWebAPI)
r.POST("/machine/:id/map", h.PollNetMapHandler)
r.POST("/machine/:id", h.RegistrationHandler)
r.GET("/oidc/register/:mkey", h.RegisterOIDC)
r.GET("/oidc/callback", h.OIDCCallback)
r.GET("/apple", h.AppleMobileConfig)
r.GET("/apple/:platform", h.ApplePlatformConfig)
r.GET("/swagger", SwaggerUI)
r.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1)
router.GET(
"/health",
func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) },
)
router.GET("/key", h.KeyHandler)
router.GET("/register", h.RegisterWebAPI)
router.POST("/machine/:id/map", h.PollNetMapHandler)
router.POST("/machine/:id", h.RegistrationHandler)
router.GET("/oidc/register/:mkey", h.RegisterOIDC)
router.GET("/oidc/callback", h.OIDCCallback)
router.GET("/apple", h.AppleMobileConfig)
router.GET("/apple/:platform", h.ApplePlatformConfig)
router.GET("/swagger", SwaggerUI)
router.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1)
api := r.Group("/api")
api := router.Group("/api")
api.Use(h.httpAuthenticationMiddleware)
{
api.Any("/v1/*any", gin.WrapF(grpcGatewayMux.ServeHTTP))
}
r.NoRoute(stdoutHandler)
router.NoRoute(stdoutHandler)
// Fetch an initial DERP Map before we start serving
h.DERPMap = GetDERPMap(h.cfg.DERP)
@ -466,14 +516,13 @@ func (h *Headscale) Serve() error {
}
// I HATE THIS
updateMillisecondsWait := int64(5000)
go h.watchForKVUpdates(updateMillisecondsWait)
go h.expireEphemeralNodes(updateMillisecondsWait)
go h.watchForKVUpdates(updateInterval)
go h.expireEphemeralNodes(updateInterval)
httpServer := &http.Server{
Addr: h.cfg.Addr,
Handler: r,
ReadTimeout: 30 * time.Second,
Handler: router,
ReadTimeout: HTTPReadTimeout,
// Go does not handle timeouts in HTTP very well, and there is
// no good way to handle streaming timeouts, therefore we need to
// keep this at unlimited and be careful to clean up connections
@ -519,36 +568,40 @@ func (h *Headscale) Serve() error {
reflection.Register(grpcServer)
reflection.Register(grpcSocket)
g := new(errgroup.Group)
errorGroup := new(errgroup.Group)
g.Go(func() error { return grpcSocket.Serve(socketListener) })
errorGroup.Go(func() error { return grpcSocket.Serve(socketListener) })
// TODO(kradalby): Verify if we need the same TLS setup for gRPC as HTTP
g.Go(func() error { return grpcServer.Serve(grpcListener) })
errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) })
if tlsConfig != nil {
g.Go(func() error {
errorGroup.Go(func() error {
tlsl := tls.NewListener(httpListener, tlsConfig)
return httpServer.Serve(tlsl)
})
} else {
g.Go(func() error { return httpServer.Serve(httpListener) })
errorGroup.Go(func() error { return httpServer.Serve(httpListener) })
}
g.Go(func() error { return m.Serve() })
errorGroup.Go(func() error { return networkMutex.Serve() })
log.Info().Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr)
log.Info().
Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr)
return g.Wait()
return errorGroup.Wait()
}
func (h *Headscale) getTLSSettings() (*tls.Config, error) {
var err error
if h.cfg.TLSLetsEncryptHostname != "" {
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
log.Warn().
Msg("Listening with TLS but ServerURL does not start with https://")
}
m := autocert.Manager{
certManager := autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(h.cfg.TLSLetsEncryptHostname),
Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir),
@ -558,40 +611,44 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
Email: h.cfg.ACMEEmail,
}
if h.cfg.TLSLetsEncryptChallengeType == "TLS-ALPN-01" {
switch h.cfg.TLSLetsEncryptChallengeType {
case "TLS-ALPN-01":
// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
// The RFC requires that the validation is done on port 443; in other words, headscale
// must be reachable on port 443.
return m.TLSConfig(), nil
} else if h.cfg.TLSLetsEncryptChallengeType == "HTTP-01" {
return certManager.TLSConfig(), nil
case "HTTP-01":
// Configuration via autocert with HTTP-01. This requires listening on
// port 80 for the certificate validation in addition to the headscale
// service, which can be configured to run on any other port.
go func() {
log.Fatal().
Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, m.HTTPHandler(http.HandlerFunc(h.redirect)))).
Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))).
Msg("failed to set up a HTTP server")
}()
return m.TLSConfig(), nil
} else {
return nil, errors.New("unknown value for TLSLetsEncryptChallengeType")
return certManager.TLSConfig(), nil
default:
return nil, errUnsupportedLetsEncryptChallengeType
}
} else if h.cfg.TLSCertPath == "" {
if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
log.Warn().Msg("Listening without TLS but ServerURL does not start with http://")
}
return nil, nil
return nil, err
} else {
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
}
var err error
tlsConfig := &tls.Config{}
tlsConfig.ClientAuth = tls.RequireAnyClientCert
tlsConfig.NextProtos = []string{"http/1.1"}
tlsConfig.Certificates = make([]tls.Certificate, 1)
tlsConfig := &tls.Config{
ClientAuth: tls.RequireAnyClientCert,
NextProtos: []string{"http/1.1"},
Certificates: make([]tls.Certificate, 1),
MinVersion: tls.VersionTLS12,
}
tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(h.cfg.TLSCertPath, h.cfg.TLSKeyPath)
return tlsConfig, err
@ -628,13 +685,13 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time {
}
}
func stdoutHandler(c *gin.Context) {
b, _ := io.ReadAll(c.Request.Body)
func stdoutHandler(ctx *gin.Context) {
body, _ := io.ReadAll(ctx.Request.Body)
log.Trace().
Interface("header", c.Request.Header).
Interface("proto", c.Request.Proto).
Interface("url", c.Request.URL).
Bytes("body", b).
Interface("header", ctx.Request.Header).
Interface("proto", ctx.Request.Proto).
Interface("url", ctx.Request.URL).
Bytes("body", body).
Msg("Request did not match")
}

View file

@ -17,8 +17,10 @@ var _ = check.Suite(&Suite{})
type Suite struct{}
var tmpDir string
var h Headscale
var (
tmpDir string
app Headscale
)
func (s *Suite) SetUpTest(c *check.C) {
s.ResetDB(c)
@ -41,18 +43,18 @@ func (s *Suite) ResetDB(c *check.C) {
IPPrefix: netaddr.MustParseIPPrefix("10.27.0.0/23"),
}
h = Headscale{
app = Headscale{
cfg: cfg,
dbType: "sqlite3",
dbString: tmpDir + "/headscale_test.db",
}
err = h.initDB()
err = app.initDB()
if err != nil {
c.Fatal(err)
}
db, err := h.openDB()
db, err := app.openDB()
if err != nil {
c.Fatal(err)
}
h.db = db
app.db = db
}

View file

@ -5,16 +5,15 @@ import (
"net/http"
"text/template"
"github.com/rs/zerolog/log"
"github.com/gin-gonic/gin"
"github.com/gofrs/uuid"
"github.com/rs/zerolog/log"
)
// AppleMobileConfig shows a simple message in the browser to point to the CLI
// Listens in /register
func (h *Headscale) AppleMobileConfig(c *gin.Context) {
t := template.Must(template.New("apple").Parse(`
// Listens in /register.
func (h *Headscale) AppleMobileConfig(ctx *gin.Context) {
appleTemplate := template.Must(template.New("apple").Parse(`
<html>
<body>
<h1>Apple configuration profiles</h1>
@ -56,7 +55,7 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {
<p>Or</p>
<p>Use your terminal to configure the default setting for Tailscale by issuing:</p>
<code>defaults write io.tailscale.ipn.macos ControlURL {{.Url}}</code>
<code>defaults write io.tailscale.ipn.macos ControlURL {{.URL}}</code>
<p>Restart Tailscale.app and log in.</p>
@ -64,24 +63,29 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {
</html>`))
config := map[string]interface{}{
"Url": h.cfg.ServerURL,
"URL": h.cfg.ServerURL,
}
var payload bytes.Buffer
if err := t.Execute(&payload, config); err != nil {
if err := appleTemplate.Execute(&payload, config); err != nil {
log.Error().
Str("handler", "AppleMobileConfig").
Err(err).
Msg("Could not render Apple index template")
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple index template"))
ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Apple index template"),
)
return
}
c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
}
func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
platform := c.Param("platform")
func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) {
platform := ctx.Param("platform")
id, err := uuid.NewV4()
if err != nil {
@ -89,23 +93,33 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Failed not create UUID")
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Failed to create UUID"))
ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Failed to create UUID"),
)
return
}
contentId, err := uuid.NewV4()
contentID, err := uuid.NewV4()
if err != nil {
log.Error().
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Failed not create UUID")
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Failed to create UUID"))
ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Failed to create UUID"),
)
return
}
platformConfig := AppleMobilePlatformConfig{
UUID: contentId,
Url: h.cfg.ServerURL,
UUID: contentID,
URL: h.cfg.ServerURL,
}
var payload bytes.Buffer
@ -117,7 +131,12 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Could not render Apple macOS template")
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple macOS template"))
ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Apple macOS template"),
)
return
}
case "ios":
@ -126,17 +145,27 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Could not render Apple iOS template")
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple iOS template"))
ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Apple iOS template"),
)
return
}
default:
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte("Invalid platform, only ios and macos is supported"))
ctx.Data(
http.StatusOK,
"text/html; charset=utf-8",
[]byte("Invalid platform, only ios and macos is supported"),
)
return
}
config := AppleMobileConfig{
UUID: id,
Url: h.cfg.ServerURL,
URL: h.cfg.ServerURL,
Payload: payload.String(),
}
@ -146,25 +175,35 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Could not render Apple platform template")
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple platform template"))
ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Apple platform template"),
)
return
}
c.Data(http.StatusOK, "application/x-apple-aspen-config; charset=utf-8", content.Bytes())
ctx.Data(
http.StatusOK,
"application/x-apple-aspen-config; charset=utf-8",
content.Bytes(),
)
}
type AppleMobileConfig struct {
UUID uuid.UUID
Url string
URL string
Payload string
}
type AppleMobilePlatformConfig struct {
UUID uuid.UUID
Url string
URL string
}
var commonTemplate = template.Must(template.New("mobileconfig").Parse(`<?xml version="1.0" encoding="UTF-8"?>
var commonTemplate = template.Must(
template.New("mobileconfig").Parse(`<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
@ -173,7 +212,7 @@ var commonTemplate = template.Must(template.New("mobileconfig").Parse(`<?xml ver
<key>PayloadDisplayName</key>
<string>Headscale</string>
<key>PayloadDescription</key>
<string>Configure Tailscale login server to: {{.Url}}</string>
<string>Configure Tailscale login server to: {{.URL}}</string>
<key>PayloadIdentifier</key>
<string>com.github.juanfont.headscale</string>
<key>PayloadRemovalDisallowed</key>
@ -187,7 +226,8 @@ var commonTemplate = template.Must(template.New("mobileconfig").Parse(`<?xml ver
{{.Payload}}
</array>
</dict>
</plist>`))
</plist>`),
)
var iosTemplate = template.Must(template.New("iosTemplate").Parse(`
<dict>
@ -203,7 +243,7 @@ var iosTemplate = template.Must(template.New("iosTemplate").Parse(`
<true/>
<key>ControlURL</key>
<string>{{.Url}}</string>
<string>{{.URL}}</string>
</dict>
`))
@ -221,6 +261,6 @@ var macosTemplate = template.Must(template.New("macosTemplate").Parse(`
<true/>
<key>ControlURL</key>
<string>{{.Url}}</string>
<string>{{.URL}}</string>
</dict>
`))

View file

@ -7,31 +7,34 @@ import (
)
func (s *Suite) TestRegisterMachine(c *check.C) {
n, err := h.CreateNamespace("test")
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
now := time.Now().UTC()
m := Machine{
machine := Machine{
ID: 0,
MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
NamespaceID: namespace.ID,
IPAddress: "10.0.0.1",
Expiry: &now,
RequestedExpiry: &now,
}
h.db.Save(&m)
app.db.Save(&machine)
_, err = h.GetMachine("test", "testmachine")
_, err = app.GetMachine("test", "testmachine")
c.Assert(err, check.IsNil)
m2, err := h.RegisterMachine("8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e", n.Name)
machineAfterRegistering, err := app.RegisterMachine(
"8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e",
namespace.Name,
)
c.Assert(err, check.IsNil)
c.Assert(m2.Registered, check.Equals, true)
c.Assert(machineAfterRegistering.Registered, check.Equals, true)
_, err = m2.GetHostInfo()
_, err = machineAfterRegistering.GetHostInfo()
c.Assert(err, check.IsNil)
}

View file

@ -27,7 +27,8 @@ func init() {
if err != nil {
log.Fatal().Err(err).Msg("")
}
createNodeCmd.Flags().StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise")
createNodeCmd.Flags().
StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise")
debugCmd.AddCommand(createNodeCmd)
}
@ -47,6 +48,7 @@ var createNodeCmd = &cobra.Command{
namespace, err := cmd.Flags().GetString("namespace")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return
}
@ -56,19 +58,34 @@ var createNodeCmd = &cobra.Command{
name, err := cmd.Flags().GetString("name")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting node from flag: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Error getting node from flag: %s", err),
output,
)
return
}
machineKey, err := cmd.Flags().GetString("key")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting key from flag: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Error getting key from flag: %s", err),
output,
)
return
}
routes, err := cmd.Flags().GetStringSlice("route")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting routes from flag: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Error getting routes from flag: %s", err),
output,
)
return
}
@ -81,7 +98,12 @@ var createNodeCmd = &cobra.Command{
response, err := client.DebugCreateMachine(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot create machine: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf("Cannot create machine: %s", status.Convert(err).Message()),
output,
)
return
}

View file

@ -4,6 +4,7 @@ import (
"fmt"
survey "github.com/AlecAivazis/survey/v2"
"github.com/juanfont/headscale"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/pterm/pterm"
"github.com/rs/zerolog/log"
@ -19,6 +20,10 @@ func init() {
namespaceCmd.AddCommand(renameNamespaceCmd)
}
const (
errMissingParameter = headscale.Error("missing parameters")
)
var namespaceCmd = &cobra.Command{
Use: "namespaces",
Short: "Manage the namespaces of Headscale",
@ -29,8 +34,9 @@ var createNamespaceCmd = &cobra.Command{
Short: "Creates a new namespace",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
return fmt.Errorf("Missing parameters")
return errMissingParameter
}
return nil
},
Run: func(cmd *cobra.Command, args []string) {
@ -49,7 +55,15 @@ var createNamespaceCmd = &cobra.Command{
log.Trace().Interface("request", request).Msg("Sending CreateNamespace request")
response, err := client.CreateNamespace(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot create namespace: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf(
"Cannot create namespace: %s",
status.Convert(err).Message(),
),
output,
)
return
}
@ -62,8 +76,9 @@ var destroyNamespaceCmd = &cobra.Command{
Short: "Destroys a namespace",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
return fmt.Errorf("Missing parameters")
return errMissingParameter
}
return nil
},
Run: func(cmd *cobra.Command, args []string) {
@ -81,7 +96,12 @@ var destroyNamespaceCmd = &cobra.Command{
_, err := client.GetNamespace(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf("Error: %s", status.Convert(err).Message()),
output,
)
return
}
@ -89,7 +109,10 @@ var destroyNamespaceCmd = &cobra.Command{
force, _ := cmd.Flags().GetBool("force")
if !force {
prompt := &survey.Confirm{
Message: fmt.Sprintf("Do you want to remove the namespace '%s' and any associated preauthkeys?", namespaceName),
Message: fmt.Sprintf(
"Do you want to remove the namespace '%s' and any associated preauthkeys?",
namespaceName,
),
}
err := survey.AskOne(prompt, &confirm)
if err != nil {
@ -102,7 +125,15 @@ var destroyNamespaceCmd = &cobra.Command{
response, err := client.DeleteNamespace(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot destroy namespace: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf(
"Cannot destroy namespace: %s",
status.Convert(err).Message(),
),
output,
)
return
}
SuccessOutput(response, "Namespace destroyed", output)
@ -126,19 +157,25 @@ var listNamespacesCmd = &cobra.Command{
response, err := client.ListNamespaces(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot get namespaces: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf("Cannot get namespaces: %s", status.Convert(err).Message()),
output,
)
return
}
if output != "" {
SuccessOutput(response.Namespaces, "", output)
return
}
d := pterm.TableData{{"ID", "Name", "Created"}}
tableData := pterm.TableData{{"ID", "Name", "Created"}}
for _, namespace := range response.GetNamespaces() {
d = append(
d,
tableData = append(
tableData,
[]string{
namespace.GetId(),
namespace.GetName(),
@ -146,9 +183,14 @@ var listNamespacesCmd = &cobra.Command{
},
)
}
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
@ -158,9 +200,11 @@ var renameNamespaceCmd = &cobra.Command{
Use: "rename OLD_NAME NEW_NAME",
Short: "Renames a namespace",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 2 {
return fmt.Errorf("Missing parameters")
expectedArguments := 2
if len(args) < expectedArguments {
return errMissingParameter
}
return nil
},
Run: func(cmd *cobra.Command, args []string) {
@ -177,7 +221,15 @@ var renameNamespaceCmd = &cobra.Command{
response, err := client.RenameNamespace(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot rename namespace: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf(
"Cannot rename namespace: %s",
status.Convert(err).Message(),
),
output,
)
return
}

View file

@ -7,6 +7,7 @@ import (
"time"
survey "github.com/AlecAivazis/survey/v2"
"github.com/juanfont/headscale"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/pterm/pterm"
"github.com/spf13/cobra"
@ -77,6 +78,7 @@ var registerNodeCmd = &cobra.Command{
namespace, err := cmd.Flags().GetString("namespace")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return
}
@ -86,7 +88,12 @@ var registerNodeCmd = &cobra.Command{
machineKey, err := cmd.Flags().GetString("key")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting machine key from flag: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Error getting machine key from flag: %s", err),
output,
)
return
}
@ -97,7 +104,15 @@ var registerNodeCmd = &cobra.Command{
response, err := client.RegisterMachine(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot register machine: %s\n", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf(
"Cannot register machine: %s\n",
status.Convert(err).Message(),
),
output,
)
return
}
@ -113,6 +128,7 @@ var listNodesCmd = &cobra.Command{
namespace, err := cmd.Flags().GetString("namespace")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return
}
@ -126,24 +142,36 @@ var listNodesCmd = &cobra.Command{
response, err := client.ListMachines(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
output,
)
return
}
if output != "" {
SuccessOutput(response.Machines, "", output)
return
}
d, err := nodesToPtables(namespace, response.Machines)
tableData, err := nodesToPtables(namespace, response.Machines)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return
}
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
@ -155,9 +183,14 @@ var deleteNodeCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
id, err := cmd.Flags().GetInt("identifier")
identifier, err := cmd.Flags().GetInt("identifier")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Error converting ID to integer: %s", err),
output,
)
return
}
@ -166,24 +199,35 @@ var deleteNodeCmd = &cobra.Command{
defer conn.Close()
getRequest := &v1.GetMachineRequest{
MachineId: uint64(id),
MachineId: uint64(identifier),
}
getResponse, err := client.GetMachine(ctx, getRequest)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf(
"Error getting node node: %s",
status.Convert(err).Message(),
),
output,
)
return
}
deleteRequest := &v1.DeleteMachineRequest{
MachineId: uint64(id),
MachineId: uint64(identifier),
}
confirm := false
force, _ := cmd.Flags().GetBool("force")
if !force {
prompt := &survey.Confirm{
Message: fmt.Sprintf("Do you want to remove the node %s?", getResponse.GetMachine().Name),
Message: fmt.Sprintf(
"Do you want to remove the node %s?",
getResponse.GetMachine().Name,
),
}
err = survey.AskOne(prompt, &confirm)
if err != nil {
@ -195,13 +239,26 @@ var deleteNodeCmd = &cobra.Command{
response, err := client.DeleteMachine(ctx, deleteRequest)
if output != "" {
SuccessOutput(response, "", output)
return
}
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error deleting node: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf(
"Error deleting node: %s",
status.Convert(err).Message(),
),
output,
)
return
}
SuccessOutput(map[string]string{"Result": "Node deleted"}, "Node deleted", output)
SuccessOutput(
map[string]string{"Result": "Node deleted"},
"Node deleted",
output,
)
} else {
SuccessOutput(map[string]string{"Result": "Node not deleted"}, "Node not deleted", output)
}
@ -210,12 +267,12 @@ var deleteNodeCmd = &cobra.Command{
func sharingWorker(
cmd *cobra.Command,
args []string,
) (string, *v1.Machine, *v1.Namespace, error) {
output, _ := cmd.Flags().GetString("output")
namespaceStr, err := cmd.Flags().GetString("namespace")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return "", nil, nil, err
}
@ -223,19 +280,25 @@ func sharingWorker(
defer cancel()
defer conn.Close()
id, err := cmd.Flags().GetInt("identifier")
identifier, err := cmd.Flags().GetInt("identifier")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output)
return "", nil, nil, err
}
machineRequest := &v1.GetMachineRequest{
MachineId: uint64(id),
MachineId: uint64(identifier),
}
machineResponse, err := client.GetMachine(ctx, machineRequest)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()),
output,
)
return "", nil, nil, err
}
@ -245,7 +308,12 @@ func sharingWorker(
namespaceResponse, err := client.GetNamespace(ctx, namespaceRequest)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()),
output,
)
return "", nil, nil, err
}
@ -256,9 +324,14 @@ var shareMachineCmd = &cobra.Command{
Use: "share",
Short: "Shares a node from the current namespace to the specified one",
Run: func(cmd *cobra.Command, args []string) {
output, machine, namespace, err := sharingWorker(cmd, args)
output, machine, namespace, err := sharingWorker(cmd)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to fetch namespace or machine: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Failed to fetch namespace or machine: %s", err),
output,
)
return
}
@ -273,7 +346,12 @@ var shareMachineCmd = &cobra.Command{
response, err := client.ShareMachine(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error sharing node: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf("Error sharing node: %s", status.Convert(err).Message()),
output,
)
return
}
@ -285,9 +363,14 @@ var unshareMachineCmd = &cobra.Command{
Use: "unshare",
Short: "Unshares a node from the specified namespace",
Run: func(cmd *cobra.Command, args []string) {
output, machine, namespace, err := sharingWorker(cmd, args)
output, machine, namespace, err := sharingWorker(cmd)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to fetch namespace or machine: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Failed to fetch namespace or machine: %s", err),
output,
)
return
}
@ -302,7 +385,12 @@ var unshareMachineCmd = &cobra.Command{
response, err := client.UnshareMachine(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error unsharing node: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf("Error unsharing node: %s", status.Convert(err).Message()),
output,
)
return
}
@ -310,8 +398,22 @@ var unshareMachineCmd = &cobra.Command{
},
}
func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.TableData, error) {
d := pterm.TableData{{"ID", "Name", "NodeKey", "Namespace", "IP address", "Ephemeral", "Last seen", "Online"}}
func nodesToPtables(
currentNamespace string,
machines []*v1.Machine,
) (pterm.TableData, error) {
tableData := pterm.TableData{
{
"ID",
"Name",
"NodeKey",
"Namespace",
"IP address",
"Ephemeral",
"Last seen",
"Online",
},
}
for _, machine := range machines {
var ephemeral bool
@ -331,7 +433,9 @@ func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.Tabl
nodeKey := tailcfg.NodeKey(nKey)
var online string
if lastSeen.After(time.Now().Add(-5 * time.Minute)) { // TODO: Find a better way to reliably show if online
if lastSeen.After(
time.Now().Add(-5 * time.Minute),
) { // TODO: Find a better way to reliably show if online
online = pterm.LightGreen("true")
} else {
online = pterm.LightRed("false")
@ -344,10 +448,10 @@ func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.Tabl
// Shared into this namespace
namespace = pterm.LightYellow(machine.Namespace.Name)
}
d = append(
d,
tableData = append(
tableData,
[]string{
strconv.FormatUint(machine.Id, 10),
strconv.FormatUint(machine.Id, headscale.Base10),
machine.Name,
nodeKey.ShortString(),
namespace,
@ -358,5 +462,6 @@ func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.Tabl
},
)
}
return d, nil
return tableData, nil
}

View file

@ -12,6 +12,10 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"
)
const (
DefaultPreAuthKeyExpiry = 24 * time.Hour
)
func init() {
rootCmd.AddCommand(preauthkeysCmd)
preauthkeysCmd.PersistentFlags().StringP("namespace", "n", "", "Namespace")
@ -22,10 +26,12 @@ func init() {
preauthkeysCmd.AddCommand(listPreAuthKeys)
preauthkeysCmd.AddCommand(createPreAuthKeyCmd)
preauthkeysCmd.AddCommand(expirePreAuthKeyCmd)
createPreAuthKeyCmd.PersistentFlags().Bool("reusable", false, "Make the preauthkey reusable")
createPreAuthKeyCmd.PersistentFlags().Bool("ephemeral", false, "Preauthkey for ephemeral nodes")
createPreAuthKeyCmd.PersistentFlags().
Bool("reusable", false, "Make the preauthkey reusable")
createPreAuthKeyCmd.PersistentFlags().
Bool("ephemeral", false, "Preauthkey for ephemeral nodes")
createPreAuthKeyCmd.Flags().
DurationP("expiration", "e", 24*time.Hour, "Human-readable expiration of the key (30m, 24h, 365d...)")
DurationP("expiration", "e", DefaultPreAuthKeyExpiry, "Human-readable expiration of the key (30m, 24h, 365d...)")
}
var preauthkeysCmd = &cobra.Command{
@ -39,9 +45,10 @@ var listPreAuthKeys = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
n, err := cmd.Flags().GetString("namespace")
namespace, err := cmd.Flags().GetString("namespace")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return
}
@ -50,48 +57,61 @@ var listPreAuthKeys = &cobra.Command{
defer conn.Close()
request := &v1.ListPreAuthKeysRequest{
Namespace: n,
Namespace: namespace,
}
response, err := client.ListPreAuthKeys(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting the list of keys: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Error getting the list of keys: %s", err),
output,
)
return
}
if output != "" {
SuccessOutput(response.PreAuthKeys, "", output)
return
}
d := pterm.TableData{{"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"}}
for _, k := range response.PreAuthKeys {
tableData := pterm.TableData{
{"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"},
}
for _, key := range response.PreAuthKeys {
expiration := "-"
if k.GetExpiration() != nil {
expiration = k.Expiration.AsTime().Format("2006-01-02 15:04:05")
if key.GetExpiration() != nil {
expiration = key.Expiration.AsTime().Format("2006-01-02 15:04:05")
}
var reusable string
if k.GetEphemeral() {
if key.GetEphemeral() {
reusable = "N/A"
} else {
reusable = fmt.Sprintf("%v", k.GetReusable())
reusable = fmt.Sprintf("%v", key.GetReusable())
}
d = append(d, []string{
k.GetId(),
k.GetKey(),
tableData = append(tableData, []string{
key.GetId(),
key.GetKey(),
reusable,
strconv.FormatBool(k.GetEphemeral()),
strconv.FormatBool(k.GetUsed()),
strconv.FormatBool(key.GetEphemeral()),
strconv.FormatBool(key.GetUsed()),
expiration,
k.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
})
}
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
@ -106,6 +126,7 @@ var createPreAuthKeyCmd = &cobra.Command{
namespace, err := cmd.Flags().GetString("namespace")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return
}
@ -139,7 +160,12 @@ var createPreAuthKeyCmd = &cobra.Command{
response, err := client.CreatePreAuthKey(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err), output)
ErrorOutput(
err,
fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err),
output,
)
return
}
@ -152,8 +178,9 @@ var expirePreAuthKeyCmd = &cobra.Command{
Short: "Expire a preauthkey",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
return fmt.Errorf("missing parameters")
return errMissingParameter
}
return nil
},
Run: func(cmd *cobra.Command, args []string) {
@ -161,6 +188,7 @@ var expirePreAuthKeyCmd = &cobra.Command{
namespace, err := cmd.Flags().GetString("namespace")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return
}
@ -175,7 +203,12 @@ var expirePreAuthKeyCmd = &cobra.Command{
response, err := client.ExpirePreAuthKey(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err), output)
ErrorOutput(
err,
fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err),
output,
)
return
}

View file

@ -10,7 +10,8 @@ import (
func init() {
rootCmd.PersistentFlags().
StringP("output", "o", "", "Output format. Empty for human-readable, 'json', 'json-line' or 'yaml'")
rootCmd.PersistentFlags().Bool("force", false, "Disable prompts and forces the execution")
rootCmd.PersistentFlags().
Bool("force", false, "Disable prompts and forces the execution")
}
var rootCmd = &cobra.Command{

View file

@ -21,7 +21,8 @@ func init() {
}
routesCmd.AddCommand(listRoutesCmd)
enableRouteCmd.Flags().StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to enable")
enableRouteCmd.Flags().
StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to enable")
enableRouteCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
err = enableRouteCmd.MarkFlagRequired("identifier")
if err != nil {
@ -44,9 +45,14 @@ var listRoutesCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
machineId, err := cmd.Flags().GetUint64("identifier")
machineID, err := cmd.Flags().GetUint64("identifier")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting machine id from flag: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Error getting machine id from flag: %s", err),
output,
)
return
}
@ -55,29 +61,41 @@ var listRoutesCmd = &cobra.Command{
defer conn.Close()
request := &v1.GetMachineRouteRequest{
MachineId: machineId,
MachineId: machineID,
}
response, err := client.GetMachineRoute(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
output,
)
return
}
if output != "" {
SuccessOutput(response.Routes, "", output)
return
}
d := routesToPtables(response.Routes)
tableData := routesToPtables(response.Routes)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return
}
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
@ -93,15 +111,26 @@ omit the route you do not want to enable.
`,
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
machineId, err := cmd.Flags().GetUint64("identifier")
machineID, err := cmd.Flags().GetUint64("identifier")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting machine id from flag: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Error getting machine id from flag: %s", err),
output,
)
return
}
routes, err := cmd.Flags().GetStringSlice("route")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting routes from flag: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Error getting routes from flag: %s", err),
output,
)
return
}
@ -110,45 +139,61 @@ omit the route you do not want to enable.
defer conn.Close()
request := &v1.EnableMachineRoutesRequest{
MachineId: machineId,
MachineId: machineID,
Routes: routes,
}
response, err := client.EnableMachineRoutes(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot register machine: %s\n", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf(
"Cannot register machine: %s\n",
status.Convert(err).Message(),
),
output,
)
return
}
if output != "" {
SuccessOutput(response.Routes, "", output)
return
}
d := routesToPtables(response.Routes)
tableData := routesToPtables(response.Routes)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return
}
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
}
// routesToPtables converts the list of routes to a nice table
// routesToPtables converts the list of routes to a nice table.
func routesToPtables(routes *v1.Routes) pterm.TableData {
d := pterm.TableData{{"Route", "Enabled"}}
tableData := pterm.TableData{{"Route", "Enabled"}}
for _, route := range routes.GetAdvertisedRoutes() {
enabled := isStringInSlice(routes.EnabledRoutes, route)
d = append(d, []string{route, strconv.FormatBool(enabled)})
tableData = append(tableData, []string{route, strconv.FormatBool(enabled)})
}
return d
return tableData
}
func isStringInSlice(strs []string, s string) bool {

View file

@ -52,9 +52,8 @@ func LoadConfig(path string) error {
viper.SetDefault("cli.insecure", false)
viper.SetDefault("cli.timeout", "5s")
err := viper.ReadInConfig()
if err != nil {
return fmt.Errorf("Fatal error reading config file: %s \n", err)
if err := viper.ReadInConfig(); err != nil {
return fmt.Errorf("fatal error reading config file: %w", err)
}
// Collect any validation errors and return them all at once
@ -82,6 +81,7 @@ func LoadConfig(path string) error {
errorText += "Fatal config error: server_url must start with https:// or http://\n"
}
if errorText != "" {
//nolint
return errors.New(strings.TrimSuffix(errorText, "\n"))
} else {
return nil
@ -149,9 +149,14 @@ func GetDNSConfig() (*tailcfg.DNSConfig, string) {
if viper.IsSet("dns_config.restricted_nameservers") {
if len(dnsConfig.Nameservers) > 0 {
dnsConfig.Routes = make(map[string][]dnstype.Resolver)
restrictedDNS := viper.GetStringMapStringSlice("dns_config.restricted_nameservers")
restrictedDNS := viper.GetStringMapStringSlice(
"dns_config.restricted_nameservers",
)
for domain, restrictedNameservers := range restrictedDNS {
restrictedResolvers := make([]dnstype.Resolver, len(restrictedNameservers))
restrictedResolvers := make(
[]dnstype.Resolver,
len(restrictedNameservers),
)
for index, nameserverStr := range restrictedNameservers {
nameserver, err := netaddr.ParseIP(nameserverStr)
if err != nil {
@ -208,6 +213,7 @@ func absPath(path string) string {
path = filepath.Join(dir, path)
}
}
return path
}
@ -219,7 +225,9 @@ func getHeadscaleConfig() headscale.Config {
"10h",
) // use 10h here because it is the length of a standard business day plus a small amount of leeway
if viper.GetDuration("max_machine_registration_duration") >= time.Second {
maxMachineRegistrationDuration = viper.GetDuration("max_machine_registration_duration")
maxMachineRegistrationDuration = viper.GetDuration(
"max_machine_registration_duration",
)
}
// defaultMachineRegistrationDuration is the default time assigned to a machine registration if one is not
@ -229,7 +237,9 @@ func getHeadscaleConfig() headscale.Config {
"8h",
) // use 8h here because it's the length of a standard business day
if viper.GetDuration("default_machine_registration_duration") >= time.Second {
defaultMachineRegistrationDuration = viper.GetDuration("default_machine_registration_duration")
defaultMachineRegistrationDuration = viper.GetDuration(
"default_machine_registration_duration",
)
}
dnsConfig, baseDomain := GetDNSConfig()
@ -244,7 +254,9 @@ func getHeadscaleConfig() headscale.Config {
DERP: derpConfig,
EphemeralNodeInactivityTimeout: viper.GetDuration("ephemeral_node_inactivity_timeout"),
EphemeralNodeInactivityTimeout: viper.GetDuration(
"ephemeral_node_inactivity_timeout",
),
DBtype: viper.GetString("db_type"),
DBpath: absPath(viper.GetString("db_path")),
@ -256,7 +268,9 @@ func getHeadscaleConfig() headscale.Config {
TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"),
TLSLetsEncryptListen: viper.GetString("tls_letsencrypt_listen"),
TLSLetsEncryptCacheDir: absPath(viper.GetString("tls_letsencrypt_cache_dir")),
TLSLetsEncryptCacheDir: absPath(
viper.GetString("tls_letsencrypt_cache_dir"),
),
TLSLetsEncryptChallengeType: viper.GetString("tls_letsencrypt_challenge_type"),
TLSCertPath: absPath(viper.GetString("tls_cert_path")),
@ -292,11 +306,14 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
// to avoid races
minInactivityTimeout, _ := time.ParseDuration("65s")
if viper.GetDuration("ephemeral_node_inactivity_timeout") <= minInactivityTimeout {
// TODO: Find a better way to return this text
//nolint
err := fmt.Errorf(
"ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s\n",
"ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s",
viper.GetString("ephemeral_node_inactivity_timeout"),
minInactivityTimeout,
)
return nil, err
}
@ -304,7 +321,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
cfg.OIDC.MatchMap = loadOIDCMatchMap()
h, err := headscale.NewHeadscale(cfg)
app, err := headscale.NewHeadscale(cfg)
if err != nil {
return nil, err
}
@ -313,7 +330,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
if viper.GetString("acl_policy_path") != "" {
aclPath := absPath(viper.GetString("acl_policy_path"))
err = h.LoadACLPolicy(aclPath)
err = app.LoadACLPolicy(aclPath)
if err != nil {
log.Error().
Str("path", aclPath).
@ -322,7 +339,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
}
}
return h, nil
return app, nil
}
func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
@ -342,7 +359,6 @@ func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.
// If the address is not set, we assume that we are on the server hosting headscale.
if address == "" {
log.Debug().
Str("socket", cfg.UnixSocket).
Msgf("HEADSCALE_CLI_ADDRESS environment is not set, connecting to unix socket.")
@ -402,10 +418,13 @@ func SuccessOutput(result interface{}, override string, outputFormat string) {
log.Fatal().Err(err)
}
default:
//nolint
fmt.Println(override)
return
}
//nolint
fmt.Println(string(j))
}
@ -423,6 +442,7 @@ func HasMachineOutputFlag() bool {
return true
}
}
return false
}
@ -431,7 +451,10 @@ type tokenAuth struct {
}
// Return value is mapped to request headers.
func (t tokenAuth) GetRequestMetadata(ctx context.Context, in ...string) (map[string]string, error) {
func (t tokenAuth) GetRequestMetadata(
ctx context.Context,
in ...string,
) (map[string]string, error) {
return map[string]string{
"authorization": "Bearer " + t.token,
}, nil

View file

@ -23,6 +23,8 @@ func main() {
colors = true
case termcolor.LevelBasic:
colors = true
case termcolor.LevelNone:
colors = false
default:
// no color, return text as is.
colors = false
@ -41,8 +43,7 @@ func main() {
NoColor: !colors,
})
err := cli.LoadConfig("")
if err != nil {
if err := cli.LoadConfig(""); err != nil {
log.Fatal().Err(err)
}
@ -63,13 +64,15 @@ func main() {
}
if !viper.GetBool("disable_check_updates") && !machineOutput {
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") && cli.Version != "dev" {
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
cli.Version != "dev" {
githubTag := &latest.GithubTag{
Owner: "juanfont",
Repository: "headscale",
}
res, err := latest.Check(githubTag, cli.Version)
if err == nil && res.Outdated {
//nolint
fmt.Printf(
"An updated version of Headscale has been found (%s vs. your current %s). Check it out https://github.com/juanfont/headscale/releases\n",
res.Current,

View file

@ -1,7 +1,6 @@
package main
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
@ -40,7 +39,10 @@ func (*Suite) TestConfigLoading(c *check.C) {
}
// Symlink the example config file
err = os.Symlink(filepath.Clean(path+"/../../config-example.yaml"), filepath.Join(tmpDir, "config.yaml"))
err = os.Symlink(
filepath.Clean(path+"/../../config-example.yaml"),
filepath.Join(tmpDir, "config.yaml"),
)
if err != nil {
c.Fatal(err)
}
@ -74,7 +76,10 @@ func (*Suite) TestDNSConfigLoading(c *check.C) {
}
// Symlink the example config file
err = os.Symlink(filepath.Clean(path+"/../../config-example.yaml"), filepath.Join(tmpDir, "config.yaml"))
err = os.Symlink(
filepath.Clean(path+"/../../config-example.yaml"),
filepath.Join(tmpDir, "config.yaml"),
)
if err != nil {
c.Fatal(err)
}
@ -94,7 +99,7 @@ func (*Suite) TestDNSConfigLoading(c *check.C) {
func writeConfig(c *check.C, tmpDir string, configYaml []byte) {
// Populate a custom config file
configFile := filepath.Join(tmpDir, "config.yaml")
err := ioutil.WriteFile(configFile, configYaml, 0o644)
err := ioutil.WriteFile(configFile, configYaml, 0o600)
if err != nil {
c.Fatalf("Couldn't write file %s", configFile)
}
@ -106,7 +111,6 @@ func (*Suite) TestTLSConfigValidation(c *check.C) {
c.Fatal(err)
}
// defer os.RemoveAll(tmpDir)
fmt.Println(tmpDir)
configYaml := []byte(
"---\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"\"\ntls_cert_path: \"abc.pem\"",
@ -128,8 +132,11 @@ func (*Suite) TestTLSConfigValidation(c *check.C) {
check.Matches,
".*Fatal config error: the only supported values for tls_letsencrypt_challenge_type are.*",
)
c.Assert(tmp, check.Matches, ".*Fatal config error: server_url must start with https:// or http://.*")
fmt.Println(tmp)
c.Assert(
tmp,
check.Matches,
".*Fatal config error: server_url must start with https:// or http://.*",
)
// Check configuration validation errors (2)
configYaml = []byte(

35
db.go
View file

@ -9,7 +9,10 @@ import (
"gorm.io/gorm/logger"
)
const dbVersion = "1"
const (
dbVersion = "1"
errValueNotFound = Error("not found")
)
// KV is a key-value store in a psql table. For future use...
type KV struct {
@ -24,7 +27,7 @@ func (h *Headscale) initDB() error {
}
h.db = db
if h.dbType == "postgres" {
if h.dbType == Postgres {
db.Exec("create extension if not exists \"uuid-ossp\";")
}
err = db.AutoMigrate(&Machine{})
@ -50,6 +53,7 @@ func (h *Headscale) initDB() error {
}
err = h.setValue("db_version", dbVersion)
return err
}
@ -65,12 +69,12 @@ func (h *Headscale) openDB() (*gorm.DB, error) {
}
switch h.dbType {
case "sqlite3":
case Sqlite:
db, err = gorm.Open(sqlite.Open(h.dbString), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
Logger: log,
})
case "postgres":
case Postgres:
db, err = gorm.Open(postgres.Open(h.dbString), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
Logger: log,
@ -84,28 +88,33 @@ func (h *Headscale) openDB() (*gorm.DB, error) {
return db, nil
}
// getValue returns the value for the given key in KV
// getValue returns the value for the given key in KV.
func (h *Headscale) getValue(key string) (string, error) {
var row KV
if result := h.db.First(&row, "key = ?", key); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", errors.New("not found")
if result := h.db.First(&row, "key = ?", key); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
return "", errValueNotFound
}
return row.Value, nil
}
// setValue sets value for the given key in KV
// setValue sets value for the given key in KV.
func (h *Headscale) setValue(key string, value string) error {
kv := KV{
keyValue := KV{
Key: key,
Value: value,
}
_, err := h.getValue(key)
if err == nil {
h.db.Model(&kv).Where("key = ?", key).Update("value", value)
if _, err := h.getValue(key); err == nil {
h.db.Model(&keyValue).Where("key = ?", key).Update("value", value)
return nil
}
h.db.Create(kv)
h.db.Create(keyValue)
return nil
}

24
derp.go
View file

@ -1,6 +1,7 @@
package headscale
import (
"context"
"encoding/json"
"io"
"io/ioutil"
@ -10,9 +11,7 @@ import (
"time"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v2"
"tailscale.com/tailcfg"
)
@ -28,14 +27,24 @@ func loadDERPMapFromPath(path string) (*tailcfg.DERPMap, error) {
return nil, err
}
err = yaml.Unmarshal(b, &derpMap)
return &derpMap, err
}
func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
client := http.Client{
Timeout: 10 * time.Second,
ctx, cancel := context.WithTimeout(context.Background(), HTTPReadTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", addr.String(), nil)
if err != nil {
return nil, err
}
resp, err := client.Get(addr.String())
client := http.Client{
Timeout: HTTPReadTimeout,
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
@ -48,6 +57,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
var derpMap tailcfg.DERPMap
err = json.Unmarshal(body, &derpMap)
return &derpMap, err
}
@ -55,7 +65,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
// DERPMap, it will _only_ look at the Regions, an integer.
// If a region exists in two of the given DERPMaps, the region
// form the _last_ DERPMap will be preserved.
// An empty DERPMap list will result in a DERPMap with no regions
// An empty DERPMap list will result in a DERPMap with no regions.
func mergeDERPMaps(derpMaps []*tailcfg.DERPMap) *tailcfg.DERPMap {
result := tailcfg.DERPMap{
OmitDefaultRegions: false,
@ -86,6 +96,7 @@ func GetDERPMap(cfg DERPConfig) *tailcfg.DERPMap {
Str("path", path).
Err(err).
Msg("Could not load DERP map from path")
break
}
@ -104,6 +115,7 @@ func GetDERPMap(cfg DERPConfig) *tailcfg.DERPMap {
Str("url", addr.String()).
Err(err).
Msg("Could not load DERP map from path")
break
}

34
dns.go
View file

@ -10,6 +10,10 @@ import (
"tailscale.com/util/dnsname"
)
const (
ByteSize = 8
)
// generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`.
// This list of reverse DNS entries instructs the OS on what subnets and domains the Tailscale embedded DNS
// server (listening in 100.100.100.100 udp/53) should be used for.
@ -30,7 +34,9 @@ import (
// From the netmask we can find out the wildcard bits (the bits that are not set in the netmask).
// This allows us to then calculate the subnets included in the subsequent class block and generate the entries.
func generateMagicDNSRootDomains(ipPrefix netaddr.IPPrefix, baseDomain string) ([]dnsname.FQDN, error) {
func generateMagicDNSRootDomains(
ipPrefix netaddr.IPPrefix,
) []dnsname.FQDN {
// TODO(juanfont): we are not handing out IPv6 addresses yet
// and in fact this is Tailscale.com's range (note the fd7a:115c:a1e0: range in the fc00::/7 network)
ipv6base := dnsname.FQDN("0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa.")
@ -41,15 +47,15 @@ func generateMagicDNSRootDomains(ipPrefix netaddr.IPPrefix, baseDomain string) (
maskBits, _ := netRange.Mask.Size()
// lastOctet is the last IP byte covered by the mask
lastOctet := maskBits / 8
lastOctet := maskBits / ByteSize
// wildcardBits is the number of bits not under the mask in the lastOctet
wildcardBits := 8 - maskBits%8
wildcardBits := ByteSize - maskBits%ByteSize
// min is the value in the lastOctet byte of the IP
// max is basically 2^wildcardBits - i.e., the value when all the wildcardBits are set to 1
min := uint(netRange.IP[lastOctet])
max := uint((min + 1<<uint(wildcardBits)) - 1)
max := (min + 1<<uint(wildcardBits)) - 1
// here we generate the base domain (e.g., 100.in-addr.arpa., 16.172.in-addr.arpa., etc.)
rdnsSlice := []string{}
@ -66,18 +72,27 @@ func generateMagicDNSRootDomains(ipPrefix netaddr.IPPrefix, baseDomain string) (
}
fqdns = append(fqdns, fqdn)
}
return fqdns, nil
return fqdns
}
func getMapResponseDNSConfig(dnsConfigOrig *tailcfg.DNSConfig, baseDomain string, m Machine, peers Machines) (*tailcfg.DNSConfig, error) {
func getMapResponseDNSConfig(
dnsConfigOrig *tailcfg.DNSConfig,
baseDomain string,
machine Machine,
peers Machines,
) *tailcfg.DNSConfig {
var dnsConfig *tailcfg.DNSConfig
if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled
// Only inject the Search Domain of the current namespace - shared nodes should use their full FQDN
dnsConfig = dnsConfigOrig.Clone()
dnsConfig.Domains = append(dnsConfig.Domains, fmt.Sprintf("%s.%s", m.Namespace.Name, baseDomain))
dnsConfig.Domains = append(
dnsConfig.Domains,
fmt.Sprintf("%s.%s", machine.Namespace.Name, baseDomain),
)
namespaceSet := set.New(set.ThreadSafe)
namespaceSet.Add(m.Namespace)
namespaceSet.Add(machine.Namespace)
for _, p := range peers {
namespaceSet.Add(p.Namespace)
}
@ -88,5 +103,6 @@ func getMapResponseDNSConfig(dnsConfigOrig *tailcfg.DNSConfig, baseDomain string
} else {
dnsConfig = dnsConfigOrig
}
return dnsConfig, nil
return dnsConfig
}

View file

@ -11,13 +11,13 @@ import (
func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
prefix := netaddr.MustParseIPPrefix("100.64.0.0/10")
domains, err := generateMagicDNSRootDomains(prefix, "foobar.headscale.net")
c.Assert(err, check.IsNil)
domains := generateMagicDNSRootDomains(prefix)
found := false
for _, domain := range domains {
if domain == "64.100.in-addr.arpa." {
found = true
break
}
}
@ -27,6 +27,7 @@ func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
for _, domain := range domains {
if domain == "100.100.in-addr.arpa." {
found = true
break
}
}
@ -36,6 +37,7 @@ func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
for _, domain := range domains {
if domain == "127.100.in-addr.arpa." {
found = true
break
}
}
@ -44,13 +46,13 @@ func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
func (s *Suite) TestMagicDNSRootDomains172(c *check.C) {
prefix := netaddr.MustParseIPPrefix("172.16.0.0/16")
domains, err := generateMagicDNSRootDomains(prefix, "headscale.net")
c.Assert(err, check.IsNil)
domains := generateMagicDNSRootDomains(prefix)
found := false
for _, domain := range domains {
if domain == "0.16.172.in-addr.arpa." {
found = true
break
}
}
@ -60,6 +62,7 @@ func (s *Suite) TestMagicDNSRootDomains172(c *check.C) {
for _, domain := range domains {
if domain == "255.16.172.in-addr.arpa." {
found = true
break
}
}
@ -67,100 +70,120 @@ func (s *Suite) TestMagicDNSRootDomains172(c *check.C) {
}
func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
n1, err := h.CreateNamespace("shared1")
namespaceShared1, err := app.CreateNamespace("shared1")
c.Assert(err, check.IsNil)
n2, err := h.CreateNamespace("shared2")
namespaceShared2, err := app.CreateNamespace("shared2")
c.Assert(err, check.IsNil)
n3, err := h.CreateNamespace("shared3")
namespaceShared3, err := app.CreateNamespace("shared3")
c.Assert(err, check.IsNil)
pak1n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
preAuthKeyInShared1, err := app.CreatePreAuthKey(
namespaceShared1.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil)
pak2n2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
preAuthKeyInShared2, err := app.CreatePreAuthKey(
namespaceShared2.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil)
pak3n3, err := h.CreatePreAuthKey(n3.Name, false, false, nil)
preAuthKeyInShared3, err := app.CreatePreAuthKey(
namespaceShared3.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil)
pak4n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
PreAuthKey2InShared1, err := app.CreatePreAuthKey(
namespaceShared1.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil)
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
_, err = app.GetMachine(namespaceShared1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil)
m1 := &Machine{
machineInShared1 := &Machine{
ID: 1,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
Name: "test_get_shared_nodes_1",
NamespaceID: n1.ID,
Namespace: *n1,
NamespaceID: namespaceShared1.ID,
Namespace: *namespaceShared1,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.1",
AuthKeyID: uint(pak1n1.ID),
AuthKeyID: uint(preAuthKeyInShared1.ID),
}
h.db.Save(m1)
app.db.Save(machineInShared1)
_, err = h.GetMachine(n1.Name, m1.Name)
_, err = app.GetMachine(namespaceShared1.Name, machineInShared1.Name)
c.Assert(err, check.IsNil)
m2 := &Machine{
machineInShared2 := &Machine{
ID: 2,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_2",
NamespaceID: n2.ID,
Namespace: *n2,
NamespaceID: namespaceShared2.ID,
Namespace: *namespaceShared2,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.2",
AuthKeyID: uint(pak2n2.ID),
AuthKeyID: uint(preAuthKeyInShared2.ID),
}
h.db.Save(m2)
app.db.Save(machineInShared2)
_, err = h.GetMachine(n2.Name, m2.Name)
_, err = app.GetMachine(namespaceShared2.Name, machineInShared2.Name)
c.Assert(err, check.IsNil)
m3 := &Machine{
machineInShared3 := &Machine{
ID: 3,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_3",
NamespaceID: n3.ID,
Namespace: *n3,
NamespaceID: namespaceShared3.ID,
Namespace: *namespaceShared3,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.3",
AuthKeyID: uint(pak3n3.ID),
AuthKeyID: uint(preAuthKeyInShared3.ID),
}
h.db.Save(m3)
app.db.Save(machineInShared3)
_, err = h.GetMachine(n3.Name, m3.Name)
_, err = app.GetMachine(namespaceShared3.Name, machineInShared3.Name)
c.Assert(err, check.IsNil)
m4 := &Machine{
machine2InShared1 := &Machine{
ID: 4,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_4",
NamespaceID: n1.ID,
Namespace: *n1,
NamespaceID: namespaceShared1.ID,
Namespace: *namespaceShared1,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.4",
AuthKeyID: uint(pak4n1.ID),
AuthKeyID: uint(PreAuthKey2InShared1.ID),
}
h.db.Save(m4)
app.db.Save(machine2InShared1)
err = h.AddSharedMachineToNamespace(m2, n1)
err = app.AddSharedMachineToNamespace(machineInShared2, namespaceShared1)
c.Assert(err, check.IsNil)
baseDomain := "foobar.headscale.net"
@ -170,122 +193,146 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
Proxied: true,
}
m1peers, err := h.getPeers(m1)
peersOfMachineInShared1, err := app.getPeers(machineInShared1)
c.Assert(err, check.IsNil)
dnsConfig, err := getMapResponseDNSConfig(&dnsConfigOrig, baseDomain, *m1, m1peers)
c.Assert(err, check.IsNil)
dnsConfig := getMapResponseDNSConfig(
&dnsConfigOrig,
baseDomain,
*machineInShared1,
peersOfMachineInShared1,
)
c.Assert(dnsConfig, check.NotNil)
c.Assert(len(dnsConfig.Routes), check.Equals, 2)
routeN1 := fmt.Sprintf("%s.%s", n1.Name, baseDomain)
_, ok := dnsConfig.Routes[routeN1]
domainRouteShared1 := fmt.Sprintf("%s.%s", namespaceShared1.Name, baseDomain)
_, ok := dnsConfig.Routes[domainRouteShared1]
c.Assert(ok, check.Equals, true)
routeN2 := fmt.Sprintf("%s.%s", n2.Name, baseDomain)
_, ok = dnsConfig.Routes[routeN2]
domainRouteShared2 := fmt.Sprintf("%s.%s", namespaceShared2.Name, baseDomain)
_, ok = dnsConfig.Routes[domainRouteShared2]
c.Assert(ok, check.Equals, true)
routeN3 := fmt.Sprintf("%s.%s", n3.Name, baseDomain)
_, ok = dnsConfig.Routes[routeN3]
domainRouteShared3 := fmt.Sprintf("%s.%s", namespaceShared3.Name, baseDomain)
_, ok = dnsConfig.Routes[domainRouteShared3]
c.Assert(ok, check.Equals, false)
}
func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
n1, err := h.CreateNamespace("shared1")
namespaceShared1, err := app.CreateNamespace("shared1")
c.Assert(err, check.IsNil)
n2, err := h.CreateNamespace("shared2")
namespaceShared2, err := app.CreateNamespace("shared2")
c.Assert(err, check.IsNil)
n3, err := h.CreateNamespace("shared3")
namespaceShared3, err := app.CreateNamespace("shared3")
c.Assert(err, check.IsNil)
pak1n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
preAuthKeyInShared1, err := app.CreatePreAuthKey(
namespaceShared1.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil)
pak2n2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
preAuthKeyInShared2, err := app.CreatePreAuthKey(
namespaceShared2.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil)
pak3n3, err := h.CreatePreAuthKey(n3.Name, false, false, nil)
preAuthKeyInShared3, err := app.CreatePreAuthKey(
namespaceShared3.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil)
pak4n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
preAuthKey2InShared1, err := app.CreatePreAuthKey(
namespaceShared1.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil)
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
_, err = app.GetMachine(namespaceShared1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil)
m1 := &Machine{
machineInShared1 := &Machine{
ID: 1,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
Name: "test_get_shared_nodes_1",
NamespaceID: n1.ID,
Namespace: *n1,
NamespaceID: namespaceShared1.ID,
Namespace: *namespaceShared1,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.1",
AuthKeyID: uint(pak1n1.ID),
AuthKeyID: uint(preAuthKeyInShared1.ID),
}
h.db.Save(m1)
app.db.Save(machineInShared1)
_, err = h.GetMachine(n1.Name, m1.Name)
_, err = app.GetMachine(namespaceShared1.Name, machineInShared1.Name)
c.Assert(err, check.IsNil)
m2 := &Machine{
machineInShared2 := &Machine{
ID: 2,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_2",
NamespaceID: n2.ID,
Namespace: *n2,
NamespaceID: namespaceShared2.ID,
Namespace: *namespaceShared2,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.2",
AuthKeyID: uint(pak2n2.ID),
AuthKeyID: uint(preAuthKeyInShared2.ID),
}
h.db.Save(m2)
app.db.Save(machineInShared2)
_, err = h.GetMachine(n2.Name, m2.Name)
_, err = app.GetMachine(namespaceShared2.Name, machineInShared2.Name)
c.Assert(err, check.IsNil)
m3 := &Machine{
machineInShared3 := &Machine{
ID: 3,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_3",
NamespaceID: n3.ID,
Namespace: *n3,
NamespaceID: namespaceShared3.ID,
Namespace: *namespaceShared3,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.3",
AuthKeyID: uint(pak3n3.ID),
AuthKeyID: uint(preAuthKeyInShared3.ID),
}
h.db.Save(m3)
app.db.Save(machineInShared3)
_, err = h.GetMachine(n3.Name, m3.Name)
_, err = app.GetMachine(namespaceShared3.Name, machineInShared3.Name)
c.Assert(err, check.IsNil)
m4 := &Machine{
machine2InShared1 := &Machine{
ID: 4,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_4",
NamespaceID: n1.ID,
Namespace: *n1,
NamespaceID: namespaceShared1.ID,
Namespace: *namespaceShared1,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.4",
AuthKeyID: uint(pak4n1.ID),
AuthKeyID: uint(preAuthKey2InShared1.ID),
}
h.db.Save(m4)
app.db.Save(machine2InShared1)
err = h.AddSharedMachineToNamespace(m2, n1)
err = app.AddSharedMachineToNamespace(machineInShared2, namespaceShared1)
c.Assert(err, check.IsNil)
baseDomain := "foobar.headscale.net"
@ -295,11 +342,15 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
Proxied: false,
}
m1peers, err := h.getPeers(m1)
peersOfMachine1Shared1, err := app.getPeers(machineInShared1)
c.Assert(err, check.IsNil)
dnsConfig, err := getMapResponseDNSConfig(&dnsConfigOrig, baseDomain, *m1, m1peers)
c.Assert(err, check.IsNil)
dnsConfig := getMapResponseDNSConfig(
&dnsConfigOrig,
baseDomain,
*machineInShared1,
peersOfMachine1Shared1,
)
c.Assert(dnsConfig, check.NotNil)
c.Assert(len(dnsConfig.Routes), check.Equals, 0)
c.Assert(len(dnsConfig.Domains), check.Equals, 1)

View file

@ -1,6 +1,7 @@
# Running headscale
## Server configuration
1. Download the headscale binary https://github.com/juanfont/headscale/releases, and place it somewhere in your $PATH or use the docker container
```shell
@ -22,6 +23,7 @@
3. Get yourself a DB
a) Get a Postgres DB running in docker
```shell
docker run --name headscale \
-e POSTGRES_DB=headscale
@ -30,7 +32,9 @@
-p 5432:5432 \
-d postgres
```
or b) Prepare a SQLite DB file
```shell
touch config/db.sqlite
```
@ -81,6 +85,7 @@
-p 127.0.0.1:8080:8080 \
headscale/headscale:x.x.x headscale serve
```
## Nodes configuration
If you used tailscale.com before in your nodes, make sure you clear the tailscaled data folder
@ -90,7 +95,9 @@ If you used tailscale.com before in your nodes, make sure you clear the tailscal
rm -fr /var/lib/tailscale
systemctl start tailscaled
```
### Adding node based on MACHINEKEY
1. Add your first machine
```shell

View file

@ -155,7 +155,10 @@ func (api headscaleV1APIServer) RegisterMachine(
ctx context.Context,
request *v1.RegisterMachineRequest,
) (*v1.RegisterMachineResponse, error) {
log.Trace().Str("namespace", request.GetNamespace()).Str("machine_key", request.GetKey()).Msg("Registering machine")
log.Trace().
Str("namespace", request.GetNamespace()).
Str("machine_key", request.GetKey()).
Msg("Registering machine")
machine, err := api.h.RegisterMachine(
request.GetKey(),
request.GetNamespace(),
@ -208,7 +211,9 @@ func (api headscaleV1APIServer) ListMachines(
return nil, err
}
sharedMachines, err := api.h.ListSharedMachinesInNamespace(request.GetNamespace())
sharedMachines, err := api.h.ListSharedMachinesInNamespace(
request.GetNamespace(),
)
if err != nil {
return nil, err
}
@ -333,12 +338,16 @@ func (api headscaleV1APIServer) DebugCreateMachine(
return nil, err
}
routes, err := stringToIpPrefix(request.GetRoutes())
routes, err := stringToIPPrefix(request.GetRoutes())
if err != nil {
return nil, err
}
log.Trace().Caller().Interface("route-prefix", routes).Interface("route-str", request.GetRoutes()).Msg("")
log.Trace().
Caller().
Interface("route-prefix", routes).
Interface("route-str", request.GetRoutes()).
Msg("")
hostinfo := tailcfg.Hostinfo{
RoutableIPs: routes,

View file

@ -88,6 +88,7 @@ func (s *IntegrationCLITestSuite) SetupTest() {
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("status code not OK")
}
return nil
}); err != nil {
// TODO(kradalby): If we cannot access headscale, or any other fatal error during
@ -109,7 +110,10 @@ func (s *IntegrationCLITestSuite) TearDownTest() {
}
}
func (s *IntegrationCLITestSuite) HandleStats(suiteName string, stats *suite.SuiteInformation) {
func (s *IntegrationCLITestSuite) HandleStats(
suiteName string,
stats *suite.SuiteInformation,
) {
s.stats = stats
}
@ -144,7 +148,6 @@ func (s *IntegrationCLITestSuite) TestNamespaceCommand() {
namespaces := make([]*v1.Namespace, len(names))
for index, namespaceName := range names {
namespace, err := s.createNamespace(namespaceName)
assert.Nil(s.T(), err)
@ -298,11 +301,26 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
assert.True(s.T(), listedPreAuthKeys[3].Expiration.AsTime().After(time.Now()))
assert.True(s.T(), listedPreAuthKeys[4].Expiration.AsTime().After(time.Now()))
assert.True(s.T(), listedPreAuthKeys[0].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
assert.True(s.T(), listedPreAuthKeys[1].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
assert.True(s.T(), listedPreAuthKeys[2].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
assert.True(s.T(), listedPreAuthKeys[3].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
assert.True(s.T(), listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
assert.True(
s.T(),
listedPreAuthKeys[0].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
)
assert.True(
s.T(),
listedPreAuthKeys[1].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
)
assert.True(
s.T(),
listedPreAuthKeys[2].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
)
assert.True(
s.T(),
listedPreAuthKeys[3].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
)
assert.True(
s.T(),
listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
)
// Expire three keys
for i := 0; i < 3; i++ {
@ -341,11 +359,26 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
err = json.Unmarshal([]byte(listAfterExpireResult), &listedAfterExpirePreAuthKeys)
assert.Nil(s.T(), err)
assert.True(s.T(), listedAfterExpirePreAuthKeys[0].Expiration.AsTime().Before(time.Now()))
assert.True(s.T(), listedAfterExpirePreAuthKeys[1].Expiration.AsTime().Before(time.Now()))
assert.True(s.T(), listedAfterExpirePreAuthKeys[2].Expiration.AsTime().Before(time.Now()))
assert.True(s.T(), listedAfterExpirePreAuthKeys[3].Expiration.AsTime().After(time.Now()))
assert.True(s.T(), listedAfterExpirePreAuthKeys[4].Expiration.AsTime().After(time.Now()))
assert.True(
s.T(),
listedAfterExpirePreAuthKeys[0].Expiration.AsTime().Before(time.Now()),
)
assert.True(
s.T(),
listedAfterExpirePreAuthKeys[1].Expiration.AsTime().Before(time.Now()),
)
assert.True(
s.T(),
listedAfterExpirePreAuthKeys[2].Expiration.AsTime().Before(time.Now()),
)
assert.True(
s.T(),
listedAfterExpirePreAuthKeys[3].Expiration.AsTime().After(time.Now()),
)
assert.True(
s.T(),
listedAfterExpirePreAuthKeys[4].Expiration.AsTime().After(time.Now()),
)
}
func (s *IntegrationCLITestSuite) TestPreAuthKeyCommandWithoutExpiry() {
@ -689,7 +722,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
assert.Nil(s.T(), err)
var listOnlySharedMachineNamespace []v1.Machine
err = json.Unmarshal([]byte(listOnlySharedMachineNamespaceResult), &listOnlySharedMachineNamespace)
err = json.Unmarshal(
[]byte(listOnlySharedMachineNamespaceResult),
&listOnlySharedMachineNamespace,
)
assert.Nil(s.T(), err)
assert.Len(s.T(), listOnlySharedMachineNamespace, 2)
@ -738,7 +774,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
assert.Nil(s.T(), err)
var listOnlyMachineNamespaceAfterDelete []v1.Machine
err = json.Unmarshal([]byte(listOnlyMachineNamespaceAfterDeleteResult), &listOnlyMachineNamespaceAfterDelete)
err = json.Unmarshal(
[]byte(listOnlyMachineNamespaceAfterDeleteResult),
&listOnlyMachineNamespaceAfterDelete,
)
assert.Nil(s.T(), err)
assert.Len(s.T(), listOnlyMachineNamespaceAfterDelete, 4)
@ -789,7 +828,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
assert.Nil(s.T(), err)
var listOnlyMachineNamespaceAfterShare []v1.Machine
err = json.Unmarshal([]byte(listOnlyMachineNamespaceAfterShareResult), &listOnlyMachineNamespaceAfterShare)
err = json.Unmarshal(
[]byte(listOnlyMachineNamespaceAfterShareResult),
&listOnlyMachineNamespaceAfterShare,
)
assert.Nil(s.T(), err)
assert.Len(s.T(), listOnlyMachineNamespaceAfterShare, 5)
@ -846,7 +888,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
assert.Nil(s.T(), err)
var listOnlyMachineNamespaceAfterUnshare []v1.Machine
err = json.Unmarshal([]byte(listOnlyMachineNamespaceAfterUnshareResult), &listOnlyMachineNamespaceAfterUnshare)
err = json.Unmarshal(
[]byte(listOnlyMachineNamespaceAfterUnshareResult),
&listOnlyMachineNamespaceAfterUnshare,
)
assert.Nil(s.T(), err)
assert.Len(s.T(), listOnlyMachineNamespaceAfterUnshare, 4)
@ -1010,5 +1055,9 @@ func (s *IntegrationCLITestSuite) TestRouteCommand() {
)
assert.Nil(s.T(), err)
assert.Contains(s.T(), string(failEnableNonAdvertisedRoute), "route (route-machine) is not available on node")
assert.Contains(
s.T(),
string(failEnableNonAdvertisedRoute),
"route (route-machine) is not available on node",
)
}

View file

@ -12,12 +12,18 @@ import (
"github.com/ory/dockertest/v3/docker"
)
func ExecuteCommand(resource *dockertest.Resource, cmd []string, env []string) (string, error) {
const DOCKER_EXECUTE_TIMEOUT = 10 * time.Second
func ExecuteCommand(
resource *dockertest.Resource,
cmd []string,
env []string,
) (string, error) {
var stdout bytes.Buffer
var stderr bytes.Buffer
// TODO(kradalby): Make configurable
timeout := 10 * time.Second
timeout := DOCKER_EXECUTE_TIMEOUT
type result struct {
exitCode int
@ -51,11 +57,13 @@ func ExecuteCommand(resource *dockertest.Resource, cmd []string, env []string) (
fmt.Println("Command: ", cmd)
fmt.Println("stdout: ", stdout.String())
fmt.Println("stderr: ", stderr.String())
return "", fmt.Errorf("command failed with: %s", stderr.String())
}
return stdout.String(), nil
case <-time.After(timeout):
return "", fmt.Errorf("command timed out after %s", timeout)
}
}

View file

@ -23,10 +23,9 @@ import (
"github.com/ory/dockertest/v3/docker"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"inet.af/netaddr"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/ipn/ipnstate"
"inet.af/netaddr"
)
var tailscaleVersions = []string{"1.16.2", "1.14.3", "1.12.3"}
@ -89,7 +88,10 @@ func TestIntegrationTestSuite(t *testing.T) {
}
}
func (s *IntegrationTestSuite) saveLog(resource *dockertest.Resource, basePath string) error {
func (s *IntegrationTestSuite) saveLog(
resource *dockertest.Resource,
basePath string,
) error {
err := os.MkdirAll(basePath, os.ModePerm)
if err != nil {
return err
@ -118,12 +120,20 @@ func (s *IntegrationTestSuite) saveLog(resource *dockertest.Resource, basePath s
fmt.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath)
err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stdout.log"), []byte(stdout.String()), 0o644)
err = ioutil.WriteFile(
path.Join(basePath, resource.Container.Name+".stdout.log"),
[]byte(stdout.String()),
0o644,
)
if err != nil {
return err
}
err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stderr.log"), []byte(stdout.String()), 0o644)
err = ioutil.WriteFile(
path.Join(basePath, resource.Container.Name+".stderr.log"),
[]byte(stdout.String()),
0o644,
)
if err != nil {
return err
}
@ -144,24 +154,38 @@ func (s *IntegrationTestSuite) tailscaleContainer(
},
},
}
hostname := fmt.Sprintf("%s-tailscale-%s-%s", namespace, strings.Replace(version, ".", "-", -1), identifier)
hostname := fmt.Sprintf(
"%s-tailscale-%s-%s",
namespace,
strings.Replace(version, ".", "-", -1),
identifier,
)
tailscaleOptions := &dockertest.RunOptions{
Name: hostname,
Networks: []*dockertest.Network{&s.network},
Cmd: []string{"tailscaled", "--tun=userspace-networking", "--socks5-server=localhost:1055"},
Cmd: []string{
"tailscaled",
"--tun=userspace-networking",
"--socks5-server=localhost:1055",
},
}
pts, err := s.pool.BuildAndRunWithBuildOptions(tailscaleBuildOptions, tailscaleOptions, DockerRestartPolicy)
pts, err := s.pool.BuildAndRunWithBuildOptions(
tailscaleBuildOptions,
tailscaleOptions,
DockerRestartPolicy,
)
if err != nil {
log.Fatalf("Could not start resource: %s", err)
}
fmt.Printf("Created %s container\n", hostname)
return hostname, pts
}
func (s *IntegrationTestSuite) SetupSuite() {
var err error
h = Headscale{
app = Headscale{
dbType: "sqlite3",
dbString: "integration_test_db.sqlite3",
}
@ -210,7 +234,11 @@ func (s *IntegrationTestSuite) SetupSuite() {
for i := 0; i < scales.count; i++ {
version := tailscaleVersions[i%len(tailscaleVersions)]
hostname, container := s.tailscaleContainer(namespace, fmt.Sprint(i), version)
hostname, container := s.tailscaleContainer(
namespace,
fmt.Sprint(i),
version,
)
scales.tailscales[hostname] = *container
}
}
@ -220,13 +248,16 @@ func (s *IntegrationTestSuite) SetupSuite() {
if err := s.pool.Retry(func() error {
url := fmt.Sprintf("http://%s/health", hostEndpoint)
resp, err := http.Get(url)
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("status code not OK")
}
return nil
}); err != nil {
// TODO(kradalby): If we cannot access headscale, or any other fatal error during
@ -273,7 +304,10 @@ func (s *IntegrationTestSuite) SetupSuite() {
headscaleEndpoint := "http://headscale:8080"
fmt.Printf("Joining tailscale containers to headscale at %s\n", headscaleEndpoint)
fmt.Printf(
"Joining tailscale containers to headscale at %s\n",
headscaleEndpoint,
)
for hostname, tailscale := range scales.tailscales {
command := []string{
"tailscale",
@ -307,7 +341,10 @@ func (s *IntegrationTestSuite) SetupSuite() {
func (s *IntegrationTestSuite) TearDownSuite() {
}
func (s *IntegrationTestSuite) HandleStats(suiteName string, stats *suite.SuiteInformation) {
func (s *IntegrationTestSuite) HandleStats(
suiteName string,
stats *suite.SuiteInformation,
) {
s.stats = stats
}
@ -427,7 +464,13 @@ func (s *IntegrationTestSuite) TestPingAllPeers() {
ip.String(),
}
fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip)
fmt.Printf(
"Pinging from %s (%s) to %s (%s)\n",
hostname,
ips[hostname],
peername,
ip,
)
result, err := ExecuteCommand(
&tailscale,
command,
@ -449,7 +492,15 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
result, err := ExecuteCommand(
&s.headscale,
[]string{"headscale", "nodes", "list", "--output", "json", "--namespace", "shared"},
[]string{
"headscale",
"nodes",
"list",
"--output",
"json",
"--namespace",
"shared",
},
[]string{},
)
assert.Nil(s.T(), err)
@ -459,7 +510,6 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
assert.Nil(s.T(), err)
for _, machine := range machineList {
result, err := ExecuteCommand(
&s.headscale,
[]string{
@ -520,7 +570,13 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
ip.String(),
}
fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, mainIps[hostname], peername, ip)
fmt.Printf(
"Pinging from %s (%s) to %s (%s)\n",
hostname,
mainIps[hostname],
peername,
ip,
)
result, err := ExecuteCommand(
&tailscale,
command,
@ -553,7 +609,6 @@ func (s *IntegrationTestSuite) TestTailDrop() {
for peername, ip := range ips {
s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
if peername != hostname {
// Under normal circumstances, we should be able to send a file
// using `tailscale file cp` - but not in userspace networking mode
// So curl!
@ -578,9 +633,19 @@ func (s *IntegrationTestSuite) TestTailDrop() {
"PUT",
"--upload-file",
fmt.Sprintf("/tmp/file_from_%s", hostname),
fmt.Sprintf("%s/v0/put/file_from_%s", peerAPI, hostname),
fmt.Sprintf(
"%s/v0/put/file_from_%s",
peerAPI,
hostname,
),
}
fmt.Printf("Sending file from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip)
fmt.Printf(
"Sending file from %s (%s) to %s (%s)\n",
hostname,
ips[hostname],
peername,
ip,
)
_, err = ExecuteCommand(
&tailscale,
command,
@ -621,7 +686,13 @@ func (s *IntegrationTestSuite) TestTailDrop() {
"ls",
fmt.Sprintf("/tmp/file_from_%s", peername),
}
fmt.Printf("Checking file in %s (%s) from %s (%s)\n", hostname, ips[hostname], peername, ip)
fmt.Printf(
"Checking file in %s (%s) from %s (%s)\n",
hostname,
ips[hostname],
peername,
ip,
)
result, err := ExecuteCommand(
&tailscale,
command,
@ -629,7 +700,11 @@ func (s *IntegrationTestSuite) TestTailDrop() {
)
assert.Nil(t, err)
fmt.Printf("Result for %s: %s\n", peername, result)
assert.Equal(t, result, fmt.Sprintf("/tmp/file_from_%s\n", peername))
assert.Equal(
t,
result,
fmt.Sprintf("/tmp/file_from_%s\n", peername),
)
}
})
}
@ -696,10 +771,13 @@ func getIPs(tailscales map[string]dockertest.Resource) (map[string]netaddr.IP, e
ips[hostname] = ip
}
return ips, nil
}
func getAPIURLs(tailscales map[string]dockertest.Resource) (map[netaddr.IP]string, error) {
func getAPIURLs(
tailscales map[string]dockertest.Resource,
) (map[netaddr.IP]string, error) {
fts := make(map[netaddr.IP]string)
for _, tailscale := range tailscales {
command := []string{
@ -733,5 +811,6 @@ func getAPIURLs(tailscales map[string]dockertest.Resource) (map[netaddr.IP]strin
}
}
}
return fts, nil
}

View file

@ -24,6 +24,7 @@ Configure DERP servers by editing `base/site/derp.yaml` if needed.
You'll somehow need to get `headscale:latest` into your cluster image registry.
An easy way to do this with k3s:
- Reconfigure k3s to use docker instead of containerd (`k3s server --docker`)
- `docker build -t headscale:latest ..` from here

View file

@ -10,10 +10,9 @@ import (
"time"
"github.com/fatih/set"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/rs/zerolog/log"
"google.golang.org/protobuf/types/known/timestamppb"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"gorm.io/datatypes"
"gorm.io/gorm"
"inet.af/netaddr"
@ -21,7 +20,13 @@ import (
"tailscale.com/types/wgkey"
)
// Machine is a Headscale client
const (
errMachineNotFound = Error("machine not found")
errMachineAlreadyRegistered = Error("machine already registered")
errMachineRouteIsNotAvailable = Error("route is not available on machine")
)
// Machine is a Headscale client.
type Machine struct {
ID uint64 `gorm:"primary_key"`
MachineKey string `gorm:"type:varchar(64);unique_index"`
@ -56,53 +61,58 @@ type (
MachinesP []*Machine
)
// For the time being this method is rather naive
func (m Machine) isAlreadyRegistered() bool {
return m.Registered
// For the time being this method is rather naive.
func (machine Machine) isAlreadyRegistered() bool {
return machine.Registered
}
// isExpired returns whether the machine registration has expired
func (m Machine) isExpired() bool {
return time.Now().UTC().After(*m.Expiry)
// isExpired returns whether the machine registration has expired.
func (machine Machine) isExpired() bool {
return time.Now().UTC().After(*machine.Expiry)
}
// If the Machine is expired, updateMachineExpiry updates the Machine Expiry time to the maximum allowed duration,
// or the default duration if no Expiry time was requested by the client. The expiry time here does not (yet) cause
// a client to be disconnected, however they will have to re-auth the machine if they attempt to reconnect after the
// expiry time.
func (h *Headscale) updateMachineExpiry(m *Machine) {
if m.isExpired() {
func (h *Headscale) updateMachineExpiry(machine *Machine) {
if machine.isExpired() {
now := time.Now().UTC()
maxExpiry := now.Add(h.cfg.MaxMachineRegistrationDuration) // calculate the maximum expiry
defaultExpiry := now.Add(h.cfg.DefaultMachineRegistrationDuration) // calculate the default expiry
maxExpiry := now.Add(
h.cfg.MaxMachineRegistrationDuration,
) // calculate the maximum expiry
defaultExpiry := now.Add(
h.cfg.DefaultMachineRegistrationDuration,
) // calculate the default expiry
// clamp the expiry time of the machine registration to the maximum allowed, or use the default if none supplied
if maxExpiry.Before(*m.RequestedExpiry) {
if maxExpiry.Before(*machine.RequestedExpiry) {
log.Debug().
Msgf("Clamping registration expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineRegistrationDuration)
m.Expiry = &maxExpiry
} else if m.RequestedExpiry.IsZero() {
machine.Expiry = &maxExpiry
} else if machine.RequestedExpiry.IsZero() {
log.Debug().Msgf("Using default machine registration expiry time: %v (%v)", defaultExpiry, h.cfg.DefaultMachineRegistrationDuration)
m.Expiry = &defaultExpiry
machine.Expiry = &defaultExpiry
} else {
log.Debug().Msgf("Using requested machine registration expiry time: %v", m.RequestedExpiry)
m.Expiry = m.RequestedExpiry
log.Debug().Msgf("Using requested machine registration expiry time: %v", machine.RequestedExpiry)
machine.Expiry = machine.RequestedExpiry
}
h.db.Save(&m)
h.db.Save(&machine)
}
}
func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
func (h *Headscale) getDirectPeers(machine *Machine) (Machines, error) {
log.Trace().
Caller().
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Finding direct peers")
machines := Machines{}
if err := h.db.Preload("Namespace").Where("namespace_id = ? AND machine_key <> ? AND registered",
m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil {
machine.NamespaceID, machine.MachineKey).Find(&machines).Error; err != nil {
log.Error().Err(err).Msg("Error accessing db")
return Machines{}, err
}
@ -110,21 +120,22 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
log.Trace().
Caller().
Str("machine", m.Name).
Str("machine", machine.Name).
Msgf("Found direct machines: %s", machines.String())
return machines, nil
}
// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for
func (h *Headscale) getShared(m *Machine) (Machines, error) {
// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for.
func (h *Headscale) getShared(machine *Machine) (Machines, error) {
log.Trace().
Caller().
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Finding shared peers")
sharedMachines := []SharedMachine{}
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("namespace_id = ?",
m.NamespaceID).Find(&sharedMachines).Error; err != nil {
machine.NamespaceID).Find(&sharedMachines).Error; err != nil {
return Machines{}, err
}
@ -137,27 +148,30 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) {
log.Trace().
Caller().
Str("machine", m.Name).
Str("machine", machine.Name).
Msgf("Found shared peers: %s", peers.String())
return peers, nil
}
// getSharedTo fetches the machines of the namespaces this machine is shared in
func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
// getSharedTo fetches the machines of the namespaces this machine is shared in.
func (h *Headscale) getSharedTo(machine *Machine) (Machines, error) {
log.Trace().
Caller().
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Finding peers in namespaces this machine is shared with")
sharedMachines := []SharedMachine{}
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("machine_id = ?",
m.ID).Find(&sharedMachines).Error; err != nil {
machine.ID).Find(&sharedMachines).Error; err != nil {
return Machines{}, err
}
peers := make(Machines, 0)
for _, sharedMachine := range sharedMachines {
namespaceMachines, err := h.ListMachinesInNamespace(sharedMachine.Namespace.Name)
namespaceMachines, err := h.ListMachinesInNamespace(
sharedMachine.Namespace.Name,
)
if err != nil {
return Machines{}, err
}
@ -168,36 +182,40 @@ func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
log.Trace().
Caller().
Str("machine", m.Name).
Str("machine", machine.Name).
Msgf("Found peers we are shared with: %s", peers.String())
return peers, nil
}
func (h *Headscale) getPeers(m *Machine) (Machines, error) {
direct, err := h.getDirectPeers(m)
func (h *Headscale) getPeers(machine *Machine) (Machines, error) {
direct, err := h.getDirectPeers(machine)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot fetch peers")
return Machines{}, err
}
shared, err := h.getShared(m)
shared, err := h.getShared(machine)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot fetch peers")
return Machines{}, err
}
sharedTo, err := h.getSharedTo(m)
sharedTo, err := h.getSharedTo(machine)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot fetch peers")
return Machines{}, err
}
@ -208,7 +226,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
log.Trace().
Caller().
Str("machine", m.Name).
Str("machine", machine.Name).
Msgf("Found total peers: %s", peers.String())
return peers, nil
@ -219,10 +237,11 @@ func (h *Headscale) ListMachines() ([]Machine, error) {
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Find(&machines).Error; err != nil {
return nil, err
}
return machines, nil
}
// GetMachine finds a Machine by name and namespace and returns the Machine struct
// GetMachine finds a Machine by name and namespace and returns the Machine struct.
func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error) {
machines, err := h.ListMachinesInNamespace(namespace)
if err != nil {
@ -234,73 +253,77 @@ func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error)
return &m, nil
}
}
return nil, fmt.Errorf("machine not found")
return nil, errMachineNotFound
}
// GetMachineByID finds a Machine by ID and returns the Machine struct
// GetMachineByID finds a Machine by ID and returns the Machine struct.
func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
m := Machine{}
if result := h.db.Preload("Namespace").Find(&Machine{ID: id}).First(&m); result.Error != nil {
return nil, result.Error
}
return &m, nil
}
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct
func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) {
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
func (h *Headscale) GetMachineByMachineKey(machineKey string) (*Machine, error) {
m := Machine{}
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey); result.Error != nil {
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", machineKey); result.Error != nil {
return nil, result.Error
}
return &m, nil
}
// UpdateMachine takes a Machine struct pointer (typically already loaded from database
// and updates it with the latest data from the database.
func (h *Headscale) UpdateMachine(m *Machine) error {
if result := h.db.Find(m).First(&m); result.Error != nil {
func (h *Headscale) UpdateMachine(machine *Machine) error {
if result := h.db.Find(machine).First(&machine); result.Error != nil {
return result.Error
}
return nil
}
// DeleteMachine softs deletes a Machine from the database
func (h *Headscale) DeleteMachine(m *Machine) error {
err := h.RemoveSharedMachineFromAllNamespaces(m)
if err != nil && err != errorMachineNotShared {
// DeleteMachine softs deletes a Machine from the database.
func (h *Headscale) DeleteMachine(machine *Machine) error {
err := h.RemoveSharedMachineFromAllNamespaces(machine)
if err != nil && errors.Is(err, errMachineNotShared) {
return err
}
m.Registered = false
namespaceID := m.NamespaceID
h.db.Save(&m) // we mark it as unregistered, just in case
if err := h.db.Delete(&m).Error; err != nil {
machine.Registered = false
namespaceID := machine.NamespaceID
h.db.Save(&machine) // we mark it as unregistered, just in case
if err := h.db.Delete(&machine).Error; err != nil {
return err
}
return h.RequestMapUpdates(namespaceID)
}
// HardDeleteMachine hard deletes a Machine from the database
func (h *Headscale) HardDeleteMachine(m *Machine) error {
err := h.RemoveSharedMachineFromAllNamespaces(m)
if err != nil && err != errorMachineNotShared {
// HardDeleteMachine hard deletes a Machine from the database.
func (h *Headscale) HardDeleteMachine(machine *Machine) error {
err := h.RemoveSharedMachineFromAllNamespaces(machine)
if err != nil && errors.Is(err, errMachineNotShared) {
return err
}
namespaceID := m.NamespaceID
if err := h.db.Unscoped().Delete(&m).Error; err != nil {
namespaceID := machine.NamespaceID
if err := h.db.Unscoped().Delete(&machine).Error; err != nil {
return err
}
return h.RequestMapUpdates(namespaceID)
}
// GetHostInfo returns a Hostinfo struct for the machine
func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
// GetHostInfo returns a Hostinfo struct for the machine.
func (machine *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
hostinfo := tailcfg.Hostinfo{}
if len(m.HostInfo) != 0 {
hi, err := m.HostInfo.MarshalJSON()
if len(machine.HostInfo) != 0 {
hi, err := machine.HostInfo.MarshalJSON()
if err != nil {
return nil, err
}
@ -309,21 +332,21 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
return nil, err
}
}
return &hostinfo, nil
}
func (h *Headscale) isOutdated(m *Machine) bool {
err := h.UpdateMachine(m)
if err != nil {
func (h *Headscale) isOutdated(machine *Machine) bool {
if err := h.UpdateMachine(machine); err != nil {
// It does not seem meaningful to propagate this error as the end result
// will have to be that the machine has to be considered outdated.
return true
}
sharedMachines, _ := h.getShared(m)
sharedMachines, _ := h.getShared(machine)
namespaceSet := set.New(set.ThreadSafe)
namespaceSet.Add(m.Namespace.Name)
namespaceSet.Add(machine.Namespace.Name)
// Check if any of our shared namespaces has updates that we have
// not propagated.
@ -333,27 +356,30 @@ func (h *Headscale) isOutdated(m *Machine) bool {
namespaces := make([]string, namespaceSet.Size())
for index, namespace := range namespaceSet.List() {
namespaces[index] = namespace.(string)
if name, ok := namespace.(string); ok {
namespaces[index] = name
}
}
lastChange := h.getLastStateChange(namespaces...)
log.Trace().
Caller().
Str("machine", m.Name).
Time("last_successful_update", *m.LastSuccessfulUpdate).
Str("machine", machine.Name).
Time("last_successful_update", *machine.LastSuccessfulUpdate).
Time("last_state_change", lastChange).
Msgf("Checking if %s is missing updates", m.Name)
return m.LastSuccessfulUpdate.Before(lastChange)
Msgf("Checking if %s is missing updates", machine.Name)
return machine.LastSuccessfulUpdate.Before(lastChange)
}
func (m Machine) String() string {
return m.Name
func (machine Machine) String() string {
return machine.Name
}
func (ms Machines) String() string {
temp := make([]string, len(ms))
func (machines Machines) String() string {
temp := make([]string, len(machines))
for index, machine := range ms {
for index, machine := range machines {
temp[index] = machine.Name
}
@ -361,24 +387,24 @@ func (ms Machines) String() string {
}
// TODO(kradalby): Remove when we have generics...
func (ms MachinesP) String() string {
temp := make([]string, len(ms))
func (machines MachinesP) String() string {
temp := make([]string, len(machines))
for index, machine := range ms {
for index, machine := range machines {
temp[index] = machine.Name
}
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
}
func (ms Machines) toNodes(
func (machines Machines) toNodes(
baseDomain string,
dnsConfig *tailcfg.DNSConfig,
includeRoutes bool,
) ([]*tailcfg.Node, error) {
nodes := make([]*tailcfg.Node, len(ms))
nodes := make([]*tailcfg.Node, len(machines))
for index, machine := range ms {
for index, machine := range machines {
node, err := machine.toNode(baseDomain, dnsConfig, includeRoutes)
if err != nil {
return nil, err
@ -391,20 +417,25 @@ func (ms Machines) toNodes(
}
// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
// as per the expected behaviour in the official SaaS
func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, includeRoutes bool) (*tailcfg.Node, error) {
nKey, err := wgkey.ParseHex(m.NodeKey)
// as per the expected behaviour in the official SaaS.
func (machine Machine) toNode(
baseDomain string,
dnsConfig *tailcfg.DNSConfig,
includeRoutes bool,
) (*tailcfg.Node, error) {
nodeKey, err := wgkey.ParseHex(machine.NodeKey)
if err != nil {
return nil, err
}
mKey, err := wgkey.ParseHex(m.MachineKey)
machineKey, err := wgkey.ParseHex(machine.MachineKey)
if err != nil {
return nil, err
}
var discoKey tailcfg.DiscoKey
if m.DiscoKey != "" {
dKey, err := wgkey.ParseHex(m.DiscoKey)
if machine.DiscoKey != "" {
dKey, err := wgkey.ParseHex(machine.DiscoKey)
if err != nil {
return nil, err
}
@ -414,23 +445,27 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
}
addrs := []netaddr.IPPrefix{}
ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", m.IPAddress))
ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", machine.IPAddress))
if err != nil {
log.Trace().
Caller().
Str("ip", m.IPAddress).
Msgf("Failed to parse IP Prefix from IP: %s", m.IPAddress)
Str("ip", machine.IPAddress).
Msgf("Failed to parse IP Prefix from IP: %s", machine.IPAddress)
return nil, err
}
addrs = append(addrs, ip) // missing the ipv6 ?
allowedIPs := []netaddr.IPPrefix{}
allowedIPs = append(allowedIPs, ip) // we append the node own IP, as it is required by the clients
allowedIPs = append(
allowedIPs,
ip,
) // we append the node own IP, as it is required by the clients
if includeRoutes {
routesStr := []string{}
if len(m.EnabledRoutes) != 0 {
allwIps, err := m.EnabledRoutes.MarshalJSON()
if len(machine.EnabledRoutes) != 0 {
allwIps, err := machine.EnabledRoutes.MarshalJSON()
if err != nil {
return nil, err
}
@ -450,8 +485,8 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
}
endpoints := []string{}
if len(m.Endpoints) != 0 {
be, err := m.Endpoints.MarshalJSON()
if len(machine.Endpoints) != 0 {
be, err := machine.Endpoints.MarshalJSON()
if err != nil {
return nil, err
}
@ -462,8 +497,8 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
}
hostinfo := tailcfg.Hostinfo{}
if len(m.HostInfo) != 0 {
hi, err := m.HostInfo.MarshalJSON()
if len(machine.HostInfo) != 0 {
hi, err := machine.HostInfo.MarshalJSON()
if err != nil {
return nil, err
}
@ -481,29 +516,34 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
}
var keyExpiry time.Time
if m.Expiry != nil {
keyExpiry = *m.Expiry
if machine.Expiry != nil {
keyExpiry = *machine.Expiry
} else {
keyExpiry = time.Time{}
}
var hostname string
if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS
hostname = fmt.Sprintf("%s.%s.%s", m.Name, m.Namespace.Name, baseDomain)
hostname = fmt.Sprintf(
"%s.%s.%s",
machine.Name,
machine.Namespace.Name,
baseDomain,
)
} else {
hostname = m.Name
hostname = machine.Name
}
n := tailcfg.Node{
ID: tailcfg.NodeID(m.ID), // this is the actual ID
node := tailcfg.Node{
ID: tailcfg.NodeID(machine.ID), // this is the actual ID
StableID: tailcfg.StableNodeID(
strconv.FormatUint(m.ID, 10),
strconv.FormatUint(machine.ID, Base10),
), // in headscale, unlike tailcontrol server, IDs are permanent
Name: hostname,
User: tailcfg.UserID(m.NamespaceID),
Key: tailcfg.NodeKey(nKey),
User: tailcfg.UserID(machine.NamespaceID),
Key: tailcfg.NodeKey(nodeKey),
KeyExpiry: keyExpiry,
Machine: tailcfg.MachineKey(mKey),
Machine: tailcfg.MachineKey(machineKey),
DiscoKey: discoKey,
Addresses: addrs,
AllowedIPs: allowedIPs,
@ -511,81 +551,90 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
DERP: derp,
Hostinfo: hostinfo,
Created: m.CreatedAt,
LastSeen: m.LastSeen,
Created: machine.CreatedAt,
LastSeen: machine.LastSeen,
KeepAlive: true,
MachineAuthorized: m.Registered,
MachineAuthorized: machine.Registered,
Capabilities: []string{tailcfg.CapabilityFileSharing},
}
return &n, nil
return &node, nil
}
func (m *Machine) toProto() *v1.Machine {
machine := &v1.Machine{
Id: m.ID,
MachineKey: m.MachineKey,
func (machine *Machine) toProto() *v1.Machine {
machineProto := &v1.Machine{
Id: machine.ID,
MachineKey: machine.MachineKey,
NodeKey: m.NodeKey,
DiscoKey: m.DiscoKey,
IpAddress: m.IPAddress,
Name: m.Name,
Namespace: m.Namespace.toProto(),
NodeKey: machine.NodeKey,
DiscoKey: machine.DiscoKey,
IpAddress: machine.IPAddress,
Name: machine.Name,
Namespace: machine.Namespace.toProto(),
Registered: m.Registered,
Registered: machine.Registered,
// TODO(kradalby): Implement register method enum converter
// RegisterMethod: ,
CreatedAt: timestamppb.New(m.CreatedAt),
CreatedAt: timestamppb.New(machine.CreatedAt),
}
if m.AuthKey != nil {
machine.PreAuthKey = m.AuthKey.toProto()
if machine.AuthKey != nil {
machineProto.PreAuthKey = machine.AuthKey.toProto()
}
if m.LastSeen != nil {
machine.LastSeen = timestamppb.New(*m.LastSeen)
if machine.LastSeen != nil {
machineProto.LastSeen = timestamppb.New(*machine.LastSeen)
}
if m.LastSuccessfulUpdate != nil {
machine.LastSuccessfulUpdate = timestamppb.New(*m.LastSuccessfulUpdate)
if machine.LastSuccessfulUpdate != nil {
machineProto.LastSuccessfulUpdate = timestamppb.New(
*machine.LastSuccessfulUpdate,
)
}
if m.Expiry != nil {
machine.Expiry = timestamppb.New(*m.Expiry)
if machine.Expiry != nil {
machineProto.Expiry = timestamppb.New(*machine.Expiry)
}
return machine
return machineProto
}
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey
func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, error) {
ns, err := h.GetNamespace(namespace)
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
func (h *Headscale) RegisterMachine(
key string,
namespaceName string,
) (*Machine, error) {
namespace, err := h.GetNamespace(namespaceName)
if err != nil {
return nil, err
}
mKey, err := wgkey.ParseHex(key)
machineKey, err := wgkey.ParseHex(key)
if err != nil {
return nil, err
}
m := Machine{}
if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, errors.New("Machine not found")
machine := Machine{}
if result := h.db.First(&machine, "machine_key = ?", machineKey.HexString()); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
return nil, errMachineNotFound
}
log.Trace().
Caller().
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Attempting to register machine")
if m.isAlreadyRegistered() {
err := errors.New("Machine already registered")
if machine.isAlreadyRegistered() {
err := errMachineAlreadyRegistered
log.Error().
Caller().
Err(err).
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Attempting to register machine")
return nil, err
@ -596,42 +645,44 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
log.Error().
Caller().
Err(err).
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Could not find IP for the new machine")
return nil, err
}
log.Trace().
Caller().
Str("machine", m.Name).
Str("machine", machine.Name).
Str("ip", ip.String()).
Msg("Found IP for host")
m.IPAddress = ip.String()
m.NamespaceID = ns.ID
m.Registered = true
m.RegisterMethod = "cli"
h.db.Save(&m)
machine.IPAddress = ip.String()
machine.NamespaceID = namespace.ID
machine.Registered = true
machine.RegisterMethod = "cli"
h.db.Save(&machine)
log.Trace().
Caller().
Str("machine", m.Name).
Str("machine", machine.Name).
Str("ip", ip.String()).
Msg("Machine registered with the database")
return &m, nil
return &machine, nil
}
func (m *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
hostInfo, err := m.GetHostInfo()
func (machine *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
hostInfo, err := machine.GetHostInfo()
if err != nil {
return nil, err
}
return hostInfo.RoutableIPs, nil
}
func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
data, err := m.EnabledRoutes.MarshalJSON()
func (machine *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
data, err := machine.EnabledRoutes.MarshalJSON()
if err != nil {
return nil, err
}
@ -654,13 +705,13 @@ func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
return routes, nil
}
func (m *Machine) IsRoutesEnabled(routeStr string) bool {
func (machine *Machine) IsRoutesEnabled(routeStr string) bool {
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return false
}
enabledRoutes, err := m.GetEnabledRoutes()
enabledRoutes, err := machine.GetEnabledRoutes()
if err != nil {
return false
}
@ -670,12 +721,13 @@ func (m *Machine) IsRoutesEnabled(routeStr string) bool {
return true
}
}
return false
}
// EnableNodeRoute enables new routes based on a list of new routes. It will _replace_ the
// previous list of routes.
func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error {
newRoutes := make([]netaddr.IPPrefix, len(routeStrs))
for index, routeStr := range routeStrs {
route, err := netaddr.ParseIPPrefix(routeStr)
@ -686,14 +738,18 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
newRoutes[index] = route
}
availableRoutes, err := m.GetAdvertisedRoutes()
availableRoutes, err := machine.GetAdvertisedRoutes()
if err != nil {
return err
}
for _, newRoute := range newRoutes {
if !containsIpPrefix(availableRoutes, newRoute) {
return fmt.Errorf("route (%s) is not available on node %s", m.Name, newRoute)
if !containsIPPrefix(availableRoutes, newRoute) {
return fmt.Errorf(
"route (%s) is not available on node %s: %w",
machine.Name,
newRoute, errMachineRouteIsNotAvailable,
)
}
}
@ -702,10 +758,10 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
return err
}
m.EnabledRoutes = datatypes.JSON(routes)
h.db.Save(&m)
machine.EnabledRoutes = datatypes.JSON(routes)
h.db.Save(&machine)
err = h.RequestMapUpdates(m.NamespaceID)
err = h.RequestMapUpdates(machine.NamespaceID)
if err != nil {
return err
}
@ -713,13 +769,13 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
return nil
}
func (m *Machine) RoutesToProto() (*v1.Routes, error) {
availableRoutes, err := m.GetAdvertisedRoutes()
func (machine *Machine) RoutesToProto() (*v1.Routes, error) {
availableRoutes, err := machine.GetAdvertisedRoutes()
if err != nil {
return nil, err
}
enabledRoutes, err := m.GetEnabledRoutes()
enabledRoutes, err := machine.GetEnabledRoutes()
if err != nil {
return nil, err
}

View file

@ -8,152 +8,159 @@ import (
)
func (s *Suite) TestGetMachine(c *check.C) {
n, err := h.CreateNamespace("test")
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "testmachine")
_, err = app.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
m := &Machine{
machine := &Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
}
h.db.Save(m)
app.db.Save(machine)
m1, err := h.GetMachine("test", "testmachine")
machineFromDB, err := app.GetMachine("test", "testmachine")
c.Assert(err, check.IsNil)
_, err = m1.GetHostInfo()
_, err = machineFromDB.GetHostInfo()
c.Assert(err, check.IsNil)
}
func (s *Suite) TestGetMachineByID(c *check.C) {
n, err := h.CreateNamespace("test")
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachineByID(0)
_, err = app.GetMachineByID(0)
c.Assert(err, check.NotNil)
m := Machine{
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
}
h.db.Save(&m)
app.db.Save(&machine)
m1, err := h.GetMachineByID(0)
machineByID, err := app.GetMachineByID(0)
c.Assert(err, check.IsNil)
_, err = m1.GetHostInfo()
_, err = machineByID.GetHostInfo()
c.Assert(err, check.IsNil)
}
func (s *Suite) TestDeleteMachine(c *check.C) {
n, err := h.CreateNamespace("test")
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
m := Machine{
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(1),
}
h.db.Save(&m)
err = h.DeleteMachine(&m)
app.db.Save(&machine)
err = app.DeleteMachine(&machine)
c.Assert(err, check.IsNil)
v, err := h.getValue("namespaces_pending_updates")
namespacesPendingUpdates, err := app.getValue("namespaces_pending_updates")
c.Assert(err, check.IsNil)
names := []string{}
err = json.Unmarshal([]byte(v), &names)
err = json.Unmarshal([]byte(namespacesPendingUpdates), &names)
c.Assert(err, check.IsNil)
c.Assert(names, check.DeepEquals, []string{n.Name})
h.checkForNamespacesPendingUpdates()
v, _ = h.getValue("namespaces_pending_updates")
c.Assert(v, check.Equals, "")
_, err = h.GetMachine(n.Name, "testmachine")
c.Assert(names, check.DeepEquals, []string{namespace.Name})
app.checkForNamespacesPendingUpdates()
namespacesPendingUpdates, _ = app.getValue("namespaces_pending_updates")
c.Assert(namespacesPendingUpdates, check.Equals, "")
_, err = app.GetMachine(namespace.Name, "testmachine")
c.Assert(err, check.NotNil)
}
func (s *Suite) TestHardDeleteMachine(c *check.C) {
n, err := h.CreateNamespace("test")
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
m := Machine{
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine3",
NamespaceID: n.ID,
NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(1),
}
h.db.Save(&m)
err = h.HardDeleteMachine(&m)
app.db.Save(&machine)
err = app.HardDeleteMachine(&machine)
c.Assert(err, check.IsNil)
_, err = h.GetMachine(n.Name, "testmachine3")
_, err = app.GetMachine(namespace.Name, "testmachine3")
c.Assert(err, check.NotNil)
}
func (s *Suite) TestGetDirectPeers(c *check.C) {
n, err := h.CreateNamespace("test")
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachineByID(0)
_, err = app.GetMachineByID(0)
c.Assert(err, check.NotNil)
for i := 0; i <= 10; i++ {
m := Machine{
ID: uint64(i),
MachineKey: "foo" + strconv.Itoa(i),
NodeKey: "bar" + strconv.Itoa(i),
DiscoKey: "faa" + strconv.Itoa(i),
Name: "testmachine" + strconv.Itoa(i),
NamespaceID: n.ID,
for index := 0; index <= 10; index++ {
machine := Machine{
ID: uint64(index),
MachineKey: "foo" + strconv.Itoa(index),
NodeKey: "bar" + strconv.Itoa(index),
DiscoKey: "faa" + strconv.Itoa(index),
Name: "testmachine" + strconv.Itoa(index),
NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
}
h.db.Save(&m)
app.db.Save(&machine)
}
m1, err := h.GetMachineByID(0)
machine0ByID, err := app.GetMachineByID(0)
c.Assert(err, check.IsNil)
_, err = m1.GetHostInfo()
_, err = machine0ByID.GetHostInfo()
c.Assert(err, check.IsNil)
peers, err := h.getDirectPeers(m1)
peersOfMachine0, err := app.getDirectPeers(machine0ByID)
c.Assert(err, check.IsNil)
c.Assert(len(peers), check.Equals, 9)
c.Assert(peers[0].Name, check.Equals, "testmachine2")
c.Assert(peers[5].Name, check.Equals, "testmachine7")
c.Assert(peers[8].Name, check.Equals, "testmachine10")
c.Assert(len(peersOfMachine0), check.Equals, 9)
c.Assert(peersOfMachine0[0].Name, check.Equals, "testmachine2")
c.Assert(peersOfMachine0[5].Name, check.Equals, "testmachine7")
c.Assert(peersOfMachine0[8].Name, check.Equals, "testmachine10")
}

View file

@ -15,9 +15,9 @@ import (
)
const (
errorNamespaceExists = Error("Namespace already exists")
errorNamespaceNotFound = Error("Namespace not found")
errorNamespaceNotEmptyOfNodes = Error("Namespace not empty: node(s) found")
errNamespaceExists = Error("Namespace already exists")
errNamespaceNotFound = Error("Namespace not found")
errNamespaceNotEmptyOfNodes = Error("Namespace not empty: node(s) found")
)
// Namespace is the way Headscale implements the concept of users in Tailscale
@ -30,51 +30,53 @@ type Namespace struct {
}
// CreateNamespace creates a new Namespace. Returns error if could not be created
// or another namespace already exists
// or another namespace already exists.
func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
n := Namespace{}
if err := h.db.Where("name = ?", name).First(&n).Error; err == nil {
return nil, errorNamespaceExists
namespace := Namespace{}
if err := h.db.Where("name = ?", name).First(&namespace).Error; err == nil {
return nil, errNamespaceExists
}
n.Name = name
if err := h.db.Create(&n).Error; err != nil {
namespace.Name = name
if err := h.db.Create(&namespace).Error; err != nil {
log.Error().
Str("func", "CreateNamespace").
Err(err).
Msg("Could not create row")
return nil, err
}
return &n, nil
return &namespace, nil
}
// DestroyNamespace destroys a Namespace. Returns error if the Namespace does
// not exist or if there are machines associated with it.
func (h *Headscale) DestroyNamespace(name string) error {
n, err := h.GetNamespace(name)
namespace, err := h.GetNamespace(name)
if err != nil {
return errorNamespaceNotFound
return errNamespaceNotFound
}
m, err := h.ListMachinesInNamespace(name)
machines, err := h.ListMachinesInNamespace(name)
if err != nil {
return err
}
if len(m) > 0 {
return errorNamespaceNotEmptyOfNodes
if len(machines) > 0 {
return errNamespaceNotEmptyOfNodes
}
keys, err := h.ListPreAuthKeys(name)
if err != nil {
return err
}
for _, p := range keys {
err = h.DestroyPreAuthKey(&p)
for _, key := range keys {
err = h.DestroyPreAuthKey(key)
if err != nil {
return err
}
}
if result := h.db.Unscoped().Delete(&n); result.Error != nil {
if result := h.db.Unscoped().Delete(&namespace); result.Error != nil {
return result.Error
}
@ -84,25 +86,25 @@ func (h *Headscale) DestroyNamespace(name string) error {
// RenameNamespace renames a Namespace. Returns error if the Namespace does
// not exist or if another Namespace exists with the new name.
func (h *Headscale) RenameNamespace(oldName, newName string) error {
n, err := h.GetNamespace(oldName)
oldNamespace, err := h.GetNamespace(oldName)
if err != nil {
return err
}
_, err = h.GetNamespace(newName)
if err == nil {
return errorNamespaceExists
return errNamespaceExists
}
if !errors.Is(err, errorNamespaceNotFound) {
if !errors.Is(err, errNamespaceNotFound) {
return err
}
n.Name = newName
oldNamespace.Name = newName
if result := h.db.Save(&n); result.Error != nil {
if result := h.db.Save(&oldNamespace); result.Error != nil {
return result.Error
}
err = h.RequestMapUpdates(n.ID)
err = h.RequestMapUpdates(oldNamespace.ID)
if err != nil {
return err
}
@ -110,39 +112,45 @@ func (h *Headscale) RenameNamespace(oldName, newName string) error {
return nil
}
// GetNamespace fetches a namespace by name
// GetNamespace fetches a namespace by name.
func (h *Headscale) GetNamespace(name string) (*Namespace, error) {
n := Namespace{}
if result := h.db.First(&n, "name = ?", name); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, errorNamespaceNotFound
}
return &n, nil
namespace := Namespace{}
if result := h.db.First(&namespace, "name = ?", name); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
return nil, errNamespaceNotFound
}
// ListNamespaces gets all the existing namespaces
return &namespace, nil
}
// ListNamespaces gets all the existing namespaces.
func (h *Headscale) ListNamespaces() ([]Namespace, error) {
namespaces := []Namespace{}
if err := h.db.Find(&namespaces).Error; err != nil {
return nil, err
}
return namespaces, nil
}
// ListMachinesInNamespace gets all the nodes in a given namespace
// ListMachinesInNamespace gets all the nodes in a given namespace.
func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) {
n, err := h.GetNamespace(name)
namespace, err := h.GetNamespace(name)
if err != nil {
return nil, err
}
machines := []Machine{}
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil {
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: namespace.ID}).Find(&machines).Error; err != nil {
return nil, err
}
return machines, nil
}
// ListSharedMachinesInNamespace returns all the machines that are shared to the specified namespace
// ListSharedMachinesInNamespace returns all the machines that are shared to the specified namespace.
func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error) {
namespace, err := h.GetNamespace(name)
if err != nil {
@ -155,48 +163,61 @@ func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error
machines := []Machine{}
for _, sharedMachine := range sharedMachines {
machine, err := h.GetMachineByID(sharedMachine.MachineID) // otherwise not everything comes filled
machine, err := h.GetMachineByID(
sharedMachine.MachineID,
) // otherwise not everything comes filled
if err != nil {
return nil, err
}
machines = append(machines, *machine)
}
return machines, nil
}
// SetMachineNamespace assigns a Machine to a namespace
func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error {
n, err := h.GetNamespace(namespaceName)
// SetMachineNamespace assigns a Machine to a namespace.
func (h *Headscale) SetMachineNamespace(machine *Machine, namespaceName string) error {
namespace, err := h.GetNamespace(namespaceName)
if err != nil {
return err
}
m.NamespaceID = n.ID
h.db.Save(&m)
machine.NamespaceID = namespace.ID
h.db.Save(&machine)
return nil
}
// RequestMapUpdates signals the KV worker to update the maps for this namespace
// TODO(kradalby): Remove the need for this.
// RequestMapUpdates signals the KV worker to update the maps for this namespace.
func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
namespace := Namespace{}
if err := h.db.First(&namespace, namespaceID).Error; err != nil {
return err
}
v, err := h.getValue("namespaces_pending_updates")
if err != nil || v == "" {
err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name))
namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates")
if err != nil || namespacesPendingUpdates == "" {
err = h.setValue(
"namespaces_pending_updates",
fmt.Sprintf(`["%s"]`, namespace.Name),
)
if err != nil {
return err
}
return nil
}
names := []string{}
err = json.Unmarshal([]byte(v), &names)
err = json.Unmarshal([]byte(namespacesPendingUpdates), &names)
if err != nil {
err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name))
err = h.setValue(
"namespaces_pending_updates",
fmt.Sprintf(`["%s"]`, namespace.Name),
)
if err != nil {
return err
}
return nil
}
@ -207,22 +228,24 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
Str("func", "RequestMapUpdates").
Err(err).
Msg("Could not marshal namespaces_pending_updates")
return err
}
return h.setValue("namespaces_pending_updates", string(data))
}
func (h *Headscale) checkForNamespacesPendingUpdates() {
v, err := h.getValue("namespaces_pending_updates")
namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates")
if err != nil {
return
}
if v == "" {
if namespacesPendingUpdates == "" {
return
}
namespaces := []string{}
err = json.Unmarshal([]byte(v), &namespaces)
err = json.Unmarshal([]byte(namespacesPendingUpdates), &namespaces)
if err != nil {
return
}
@ -233,24 +256,25 @@ func (h *Headscale) checkForNamespacesPendingUpdates() {
Msg("Sending updates to nodes in namespacespace")
h.setLastStateChangeToNow(namespace)
}
newV, err := h.getValue("namespaces_pending_updates")
newPendingUpdateValue, err := h.getValue("namespaces_pending_updates")
if err != nil {
return
}
if v == newV { // only clear when no changes, so we notified everybody
if namespacesPendingUpdates == newPendingUpdateValue { // only clear when no changes, so we notified everybody
err = h.setValue("namespaces_pending_updates", "")
if err != nil {
log.Error().
Str("func", "checkForNamespacesPendingUpdates").
Err(err).
Msg("Could not save to KV")
return
}
}
}
func (n *Namespace) toUser() *tailcfg.User {
u := tailcfg.User{
user := tailcfg.User{
ID: tailcfg.UserID(n.ID),
LoginName: n.Name,
DisplayName: n.Name,
@ -259,25 +283,27 @@ func (n *Namespace) toUser() *tailcfg.User {
Logins: []tailcfg.LoginID{},
Created: time.Time{},
}
return &u
return &user
}
func (n *Namespace) toLogin() *tailcfg.Login {
l := tailcfg.Login{
login := tailcfg.Login{
ID: tailcfg.LoginID(n.ID),
LoginName: n.Name,
DisplayName: n.Name,
ProfilePicURL: "",
Domain: "headscale.net",
}
return &l
return &login
}
func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile {
func getMapResponseUserProfiles(machine Machine, peers Machines) []tailcfg.UserProfile {
namespaceMap := make(map[string]Namespace)
namespaceMap[m.Namespace.Name] = m.Namespace
for _, p := range peers {
namespaceMap[p.Namespace.Name] = p.Namespace // not worth checking if already is there
namespaceMap[machine.Namespace.Name] = machine.Namespace
for _, peer := range peers {
namespaceMap[peer.Namespace.Name] = peer.Namespace // not worth checking if already is there
}
profiles := []tailcfg.UserProfile{}
@ -289,12 +315,13 @@ func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile
DisplayName: namespace.Name,
})
}
return profiles
}
func (n *Namespace) toProto() *v1.Namespace {
return &v1.Namespace{
Id: strconv.FormatUint(uint64(n.ID), 10),
Id: strconv.FormatUint(uint64(n.ID), Base10),
Name: n.Name,
CreatedAt: timestamppb.New(n.CreatedAt),
}

View file

@ -7,207 +7,232 @@ import (
)
func (s *Suite) TestCreateAndDestroyNamespace(c *check.C) {
n, err := h.CreateNamespace("test")
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
c.Assert(n.Name, check.Equals, "test")
c.Assert(namespace.Name, check.Equals, "test")
ns, err := h.ListNamespaces()
namespaces, err := app.ListNamespaces()
c.Assert(err, check.IsNil)
c.Assert(len(ns), check.Equals, 1)
c.Assert(len(namespaces), check.Equals, 1)
err = h.DestroyNamespace("test")
err = app.DestroyNamespace("test")
c.Assert(err, check.IsNil)
_, err = h.GetNamespace("test")
_, err = app.GetNamespace("test")
c.Assert(err, check.NotNil)
}
func (s *Suite) TestDestroyNamespaceErrors(c *check.C) {
err := h.DestroyNamespace("test")
c.Assert(err, check.Equals, errorNamespaceNotFound)
err := app.DestroyNamespace("test")
c.Assert(err, check.Equals, errNamespaceNotFound)
n, err := h.CreateNamespace("test")
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
err = h.DestroyNamespace("test")
err = app.DestroyNamespace("test")
c.Assert(err, check.IsNil)
result := h.db.Preload("Namespace").First(&pak, "key = ?", pak.Key)
result := app.db.Preload("Namespace").First(&pak, "key = ?", pak.Key)
// destroying a namespace also deletes all associated preauthkeys
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
n, err = h.CreateNamespace("test")
namespace, err = app.CreateNamespace("test")
c.Assert(err, check.IsNil)
pak, err = h.CreatePreAuthKey(n.Name, false, false, nil)
pak, err = app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
m := Machine{
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
}
h.db.Save(&m)
app.db.Save(&machine)
err = h.DestroyNamespace("test")
c.Assert(err, check.Equals, errorNamespaceNotEmptyOfNodes)
err = app.DestroyNamespace("test")
c.Assert(err, check.Equals, errNamespaceNotEmptyOfNodes)
}
func (s *Suite) TestRenameNamespace(c *check.C) {
n, err := h.CreateNamespace("test")
namespaceTest, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
c.Assert(n.Name, check.Equals, "test")
c.Assert(namespaceTest.Name, check.Equals, "test")
ns, err := h.ListNamespaces()
namespaces, err := app.ListNamespaces()
c.Assert(err, check.IsNil)
c.Assert(len(ns), check.Equals, 1)
c.Assert(len(namespaces), check.Equals, 1)
err = h.RenameNamespace("test", "test_renamed")
err = app.RenameNamespace("test", "test_renamed")
c.Assert(err, check.IsNil)
_, err = h.GetNamespace("test")
c.Assert(err, check.Equals, errorNamespaceNotFound)
_, err = app.GetNamespace("test")
c.Assert(err, check.Equals, errNamespaceNotFound)
_, err = h.GetNamespace("test_renamed")
_, err = app.GetNamespace("test_renamed")
c.Assert(err, check.IsNil)
err = h.RenameNamespace("test_does_not_exit", "test")
c.Assert(err, check.Equals, errorNamespaceNotFound)
err = app.RenameNamespace("test_does_not_exit", "test")
c.Assert(err, check.Equals, errNamespaceNotFound)
n2, err := h.CreateNamespace("test2")
namespaceTest2, err := app.CreateNamespace("test2")
c.Assert(err, check.IsNil)
c.Assert(n2.Name, check.Equals, "test2")
c.Assert(namespaceTest2.Name, check.Equals, "test2")
err = h.RenameNamespace("test2", "test_renamed")
c.Assert(err, check.Equals, errorNamespaceExists)
err = app.RenameNamespace("test2", "test_renamed")
c.Assert(err, check.Equals, errNamespaceExists)
}
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
n1, err := h.CreateNamespace("shared1")
namespaceShared1, err := app.CreateNamespace("shared1")
c.Assert(err, check.IsNil)
n2, err := h.CreateNamespace("shared2")
namespaceShared2, err := app.CreateNamespace("shared2")
c.Assert(err, check.IsNil)
n3, err := h.CreateNamespace("shared3")
namespaceShared3, err := app.CreateNamespace("shared3")
c.Assert(err, check.IsNil)
pak1n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
preAuthKeyShared1, err := app.CreatePreAuthKey(
namespaceShared1.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil)
pak2n2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
preAuthKeyShared2, err := app.CreatePreAuthKey(
namespaceShared2.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil)
pak3n3, err := h.CreatePreAuthKey(n3.Name, false, false, nil)
preAuthKeyShared3, err := app.CreatePreAuthKey(
namespaceShared3.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil)
pak4n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
preAuthKey2Shared1, err := app.CreatePreAuthKey(
namespaceShared1.Name,
false,
false,
nil,
)
c.Assert(err, check.IsNil)
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
_, err = app.GetMachine(namespaceShared1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil)
m1 := &Machine{
machineInShared1 := &Machine{
ID: 1,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
Name: "test_get_shared_nodes_1",
NamespaceID: n1.ID,
Namespace: *n1,
NamespaceID: namespaceShared1.ID,
Namespace: *namespaceShared1,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.1",
AuthKeyID: uint(pak1n1.ID),
AuthKeyID: uint(preAuthKeyShared1.ID),
}
h.db.Save(m1)
app.db.Save(machineInShared1)
_, err = h.GetMachine(n1.Name, m1.Name)
_, err = app.GetMachine(namespaceShared1.Name, machineInShared1.Name)
c.Assert(err, check.IsNil)
m2 := &Machine{
machineInShared2 := &Machine{
ID: 2,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_2",
NamespaceID: n2.ID,
Namespace: *n2,
NamespaceID: namespaceShared2.ID,
Namespace: *namespaceShared2,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.2",
AuthKeyID: uint(pak2n2.ID),
AuthKeyID: uint(preAuthKeyShared2.ID),
}
h.db.Save(m2)
app.db.Save(machineInShared2)
_, err = h.GetMachine(n2.Name, m2.Name)
_, err = app.GetMachine(namespaceShared2.Name, machineInShared2.Name)
c.Assert(err, check.IsNil)
m3 := &Machine{
machineInShared3 := &Machine{
ID: 3,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_3",
NamespaceID: n3.ID,
Namespace: *n3,
NamespaceID: namespaceShared3.ID,
Namespace: *namespaceShared3,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.3",
AuthKeyID: uint(pak3n3.ID),
AuthKeyID: uint(preAuthKeyShared3.ID),
}
h.db.Save(m3)
app.db.Save(machineInShared3)
_, err = h.GetMachine(n3.Name, m3.Name)
_, err = app.GetMachine(namespaceShared3.Name, machineInShared3.Name)
c.Assert(err, check.IsNil)
m4 := &Machine{
machine2InShared1 := &Machine{
ID: 4,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_4",
NamespaceID: n1.ID,
Namespace: *n1,
NamespaceID: namespaceShared1.ID,
Namespace: *namespaceShared1,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.4",
AuthKeyID: uint(pak4n1.ID),
AuthKeyID: uint(preAuthKey2Shared1.ID),
}
h.db.Save(m4)
app.db.Save(machine2InShared1)
err = h.AddSharedMachineToNamespace(m2, n1)
err = app.AddSharedMachineToNamespace(machineInShared2, namespaceShared1)
c.Assert(err, check.IsNil)
m1peers, err := h.getPeers(m1)
peersOfMachine1InShared1, err := app.getPeers(machineInShared1)
c.Assert(err, check.IsNil)
userProfiles := getMapResponseUserProfiles(*m1, m1peers)
userProfiles := getMapResponseUserProfiles(
*machineInShared1,
peersOfMachine1InShared1,
)
log.Trace().Msgf("userProfiles %#v", userProfiles)
c.Assert(len(userProfiles), check.Equals, 2)
found := false
for _, up := range userProfiles {
if up.DisplayName == n1.Name {
for _, userProfiles := range userProfiles {
if userProfiles.DisplayName == namespaceShared1.Name {
found = true
break
}
}
c.Assert(found, check.Equals, true)
found = false
for _, up := range userProfiles {
if up.DisplayName == n2.Name {
for _, userProfile := range userProfiles {
if userProfile.DisplayName == namespaceShared2.Name {
found = true
break
}
}

138
oidc.go
View file

@ -17,6 +17,12 @@ import (
"golang.org/x/oauth2"
)
const (
oidcStateCacheExpiration = time.Minute * 5
oidcStateCacheCleanupInterval = time.Minute * 10
randomByteSize = 16
)
type IDTokenClaims struct {
Name string `json:"name,omitempty"`
Groups []string `json:"groups,omitempty"`
@ -32,6 +38,7 @@ func (h *Headscale) initOIDC() error {
if err != nil {
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
return err
}
@ -39,14 +46,20 @@ func (h *Headscale) initOIDC() error {
ClientID: h.cfg.OIDC.ClientID,
ClientSecret: h.cfg.OIDC.ClientSecret,
Endpoint: h.oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")),
RedirectURL: fmt.Sprintf(
"%s/oidc/callback",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
),
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
}
// init the state cache if it hasn't been already
if h.oidcStateCache == nil {
h.oidcStateCache = cache.New(time.Minute*5, time.Minute*10)
h.oidcStateCache = cache.New(
oidcStateCacheExpiration,
oidcStateCacheCleanupInterval,
)
}
return nil
@ -54,50 +67,53 @@ func (h *Headscale) initOIDC() error {
// RegisterOIDC redirects to the OIDC provider for authentication
// Puts machine key in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:mKey
func (h *Headscale) RegisterOIDC(c *gin.Context) {
mKeyStr := c.Param("mkey")
// Listens in /oidc/register/:mKey.
func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
mKeyStr := ctx.Param("mkey")
if mKeyStr == "" {
c.String(http.StatusBadRequest, "Wrong params")
ctx.String(http.StatusBadRequest, "Wrong params")
return
}
b := make([]byte, 16)
_, err := rand.Read(b)
if err != nil {
randomBlob := make([]byte, randomByteSize)
if _, err := rand.Read(randomBlob); err != nil {
log.Error().Msg("could not read 16 bytes from rand")
c.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
return
}
stateStr := hex.EncodeToString(b)[:32]
stateStr := hex.EncodeToString(randomBlob)[:32]
// place the machine key into the state cache, so it can be retrieved later
h.oidcStateCache.Set(stateStr, mKeyStr, time.Minute*5)
h.oidcStateCache.Set(stateStr, mKeyStr, oidcStateCacheExpiration)
authUrl := h.oauth2Config.AuthCodeURL(stateStr)
log.Debug().Msgf("Redirecting to %s for authentication", authUrl)
authURL := h.oauth2Config.AuthCodeURL(stateStr)
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
c.Redirect(http.StatusFound, authUrl)
ctx.Redirect(http.StatusFound, authURL)
}
// OIDCCallback handles the callback from the OIDC endpoint
// Retrieves the mkey from the state cache and adds the machine to the users email namespace
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into machine HostInfo
// Listens in /oidc/callback
func (h *Headscale) OIDCCallback(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
// Listens in /oidc/callback.
func (h *Headscale) OIDCCallback(ctx *gin.Context) {
code := ctx.Query("code")
state := ctx.Query("state")
if code == "" || state == "" {
c.String(http.StatusBadRequest, "Wrong params")
ctx.String(http.StatusBadRequest, "Wrong params")
return
}
oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
if err != nil {
c.String(http.StatusBadRequest, "Could not exchange code for token")
ctx.String(http.StatusBadRequest, "Could not exchange code for token")
return
}
@ -105,7 +121,8 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
if !rawIDTokenOK {
c.String(http.StatusBadRequest, "Could not extract ID Token")
ctx.String(http.StatusBadRequest, "Could not extract ID Token")
return
}
@ -113,7 +130,8 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
idToken, err := verifier.Verify(context.Background(), rawIDToken)
if err != nil {
c.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
ctx.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
return
}
@ -127,7 +145,11 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
// Extract custom claims
var claims IDTokenClaims
if err = idToken.Claims(&claims); err != nil {
c.String(http.StatusBadRequest, fmt.Sprintf("Failed to decode id token claims: %s", err))
ctx.String(
http.StatusBadRequest,
fmt.Sprintf("Failed to decode id token claims: %s", err),
)
return
}
@ -135,62 +157,80 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
mKeyIf, mKeyFound := h.oidcStateCache.Get(state)
if !mKeyFound {
log.Error().Msg("requested machine state key expired before authorisation completed")
c.String(http.StatusBadRequest, "state has expired")
log.Error().
Msg("requested machine state key expired before authorisation completed")
ctx.String(http.StatusBadRequest, "state has expired")
return
}
mKeyStr, mKeyOK := mKeyIf.(string)
if !mKeyOK {
log.Error().Msg("could not get machine key from cache")
c.String(http.StatusInternalServerError, "could not get machine key from cache")
ctx.String(
http.StatusInternalServerError,
"could not get machine key from cache",
)
return
}
// retrieve machine information
m, err := h.GetMachineByMachineKey(mKeyStr)
machine, err := h.GetMachineByMachineKey(mKeyStr)
if err != nil {
log.Error().Msg("machine key not found in database")
c.String(http.StatusInternalServerError, "could not get machine info from database")
ctx.String(
http.StatusInternalServerError,
"could not get machine info from database",
)
return
}
now := time.Now().UTC()
if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok {
if namespaceName, ok := h.getNamespaceFromEmail(claims.Email); ok {
// register the machine if it's new
if !m.Registered {
if !machine.Registered {
log.Debug().Msg("Registering new machine after successful callback")
ns, err := h.GetNamespace(nsName)
namespace, err := h.GetNamespace(namespaceName)
if err != nil {
ns, err = h.CreateNamespace(nsName)
namespace, err = h.CreateNamespace(namespaceName)
if err != nil {
log.Error().Msgf("could not create new namespace '%s'", claims.Email)
c.String(http.StatusInternalServerError, "could not create new namespace")
log.Error().
Msgf("could not create new namespace '%s'", claims.Email)
ctx.String(
http.StatusInternalServerError,
"could not create new namespace",
)
return
}
}
ip, err := h.getAvailableIP()
if err != nil {
c.String(http.StatusInternalServerError, "could not get an IP from the pool")
ctx.String(
http.StatusInternalServerError,
"could not get an IP from the pool",
)
return
}
m.IPAddress = ip.String()
m.NamespaceID = ns.ID
m.Registered = true
m.RegisterMethod = "oidc"
m.LastSuccessfulUpdate = &now
h.db.Save(&m)
machine.IPAddress = ip.String()
machine.NamespaceID = namespace.ID
machine.Registered = true
machine.RegisterMethod = "oidc"
machine.LastSuccessfulUpdate = &now
h.db.Save(&machine)
}
h.updateMachineExpiry(m)
h.updateMachineExpiry(machine)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
<html>
<body>
<h1>headscale</h1>
@ -201,15 +241,17 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
</html>
`, claims.Email)))
}
log.Error().
Str("email", claims.Email).
Str("username", claims.Username).
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Email could not be mapped to a namespace")
c.String(http.StatusBadRequest, "email from claim could not be mapped to a namespace")
ctx.String(
http.StatusBadRequest,
"email from claim could not be mapped to a namespace",
)
}
// getNamespaceFromEmail passes the users email through a list of "matchers"

View file

@ -145,29 +145,37 @@ func TestHeadscale_getNamespaceFromEmail(t *testing.T) {
},
}
//nolint
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := &Headscale{
cfg: tt.fields.cfg,
db: tt.fields.db,
dbString: tt.fields.dbString,
dbType: tt.fields.dbType,
dbDebug: tt.fields.dbDebug,
publicKey: tt.fields.publicKey,
privateKey: tt.fields.privateKey,
aclPolicy: tt.fields.aclPolicy,
aclRules: tt.fields.aclRules,
lastStateChange: tt.fields.lastStateChange,
oidcProvider: tt.fields.oidcProvider,
oauth2Config: tt.fields.oauth2Config,
oidcStateCache: tt.fields.oidcStateCache,
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
app := &Headscale{
cfg: test.fields.cfg,
db: test.fields.db,
dbString: test.fields.dbString,
dbType: test.fields.dbType,
dbDebug: test.fields.dbDebug,
publicKey: test.fields.publicKey,
privateKey: test.fields.privateKey,
aclPolicy: test.fields.aclPolicy,
aclRules: test.fields.aclRules,
lastStateChange: test.fields.lastStateChange,
oidcProvider: test.fields.oidcProvider,
oauth2Config: test.fields.oauth2Config,
oidcStateCache: test.fields.oidcStateCache,
}
got, got1 := h.getNamespaceFromEmail(tt.args.email)
if got != tt.want {
t.Errorf("Headscale.getNamespaceFromEmail() got = %v, want %v", got, tt.want)
got, got1 := app.getNamespaceFromEmail(test.args.email)
if got != test.want {
t.Errorf(
"Headscale.getNamespaceFromEmail() got = %v, want %v",
got,
test.want,
)
}
if got1 != tt.want1 {
t.Errorf("Headscale.getNamespaceFromEmail() got1 = %v, want %v", got1, tt.want1)
if got1 != test.want1 {
t.Errorf(
"Headscale.getNamespaceFromEmail() got1 = %v, want %v",
got1,
test.want1,
)
}
})
}

263
poll.go
View file

@ -15,6 +15,11 @@ import (
"tailscale.com/types/wgkey"
)
const (
keepAliveInterval = 60 * time.Second
updateCheckInterval = 10 * time.Second
)
// PollNetMapHandler takes care of /machine/:id/map
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
@ -24,20 +29,21 @@ import (
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) PollNetMapHandler(c *gin.Context) {
func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("id", ctx.Param("id")).
Msg("PollNetMapHandler called")
body, _ := io.ReadAll(c.Request.Body)
mKeyStr := c.Param("id")
body, _ := io.ReadAll(ctx.Request.Body)
mKeyStr := ctx.Param("id")
mKey, err := wgkey.ParseHex(mKeyStr)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot parse client key")
c.String(http.StatusBadRequest, "")
ctx.String(http.StatusBadRequest, "")
return
}
req := tailcfg.MapRequest{}
@ -47,34 +53,36 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot decode message")
c.String(http.StatusBadRequest, "")
ctx.String(http.StatusBadRequest, "")
return
}
m, err := h.GetMachineByMachineKey(mKey.HexString())
machine, err := h.GetMachineByMachineKey(mKey.HexString())
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
c.String(http.StatusUnauthorized, "")
ctx.String(http.StatusUnauthorized, "")
return
}
log.Error().
Str("handler", "PollNetMap").
Msgf("Failed to fetch machine from the database with Machine key: %s", mKey.HexString())
c.String(http.StatusInternalServerError, "")
ctx.String(http.StatusInternalServerError, "")
}
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Str("id", ctx.Param("id")).
Str("machine", machine.Name).
Msg("Found machine in database")
hostinfo, _ := json.Marshal(req.Hostinfo)
m.Name = req.Hostinfo.Hostname
m.HostInfo = datatypes.JSON(hostinfo)
m.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
machine.Name = req.Hostinfo.Hostname
machine.HostInfo = datatypes.JSON(hostinfo)
machine.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
now := time.Now().UTC()
// From Tailscale client:
@ -87,20 +95,21 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
// before their first real endpoint update.
if !req.ReadOnly {
endpoints, _ := json.Marshal(req.Endpoints)
m.Endpoints = datatypes.JSON(endpoints)
m.LastSeen = &now
machine.Endpoints = datatypes.JSON(endpoints)
machine.LastSeen = &now
}
h.db.Save(&m)
h.db.Save(&machine)
data, err := h.getMapResponse(mKey, req, m)
data, err := h.getMapResponse(mKey, req, machine)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Str("id", ctx.Param("id")).
Str("machine", machine.Name).
Err(err).
Msg("Failed to get Map response")
c.String(http.StatusInternalServerError, ":(")
ctx.String(http.StatusInternalServerError, ":(")
return
}
@ -111,8 +120,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Str("id", ctx.Param("id")).
Str("machine", machine.Name).
Bool("readOnly", req.ReadOnly).
Bool("omitPeers", req.OmitPeers).
Bool("stream", req.Stream).
@ -121,15 +130,16 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
if req.ReadOnly {
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Client is starting up. Probably interested in a DERP map")
c.Data(200, "application/json; charset=utf-8", data)
ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)
return
}
// There has been an update to _any_ of the nodes that the other nodes would
// need to know about
h.setLastStateChangeToNow(m.Namespace.Name)
h.setLastStateChangeToNow(machine.Namespace.Name)
// The request is not ReadOnly, so we need to set up channels for updating
// peers via longpoll
@ -137,8 +147,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
// Only create update channel if it has not been created
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Str("id", ctx.Param("id")).
Str("machine", machine.Name).
Msg("Loading or creating update channel")
updateChan := make(chan struct{})
@ -152,46 +162,59 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
if req.OmitPeers && !req.Stream {
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Client sent endpoint update and is ok with a response without peer list")
c.Data(200, "application/json; charset=utf-8", data)
ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)
// It sounds like we should update the nodes when we have received a endpoint update
// even tho the comments in the tailscale code dont explicitly say so.
updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "endpoint-update").Inc()
updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "endpoint-update").
Inc()
go func() { updateChan <- struct{}{} }()
return
} else if req.OmitPeers && req.Stream {
log.Warn().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Ignoring request, don't know how to handle it")
c.String(http.StatusBadRequest, "")
ctx.String(http.StatusBadRequest, "")
return
}
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Client is ready to access the tailnet")
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Sending initial map")
go func() { pollDataChan <- data }()
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Notifying peers")
updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "full-update").Inc()
updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "full-update").
Inc()
go func() { updateChan <- struct{}{} }()
h.PollNetMapStream(c, m, req, mKey, pollDataChan, keepAliveChan, updateChan, cancelKeepAlive)
h.PollNetMapStream(
ctx,
machine,
req,
mKey,
pollDataChan,
keepAliveChan,
updateChan,
cancelKeepAlive,
)
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Str("id", ctx.Param("id")).
Str("machine", machine.Name).
Msg("Finished stream, closing PollNetMap session")
}
@ -199,165 +222,181 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
// stream logic, ensuring we communicate updates and data
// to the connected clients.
func (h *Headscale) PollNetMapStream(
c *gin.Context,
m *Machine,
req tailcfg.MapRequest,
mKey wgkey.Key,
ctx *gin.Context,
machine *Machine,
mapRequest tailcfg.MapRequest,
machineKey wgkey.Key,
pollDataChan chan []byte,
keepAliveChan chan []byte,
updateChan chan struct{},
cancelKeepAlive chan struct{},
) {
go h.scheduledPollWorker(cancelKeepAlive, updateChan, keepAliveChan, mKey, req, m)
go h.scheduledPollWorker(
cancelKeepAlive,
updateChan,
keepAliveChan,
machineKey,
mapRequest,
machine,
)
c.Stream(func(w io.Writer) bool {
ctx.Stream(func(writer io.Writer) bool {
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Waiting for data to stream...")
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
select {
case data := <-pollDataChan:
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Sending data received via pollData channel")
_, err := w.Write(data)
_, err := writer.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("channel", "pollData").
Err(err).
Msg("Cannot write data")
return false
}
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Data from pollData channel written successfully")
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachine(m)
err = h.UpdateMachine(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("channel", "pollData").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC()
m.LastSeen = &now
machine.LastSeen = &now
lastStateUpdate.WithLabelValues(m.Namespace.Name, m.Name).Set(float64(now.Unix()))
m.LastSuccessfulUpdate = &now
lastStateUpdate.WithLabelValues(machine.Namespace.Name, machine.Name).
Set(float64(now.Unix()))
machine.LastSuccessfulUpdate = &now
h.db.Save(&m)
h.db.Save(&machine)
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Machine entry in database updated successfully after sending pollData")
return true
case data := <-keepAliveChan:
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Sending keep alive message")
_, err := w.Write(data)
_, err := writer.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot write keep alive message")
return false
}
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Keep alive sent successfully")
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachine(m)
err = h.UpdateMachine(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC()
m.LastSeen = &now
h.db.Save(&m)
machine.LastSeen = &now
h.db.Save(&machine)
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Machine updated successfully after sending keep alive")
return true
case <-updateChan:
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("channel", "update").
Msg("Received a request for update")
updateRequestsReceivedOnChannel.WithLabelValues(m.Name, m.Namespace.Name).Inc()
if h.isOutdated(m) {
updateRequestsReceivedOnChannel.WithLabelValues(machine.Name, machine.Namespace.Name).
Inc()
if h.isOutdated(machine) {
log.Debug().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Time("last_successful_update", *m.LastSuccessfulUpdate).
Time("last_state_change", h.getLastStateChange(m.Namespace.Name)).
Msgf("There has been updates since the last successful update to %s", m.Name)
data, err := h.getMapResponse(mKey, req, m)
Str("machine", machine.Name).
Time("last_successful_update", *machine.LastSuccessfulUpdate).
Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)).
Msgf("There has been updates since the last successful update to %s", machine.Name)
data, err := h.getMapResponse(machineKey, mapRequest, machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("channel", "update").
Err(err).
Msg("Could not get the map update")
}
_, err = w.Write(data)
_, err = writer.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("channel", "update").
Err(err).
Msg("Could not write the map response")
updateRequestsSentToNode.WithLabelValues(m.Name, m.Namespace.Name, "failed").Inc()
updateRequestsSentToNode.WithLabelValues(machine.Name, machine.Namespace.Name, "failed").
Inc()
return false
}
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("channel", "update").
Msg("Updated Map has been sent")
updateRequestsSentToNode.WithLabelValues(m.Name, m.Namespace.Name, "success").Inc()
updateRequestsSentToNode.WithLabelValues(machine.Name, machine.Namespace.Name, "success").
Inc()
// Keep track of the last successful update,
// we sometimes end in a state were the update
@ -366,62 +405,64 @@ func (h *Headscale) PollNetMapStream(
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachine(m)
err = h.UpdateMachine(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("channel", "update").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC()
lastStateUpdate.WithLabelValues(m.Namespace.Name, m.Name).Set(float64(now.Unix()))
m.LastSuccessfulUpdate = &now
lastStateUpdate.WithLabelValues(machine.Namespace.Name, machine.Name).
Set(float64(now.Unix()))
machine.LastSuccessfulUpdate = &now
h.db.Save(&m)
h.db.Save(&machine)
} else {
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Time("last_successful_update", *m.LastSuccessfulUpdate).
Time("last_state_change", h.getLastStateChange(m.Namespace.Name)).
Msgf("%s is up to date", m.Name)
Str("machine", machine.Name).
Time("last_successful_update", *machine.LastSuccessfulUpdate).
Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)).
Msgf("%s is up to date", machine.Name)
}
return true
case <-c.Request.Context().Done():
case <-ctx.Request.Context().Done():
log.Info().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("The client has closed the connection")
// TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err := h.UpdateMachine(m)
err := h.UpdateMachine(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("channel", "Done").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC()
m.LastSeen = &now
h.db.Save(&m)
machine.LastSeen = &now
h.db.Save(&machine)
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("channel", "Done").
Msg("Cancelling keepAlive channel")
cancelKeepAlive <- struct{}{}
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("channel", "Done").
Msg("Closing update channel")
// h.closeUpdateChannel(m)
@ -429,14 +470,14 @@ func (h *Headscale) PollNetMapStream(
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("channel", "Done").
Msg("Closing pollData channel")
close(pollDataChan)
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("channel", "Done").
Msg("Closing keepAliveChan channel")
close(keepAliveChan)
@ -450,12 +491,12 @@ func (h *Headscale) scheduledPollWorker(
cancelChan <-chan struct{},
updateChan chan<- struct{},
keepAliveChan chan<- []byte,
mKey wgkey.Key,
req tailcfg.MapRequest,
m *Machine,
machineKey wgkey.Key,
mapRequest tailcfg.MapRequest,
machine *Machine,
) {
keepAliveTicker := time.NewTicker(60 * time.Second)
updateCheckerTicker := time.NewTicker(10 * time.Second)
keepAliveTicker := time.NewTicker(keepAliveInterval)
updateCheckerTicker := time.NewTicker(updateCheckInterval)
for {
select {
@ -463,27 +504,29 @@ func (h *Headscale) scheduledPollWorker(
return
case <-keepAliveTicker.C:
data, err := h.getMapKeepAliveResponse(mKey, req, m)
data, err := h.getMapKeepAliveResponse(machineKey, mapRequest)
if err != nil {
log.Error().
Str("func", "keepAlive").
Err(err).
Msg("Error generating the keep alive msg")
return
}
log.Debug().
Str("func", "keepAlive").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Sending keepalive")
keepAliveChan <- data
case <-updateCheckerTicker.C:
log.Debug().
Str("func", "scheduledPollWorker").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Sending update request")
updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "scheduled-update").Inc()
updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "scheduled-update").
Inc()
updateChan <- struct{}{}
}
}

View file

@ -7,19 +7,19 @@ import (
"strconv"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
)
const (
errorAuthKeyNotFound = Error("AuthKey not found")
errorAuthKeyExpired = Error("AuthKey expired")
errPreAuthKeyNotFound = Error("AuthKey not found")
errPreAuthKeyExpired = Error("AuthKey expired")
errSingleUseAuthKeyHasBeenUsed = Error("AuthKey has already been used")
errNamespaceMismatch = Error("namespace mismatch")
)
// PreAuthKey describes a pre-authorization key usable in a particular namespace
// PreAuthKey describes a pre-authorization key usable in a particular namespace.
type PreAuthKey struct {
ID uint64 `gorm:"primary_key"`
Key string
@ -33,14 +33,14 @@ type PreAuthKey struct {
Expiration *time.Time
}
// CreatePreAuthKey creates a new PreAuthKey in a namespace, and returns it
// CreatePreAuthKey creates a new PreAuthKey in a namespace, and returns it.
func (h *Headscale) CreatePreAuthKey(
namespaceName string,
reusable bool,
ephemeral bool,
expiration *time.Time,
) (*PreAuthKey, error) {
n, err := h.GetNamespace(namespaceName)
namespace, err := h.GetNamespace(namespaceName)
if err != nil {
return nil, err
}
@ -51,35 +51,36 @@ func (h *Headscale) CreatePreAuthKey(
return nil, err
}
k := PreAuthKey{
key := PreAuthKey{
Key: kstr,
NamespaceID: n.ID,
Namespace: *n,
NamespaceID: namespace.ID,
Namespace: *namespace,
Reusable: reusable,
Ephemeral: ephemeral,
CreatedAt: &now,
Expiration: expiration,
}
h.db.Save(&k)
h.db.Save(&key)
return &k, nil
return &key, nil
}
// ListPreAuthKeys returns the list of PreAuthKeys for a namespace
// ListPreAuthKeys returns the list of PreAuthKeys for a namespace.
func (h *Headscale) ListPreAuthKeys(namespaceName string) ([]PreAuthKey, error) {
n, err := h.GetNamespace(namespaceName)
namespace, err := h.GetNamespace(namespaceName)
if err != nil {
return nil, err
}
keys := []PreAuthKey{}
if err := h.db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: n.ID}).Find(&keys).Error; err != nil {
if err := h.db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: namespace.ID}).Find(&keys).Error; err != nil {
return nil, err
}
return keys, nil
}
// GetPreAuthKey returns a PreAuthKey for a given key
// GetPreAuthKey returns a PreAuthKey for a given key.
func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, error) {
pak, err := h.checkKeyValidity(key)
if err != nil {
@ -87,7 +88,7 @@ func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, er
}
if pak.Namespace.Name != namespace {
return nil, errors.New("Namespace mismatch")
return nil, errNamespaceMismatch
}
return pak, nil
@ -95,32 +96,36 @@ func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, er
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
// does not exist.
func (h *Headscale) DestroyPreAuthKey(pak *PreAuthKey) error {
if result := h.db.Unscoped().Delete(&pak); result.Error != nil {
func (h *Headscale) DestroyPreAuthKey(pak PreAuthKey) error {
if result := h.db.Unscoped().Delete(pak); result.Error != nil {
return result.Error
}
return nil
}
// MarkExpirePreAuthKey marks a PreAuthKey as expired
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error {
if err := h.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
return err
}
return nil
}
// checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node
// If returns no error and a PreAuthKey, it can be used
// If returns no error and a PreAuthKey, it can be used.
func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {
pak := PreAuthKey{}
if result := h.db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, errorAuthKeyNotFound
if result := h.db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
return nil, errPreAuthKeyNotFound
}
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
return nil, errorAuthKeyExpired
return nil, errPreAuthKeyExpired
}
if pak.Reusable || pak.Ephemeral { // we don't need to check if has been used before
@ -145,13 +150,14 @@ func (h *Headscale) generateKey() (string, error) {
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
func (key *PreAuthKey) toProto() *v1.PreAuthKey {
protoKey := v1.PreAuthKey{
Namespace: key.Namespace.Name,
Id: strconv.FormatUint(key.ID, 10),
Id: strconv.FormatUint(key.ID, Base10),
Key: key.Key,
Ephemeral: key.Ephemeral,
Reusable: key.Reusable,

View file

@ -7,189 +7,189 @@ import (
)
func (*Suite) TestCreatePreAuthKey(c *check.C) {
_, err := h.CreatePreAuthKey("bogus", true, false, nil)
_, err := app.CreatePreAuthKey("bogus", true, false, nil)
c.Assert(err, check.NotNil)
n, err := h.CreateNamespace("test")
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
k, err := h.CreatePreAuthKey(n.Name, true, false, nil)
key, err := app.CreatePreAuthKey(namespace.Name, true, false, nil)
c.Assert(err, check.IsNil)
// Did we get a valid key?
c.Assert(k.Key, check.NotNil)
c.Assert(len(k.Key), check.Equals, 48)
c.Assert(key.Key, check.NotNil)
c.Assert(len(key.Key), check.Equals, 48)
// Make sure the Namespace association is populated
c.Assert(k.Namespace.Name, check.Equals, n.Name)
c.Assert(key.Namespace.Name, check.Equals, namespace.Name)
_, err = h.ListPreAuthKeys("bogus")
_, err = app.ListPreAuthKeys("bogus")
c.Assert(err, check.NotNil)
keys, err := h.ListPreAuthKeys(n.Name)
keys, err := app.ListPreAuthKeys(namespace.Name)
c.Assert(err, check.IsNil)
c.Assert(len(keys), check.Equals, 1)
// Make sure the Namespace association is populated
c.Assert((keys)[0].Namespace.Name, check.Equals, n.Name)
c.Assert((keys)[0].Namespace.Name, check.Equals, namespace.Name)
}
func (*Suite) TestExpiredPreAuthKey(c *check.C) {
n, err := h.CreateNamespace("test2")
namespace, err := app.CreateNamespace("test2")
c.Assert(err, check.IsNil)
now := time.Now()
pak, err := h.CreatePreAuthKey(n.Name, true, false, &now)
pak, err := app.CreatePreAuthKey(namespace.Name, true, false, &now)
c.Assert(err, check.IsNil)
p, err := h.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, errorAuthKeyExpired)
c.Assert(p, check.IsNil)
key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, errPreAuthKeyExpired)
c.Assert(key, check.IsNil)
}
func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) {
p, err := h.checkKeyValidity("potatoKey")
c.Assert(err, check.Equals, errorAuthKeyNotFound)
c.Assert(p, check.IsNil)
key, err := app.checkKeyValidity("potatoKey")
c.Assert(err, check.Equals, errPreAuthKeyNotFound)
c.Assert(key, check.IsNil)
}
func (*Suite) TestValidateKeyOk(c *check.C) {
n, err := h.CreateNamespace("test3")
namespace, err := app.CreateNamespace("test3")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, true, false, nil)
pak, err := app.CreatePreAuthKey(namespace.Name, true, false, nil)
c.Assert(err, check.IsNil)
p, err := h.checkKeyValidity(pak.Key)
key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.IsNil)
c.Assert(p.ID, check.Equals, pak.ID)
c.Assert(key.ID, check.Equals, pak.ID)
}
func (*Suite) TestAlreadyUsedKey(c *check.C) {
n, err := h.CreateNamespace("test4")
namespace, err := app.CreateNamespace("test4")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
m := Machine{
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testest",
NamespaceID: n.ID,
NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
}
h.db.Save(&m)
app.db.Save(&machine)
p, err := h.checkKeyValidity(pak.Key)
key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, errSingleUseAuthKeyHasBeenUsed)
c.Assert(p, check.IsNil)
c.Assert(key, check.IsNil)
}
func (*Suite) TestReusableBeingUsedKey(c *check.C) {
n, err := h.CreateNamespace("test5")
namespace, err := app.CreateNamespace("test5")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, true, false, nil)
pak, err := app.CreatePreAuthKey(namespace.Name, true, false, nil)
c.Assert(err, check.IsNil)
m := Machine{
machine := Machine{
ID: 1,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testest",
NamespaceID: n.ID,
NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
}
h.db.Save(&m)
app.db.Save(&machine)
p, err := h.checkKeyValidity(pak.Key)
key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.IsNil)
c.Assert(p.ID, check.Equals, pak.ID)
c.Assert(key.ID, check.Equals, pak.ID)
}
func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
n, err := h.CreateNamespace("test6")
namespace, err := app.CreateNamespace("test6")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
p, err := h.checkKeyValidity(pak.Key)
key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.IsNil)
c.Assert(p.ID, check.Equals, pak.ID)
c.Assert(key.ID, check.Equals, pak.ID)
}
func (*Suite) TestEphemeralKey(c *check.C) {
n, err := h.CreateNamespace("test7")
namespace, err := app.CreateNamespace("test7")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, true, nil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, true, nil)
c.Assert(err, check.IsNil)
now := time.Now()
m := Machine{
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testest",
NamespaceID: n.ID,
NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
LastSeen: &now,
AuthKeyID: uint(pak.ID),
}
h.db.Save(&m)
app.db.Save(&machine)
_, err = h.checkKeyValidity(pak.Key)
_, err = app.checkKeyValidity(pak.Key)
// Ephemeral keys are by definition reusable
c.Assert(err, check.IsNil)
_, err = h.GetMachine("test7", "testest")
_, err = app.GetMachine("test7", "testest")
c.Assert(err, check.IsNil)
h.expireEphemeralNodesWorker()
app.expireEphemeralNodesWorker()
// The machine record should have been deleted
_, err = h.GetMachine("test7", "testest")
_, err = app.GetMachine("test7", "testest")
c.Assert(err, check.NotNil)
}
func (*Suite) TestExpirePreauthKey(c *check.C) {
n, err := h.CreateNamespace("test3")
namespace, err := app.CreateNamespace("test3")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, true, false, nil)
pak, err := app.CreatePreAuthKey(namespace.Name, true, false, nil)
c.Assert(err, check.IsNil)
c.Assert(pak.Expiration, check.IsNil)
err = h.ExpirePreAuthKey(pak)
err = app.ExpirePreAuthKey(pak)
c.Assert(err, check.IsNil)
c.Assert(pak.Expiration, check.NotNil)
p, err := h.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, errorAuthKeyExpired)
c.Assert(p, check.IsNil)
key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, errPreAuthKeyExpired)
c.Assert(key, check.IsNil)
}
func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
n, err := h.CreateNamespace("test6")
namespace, err := app.CreateNamespace("test6")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
pak.Used = true
h.db.Save(&pak)
app.db.Save(&pak)
_, err = h.checkKeyValidity(pak.Key)
_, err = app.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, errSingleUseAuthKeyHasBeenUsed)
}

View file

@ -2,38 +2,48 @@ package headscale
import (
"encoding/json"
"fmt"
"gorm.io/datatypes"
"inet.af/netaddr"
)
const (
errRouteIsNotAvailable = Error("route is not available")
)
// Deprecated: use machine function instead
// GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by
// namespace and node name)
func (h *Headscale) GetAdvertisedNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) {
m, err := h.GetMachine(namespace, nodeName)
// namespace and node name).
func (h *Headscale) GetAdvertisedNodeRoutes(
namespace string,
nodeName string,
) (*[]netaddr.IPPrefix, error) {
machine, err := h.GetMachine(namespace, nodeName)
if err != nil {
return nil, err
}
hostInfo, err := m.GetHostInfo()
hostInfo, err := machine.GetHostInfo()
if err != nil {
return nil, err
}
return &hostInfo.RoutableIPs, nil
}
// Deprecated: use machine function instead
// GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by
// namespace and node name)
func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]netaddr.IPPrefix, error) {
m, err := h.GetMachine(namespace, nodeName)
// namespace and node name).
func (h *Headscale) GetEnabledNodeRoutes(
namespace string,
nodeName string,
) ([]netaddr.IPPrefix, error) {
machine, err := h.GetMachine(namespace, nodeName)
if err != nil {
return nil, err
}
data, err := m.EnabledRoutes.MarshalJSON()
data, err := machine.EnabledRoutes.MarshalJSON()
if err != nil {
return nil, err
}
@ -57,8 +67,12 @@ func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]n
}
// Deprecated: use machine function instead
// IsNodeRouteEnabled checks if a certain route has been enabled
func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeStr string) bool {
// IsNodeRouteEnabled checks if a certain route has been enabled.
func (h *Headscale) IsNodeRouteEnabled(
namespace string,
nodeName string,
routeStr string,
) bool {
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return false
@ -74,14 +88,19 @@ func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeS
return true
}
}
return false
}
// Deprecated: use EnableRoute in machine.go
// EnableNodeRoute enables a subnet route advertised by a node (identified by
// namespace and node name)
func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) error {
m, err := h.GetMachine(namespace, nodeName)
// namespace and node name).
func (h *Headscale) EnableNodeRoute(
namespace string,
nodeName string,
routeStr string,
) error {
machine, err := h.GetMachine(namespace, nodeName)
if err != nil {
return err
}
@ -113,7 +132,7 @@ func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr
}
if !available {
return fmt.Errorf("route (%s) is not available on node %s", nodeName, routeStr)
return errRouteIsNotAvailable
}
routes, err := json.Marshal(enabledRoutes)
@ -121,10 +140,10 @@ func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr
return err
}
m.EnabledRoutes = datatypes.JSON(routes)
h.db.Save(&m)
machine.EnabledRoutes = datatypes.JSON(routes)
h.db.Save(&machine)
err = h.RequestMapUpdates(m.NamespaceID)
err = h.RequestMapUpdates(machine.NamespaceID)
if err != nil {
return err
}

View file

@ -10,57 +10,60 @@ import (
)
func (s *Suite) TestGetRoutes(c *check.C) {
n, err := h.CreateNamespace("test")
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "test_get_route_machine")
_, err = app.GetMachine("test", "test_get_route_machine")
c.Assert(err, check.NotNil)
route, err := netaddr.ParseIPPrefix("10.0.0.0/24")
c.Assert(err, check.IsNil)
hi := tailcfg.Hostinfo{
hostInfo := tailcfg.Hostinfo{
RoutableIPs: []netaddr.IPPrefix{route},
}
hostinfo, err := json.Marshal(hi)
hostinfo, err := json.Marshal(hostInfo)
c.Assert(err, check.IsNil)
m := Machine{
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "test_get_route_machine",
NamespaceID: n.ID,
NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
HostInfo: datatypes.JSON(hostinfo),
}
h.db.Save(&m)
app.db.Save(&machine)
r, err := h.GetAdvertisedNodeRoutes("test", "test_get_route_machine")
advertisedRoutes, err := app.GetAdvertisedNodeRoutes(
"test",
"test_get_route_machine",
)
c.Assert(err, check.IsNil)
c.Assert(len(*r), check.Equals, 1)
c.Assert(len(*advertisedRoutes), check.Equals, 1)
err = h.EnableNodeRoute("test", "test_get_route_machine", "192.168.0.0/24")
err = app.EnableNodeRoute("test", "test_get_route_machine", "192.168.0.0/24")
c.Assert(err, check.NotNil)
err = h.EnableNodeRoute("test", "test_get_route_machine", "10.0.0.0/24")
err = app.EnableNodeRoute("test", "test_get_route_machine", "10.0.0.0/24")
c.Assert(err, check.IsNil)
}
func (s *Suite) TestGetEnableRoutes(c *check.C) {
n, err := h.CreateNamespace("test")
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "test_enable_route_machine")
_, err = app.GetMachine("test", "test_enable_route_machine")
c.Assert(err, check.NotNil)
route, err := netaddr.ParseIPPrefix(
@ -73,56 +76,68 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
)
c.Assert(err, check.IsNil)
hi := tailcfg.Hostinfo{
hostInfo := tailcfg.Hostinfo{
RoutableIPs: []netaddr.IPPrefix{route, route2},
}
hostinfo, err := json.Marshal(hi)
hostinfo, err := json.Marshal(hostInfo)
c.Assert(err, check.IsNil)
m := Machine{
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "test_enable_route_machine",
NamespaceID: n.ID,
NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
HostInfo: datatypes.JSON(hostinfo),
}
h.db.Save(&m)
app.db.Save(&machine)
availableRoutes, err := h.GetAdvertisedNodeRoutes("test", "test_enable_route_machine")
availableRoutes, err := app.GetAdvertisedNodeRoutes(
"test",
"test_enable_route_machine",
)
c.Assert(err, check.IsNil)
c.Assert(len(*availableRoutes), check.Equals, 2)
enabledRoutes, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
noEnabledRoutes, err := app.GetEnabledNodeRoutes(
"test",
"test_enable_route_machine",
)
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes), check.Equals, 0)
c.Assert(len(noEnabledRoutes), check.Equals, 0)
err = h.EnableNodeRoute("test", "test_enable_route_machine", "192.168.0.0/24")
err = app.EnableNodeRoute("test", "test_enable_route_machine", "192.168.0.0/24")
c.Assert(err, check.NotNil)
err = h.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
err = app.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
c.Assert(err, check.IsNil)
enabledRoutes1, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
enabledRoutes, err := app.GetEnabledNodeRoutes("test", "test_enable_route_machine")
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 1)
c.Assert(len(enabledRoutes), check.Equals, 1)
// Adding it twice will just let it pass through
err = h.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
err = app.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
c.Assert(err, check.IsNil)
enabledRoutes2, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
enableRoutesAfterDoubleApply, err := app.GetEnabledNodeRoutes(
"test",
"test_enable_route_machine",
)
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes2), check.Equals, 1)
c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1)
err = h.EnableNodeRoute("test", "test_enable_route_machine", "150.0.10.0/25")
err = app.EnableNodeRoute("test", "test_enable_route_machine", "150.0.10.0/25")
c.Assert(err, check.IsNil)
enabledRoutes3, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
enabledRoutesWithAdditionalRoute, err := app.GetEnabledNodeRoutes(
"test",
"test_enable_route_machine",
)
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes3), check.Equals, 2)
c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2)
}

View file

@ -2,11 +2,13 @@ package headscale
import "gorm.io/gorm"
const errorSameNamespace = Error("Destination namespace same as origin")
const errorMachineAlreadyShared = Error("Node already shared to this namespace")
const errorMachineNotShared = Error("Machine not shared to this namespace")
const (
errSameNamespace = Error("Destination namespace same as origin")
errMachineAlreadyShared = Error("Node already shared to this namespace")
errMachineNotShared = Error("Machine not shared to this namespace")
)
// SharedMachine is a join table to support sharing nodes between namespaces
// SharedMachine is a join table to support sharing nodes between namespaces.
type SharedMachine struct {
gorm.Model
MachineID uint64
@ -15,49 +17,57 @@ type SharedMachine struct {
Namespace Namespace
}
// AddSharedMachineToNamespace adds a machine as a shared node to a namespace
func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error {
if m.NamespaceID == ns.ID {
return errorSameNamespace
// AddSharedMachineToNamespace adds a machine as a shared node to a namespace.
func (h *Headscale) AddSharedMachineToNamespace(
machine *Machine,
namespace *Namespace,
) error {
if machine.NamespaceID == namespace.ID {
return errSameNamespace
}
sharedMachines := []SharedMachine{}
if err := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).Find(&sharedMachines).Error; err != nil {
if err := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID).Find(&sharedMachines).Error; err != nil {
return err
}
if len(sharedMachines) > 0 {
return errorMachineAlreadyShared
return errMachineAlreadyShared
}
sharedMachine := SharedMachine{
MachineID: m.ID,
Machine: *m,
NamespaceID: ns.ID,
Namespace: *ns,
MachineID: machine.ID,
Machine: *machine,
NamespaceID: namespace.ID,
Namespace: *namespace,
}
h.db.Save(&sharedMachine)
return nil
}
// RemoveSharedMachineFromNamespace removes a shared machine from a namespace
func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace) error {
if m.NamespaceID == ns.ID {
// RemoveSharedMachineFromNamespace removes a shared machine from a namespace.
func (h *Headscale) RemoveSharedMachineFromNamespace(
machine *Machine,
namespace *Namespace,
) error {
if machine.NamespaceID == namespace.ID {
// Can't unshare from primary namespace
return errorMachineNotShared
return errMachineNotShared
}
sharedMachine := SharedMachine{}
result := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).Unscoped().Delete(&sharedMachine)
result := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID).
Unscoped().
Delete(&sharedMachine)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return errorMachineNotShared
return errMachineNotShared
}
err := h.RequestMapUpdates(ns.ID)
err := h.RequestMapUpdates(namespace.ID)
if err != nil {
return err
}
@ -65,10 +75,10 @@ func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace)
return nil
}
// RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces
func (h *Headscale) RemoveSharedMachineFromAllNamespaces(m *Machine) error {
// RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces.
func (h *Headscale) RemoveSharedMachineFromAllNamespaces(machine *Machine) error {
sharedMachine := SharedMachine{}
if result := h.db.Where("machine_id = ?", m.ID).Unscoped().Delete(&sharedMachine); result.Error != nil {
if result := h.db.Where("machine_id = ?", machine.ID).Unscoped().Delete(&sharedMachine); result.Error != nil {
return result.Error
}

View file

@ -4,45 +4,48 @@ import (
"gopkg.in/check.v1"
)
func CreateNodeNamespace(c *check.C, namespace, node, key, IP string) (*Namespace, *Machine) {
n1, err := h.CreateNamespace(namespace)
func CreateNodeNamespace(
c *check.C,
namespaceName, node, key, ip string,
) (*Namespace, *Machine) {
namespace, err := app.CreateNamespace(namespaceName)
c.Assert(err, check.IsNil)
pak1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
pak1, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine(n1.Name, node)
_, err = app.GetMachine(namespace.Name, node)
c.Assert(err, check.NotNil)
m1 := &Machine{
machine := &Machine{
ID: 0,
MachineKey: key,
NodeKey: key,
DiscoKey: key,
Name: node,
NamespaceID: n1.ID,
NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: IP,
IPAddress: ip,
AuthKeyID: uint(pak1.ID),
}
h.db.Save(m1)
app.db.Save(machine)
_, err = h.GetMachine(n1.Name, m1.Name)
_, err = app.GetMachine(namespace.Name, machine.Name)
c.Assert(err, check.IsNil)
return n1, m1
return namespace, machine
}
func (s *Suite) TestBasicSharedNodesInNamespace(c *check.C) {
n1, m1 := CreateNodeNamespace(
namespace1, machine1 := CreateNodeNamespace(
c,
"shared1",
"test_get_shared_nodes_1",
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
"100.64.0.1",
)
_, m2 := CreateNodeNamespace(
_, machine2 := CreateNodeNamespace(
c,
"shared2",
"test_get_shared_nodes_2",
@ -50,21 +53,21 @@ func (s *Suite) TestBasicSharedNodesInNamespace(c *check.C) {
"100.64.0.2",
)
p1s, err := h.getPeers(m1)
peersOfMachine1BeforeShared, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
c.Assert(len(p1s), check.Equals, 0)
c.Assert(len(peersOfMachine1BeforeShared), check.Equals, 0)
err = h.AddSharedMachineToNamespace(m2, n1)
err = app.AddSharedMachineToNamespace(machine2, namespace1)
c.Assert(err, check.IsNil)
p1sAfter, err := h.getPeers(m1)
peersOfMachine1AfterShared, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
c.Assert(len(p1sAfter), check.Equals, 1)
c.Assert(p1sAfter[0].ID, check.Equals, m2.ID)
c.Assert(len(peersOfMachine1AfterShared), check.Equals, 1)
c.Assert(peersOfMachine1AfterShared[0].ID, check.Equals, machine2.ID)
}
func (s *Suite) TestSameNamespace(c *check.C) {
n1, m1 := CreateNodeNamespace(
namespace1, machine1 := CreateNodeNamespace(
c,
"shared1",
"test_get_shared_nodes_1",
@ -72,23 +75,23 @@ func (s *Suite) TestSameNamespace(c *check.C) {
"100.64.0.1",
)
p1s, err := h.getPeers(m1)
peersOfMachine1BeforeShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
c.Assert(len(p1s), check.Equals, 0)
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
err = h.AddSharedMachineToNamespace(m1, n1)
c.Assert(err, check.Equals, errorSameNamespace)
err = app.AddSharedMachineToNamespace(machine1, namespace1)
c.Assert(err, check.Equals, errSameNamespace)
}
func (s *Suite) TestUnshare(c *check.C) {
n1, m1 := CreateNodeNamespace(
namespace1, machine1 := CreateNodeNamespace(
c,
"shared1",
"test_unshare_1",
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
"100.64.0.1",
)
_, m2 := CreateNodeNamespace(
_, machine2 := CreateNodeNamespace(
c,
"shared2",
"test_unshare_2",
@ -96,40 +99,40 @@ func (s *Suite) TestUnshare(c *check.C) {
"100.64.0.2",
)
p1s, err := h.getPeers(m1)
peersOfMachine1BeforeShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
c.Assert(len(p1s), check.Equals, 0)
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
err = h.AddSharedMachineToNamespace(m2, n1)
err = app.AddSharedMachineToNamespace(machine2, namespace1)
c.Assert(err, check.IsNil)
p1s, err = h.getShared(m1)
peersOfMachine1BeforeShare, err = app.getShared(machine1)
c.Assert(err, check.IsNil)
c.Assert(len(p1s), check.Equals, 1)
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 1)
err = h.RemoveSharedMachineFromNamespace(m2, n1)
err = app.RemoveSharedMachineFromNamespace(machine2, namespace1)
c.Assert(err, check.IsNil)
p1s, err = h.getShared(m1)
peersOfMachine1BeforeShare, err = app.getShared(machine1)
c.Assert(err, check.IsNil)
c.Assert(len(p1s), check.Equals, 0)
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
err = h.RemoveSharedMachineFromNamespace(m2, n1)
c.Assert(err, check.Equals, errorMachineNotShared)
err = app.RemoveSharedMachineFromNamespace(machine2, namespace1)
c.Assert(err, check.Equals, errMachineNotShared)
err = h.RemoveSharedMachineFromNamespace(m1, n1)
c.Assert(err, check.Equals, errorMachineNotShared)
err = app.RemoveSharedMachineFromNamespace(machine1, namespace1)
c.Assert(err, check.Equals, errMachineNotShared)
}
func (s *Suite) TestAlreadyShared(c *check.C) {
n1, m1 := CreateNodeNamespace(
namespace1, machine1 := CreateNodeNamespace(
c,
"shared1",
"test_get_shared_nodes_1",
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
"100.64.0.1",
)
_, m2 := CreateNodeNamespace(
_, machine2 := CreateNodeNamespace(
c,
"shared2",
"test_get_shared_nodes_2",
@ -137,25 +140,25 @@ func (s *Suite) TestAlreadyShared(c *check.C) {
"100.64.0.2",
)
p1s, err := h.getPeers(m1)
peersOfMachine1BeforeShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
c.Assert(len(p1s), check.Equals, 0)
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
err = h.AddSharedMachineToNamespace(m2, n1)
err = app.AddSharedMachineToNamespace(machine2, namespace1)
c.Assert(err, check.IsNil)
err = h.AddSharedMachineToNamespace(m2, n1)
c.Assert(err, check.Equals, errorMachineAlreadyShared)
err = app.AddSharedMachineToNamespace(machine2, namespace1)
c.Assert(err, check.Equals, errMachineAlreadyShared)
}
func (s *Suite) TestDoNotIncludeRoutesOnShared(c *check.C) {
n1, m1 := CreateNodeNamespace(
namespace1, machine1 := CreateNodeNamespace(
c,
"shared1",
"test_get_shared_nodes_1",
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
"100.64.0.1",
)
_, m2 := CreateNodeNamespace(
_, machine2 := CreateNodeNamespace(
c,
"shared2",
"test_get_shared_nodes_2",
@ -163,35 +166,35 @@ func (s *Suite) TestDoNotIncludeRoutesOnShared(c *check.C) {
"100.64.0.2",
)
p1s, err := h.getPeers(m1)
peersOfMachine1BeforeShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
c.Assert(len(p1s), check.Equals, 0)
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
err = h.AddSharedMachineToNamespace(m2, n1)
err = app.AddSharedMachineToNamespace(machine2, namespace1)
c.Assert(err, check.IsNil)
p1sAfter, err := h.getPeers(m1)
peersOfMachine1AfterShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
c.Assert(len(p1sAfter), check.Equals, 1)
c.Assert(p1sAfter[0].Name, check.Equals, "test_get_shared_nodes_2")
c.Assert(len(peersOfMachine1AfterShare), check.Equals, 1)
c.Assert(peersOfMachine1AfterShare[0].Name, check.Equals, "test_get_shared_nodes_2")
}
func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
n1, m1 := CreateNodeNamespace(
namespace1, machine1 := CreateNodeNamespace(
c,
"shared1",
"test_get_shared_nodes_1",
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
"100.64.0.1",
)
_, m2 := CreateNodeNamespace(
_, machine2 := CreateNodeNamespace(
c,
"shared2",
"test_get_shared_nodes_2",
"dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
"100.64.0.2",
)
_, m3 := CreateNodeNamespace(
_, machine3 := CreateNodeNamespace(
c,
"shared3",
"test_get_shared_nodes_3",
@ -199,76 +202,80 @@ func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
"100.64.0.3",
)
pak4, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
pak4, err := app.CreatePreAuthKey(namespace1.Name, false, false, nil)
c.Assert(err, check.IsNil)
m4 := &Machine{
machine4 := &Machine{
ID: 4,
MachineKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
NodeKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
DiscoKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
Name: "test_get_shared_nodes_4",
NamespaceID: n1.ID,
NamespaceID: namespace1.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.4",
AuthKeyID: uint(pak4.ID),
}
h.db.Save(m4)
app.db.Save(machine4)
_, err = h.GetMachine(n1.Name, m4.Name)
_, err = app.GetMachine(namespace1.Name, machine4.Name)
c.Assert(err, check.IsNil)
p1s, err := h.getPeers(m1)
peersOfMachine1BeforeShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
c.Assert(len(p1s), check.Equals, 1) // node1 can see node4
c.Assert(p1s[0].Name, check.Equals, m4.Name)
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 1) // node1 can see node4
c.Assert(peersOfMachine1BeforeShare[0].Name, check.Equals, machine4.Name)
err = h.AddSharedMachineToNamespace(m2, n1)
err = app.AddSharedMachineToNamespace(machine2, namespace1)
c.Assert(err, check.IsNil)
p1sAfter, err := h.getPeers(m1)
c.Assert(err, check.IsNil)
c.Assert(len(p1sAfter), check.Equals, 2) // node1 can see node2 (shared) and node4 (same namespace)
c.Assert(p1sAfter[0].Name, check.Equals, m2.Name)
c.Assert(p1sAfter[1].Name, check.Equals, m4.Name)
node1shared, err := h.getShared(m1)
c.Assert(err, check.IsNil)
c.Assert(len(node1shared), check.Equals, 1) // node1 can see node2 as shared
c.Assert(node1shared[0].Name, check.Equals, m2.Name)
pAlone, err := h.getPeers(m3)
c.Assert(err, check.IsNil)
c.Assert(len(pAlone), check.Equals, 0) // node3 is alone
pSharedTo, err := h.getPeers(m2)
peersOfMachine1AfterShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
c.Assert(
len(pSharedTo),
len(peersOfMachine1AfterShare),
check.Equals,
2,
) // node1 can see node2 (shared) and node4 (same namespace)
c.Assert(peersOfMachine1AfterShare[0].Name, check.Equals, machine2.Name)
c.Assert(peersOfMachine1AfterShare[1].Name, check.Equals, machine4.Name)
sharedOfMachine1, err := app.getShared(machine1)
c.Assert(err, check.IsNil)
c.Assert(len(sharedOfMachine1), check.Equals, 1) // node1 can see node2 as shared
c.Assert(sharedOfMachine1[0].Name, check.Equals, machine2.Name)
peersOfMachine3, err := app.getPeers(machine3)
c.Assert(err, check.IsNil)
c.Assert(len(peersOfMachine3), check.Equals, 0) // node3 is alone
peersOfMachine2, err := app.getPeers(machine2)
c.Assert(err, check.IsNil)
c.Assert(
len(peersOfMachine2),
check.Equals,
2,
) // node2 should see node1 (sharedTo) and node4 (sharedTo), as is shared in namespace1
c.Assert(pSharedTo[0].Name, check.Equals, m1.Name)
c.Assert(pSharedTo[1].Name, check.Equals, m4.Name)
c.Assert(peersOfMachine2[0].Name, check.Equals, machine1.Name)
c.Assert(peersOfMachine2[1].Name, check.Equals, machine4.Name)
}
func (s *Suite) TestDeleteSharedMachine(c *check.C) {
n1, m1 := CreateNodeNamespace(
namespace1, machine1 := CreateNodeNamespace(
c,
"shared1",
"test_get_shared_nodes_1",
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
"100.64.0.1",
)
_, m2 := CreateNodeNamespace(
_, machine2 := CreateNodeNamespace(
c,
"shared2",
"test_get_shared_nodes_2",
"dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
"100.64.0.2",
)
_, m3 := CreateNodeNamespace(
_, machine3 := CreateNodeNamespace(
c,
"shared3",
"test_get_shared_nodes_3",
@ -276,56 +283,58 @@ func (s *Suite) TestDeleteSharedMachine(c *check.C) {
"100.64.0.3",
)
pak4n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
pak4n1, err := app.CreatePreAuthKey(namespace1.Name, false, false, nil)
c.Assert(err, check.IsNil)
m4 := &Machine{
machine4 := &Machine{
ID: 4,
MachineKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
NodeKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
DiscoKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
Name: "test_get_shared_nodes_4",
NamespaceID: n1.ID,
NamespaceID: namespace1.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.4",
AuthKeyID: uint(pak4n1.ID),
}
h.db.Save(m4)
app.db.Save(machine4)
_, err = h.GetMachine(n1.Name, m4.Name)
_, err = app.GetMachine(namespace1.Name, machine4.Name)
c.Assert(err, check.IsNil)
p1s, err := h.getPeers(m1)
peersOfMachine1BeforeShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
c.Assert(len(p1s), check.Equals, 1) // nodes 1 and 4
c.Assert(p1s[0].Name, check.Equals, m4.Name)
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 1) // nodes 1 and 4
c.Assert(peersOfMachine1BeforeShare[0].Name, check.Equals, machine4.Name)
err = h.AddSharedMachineToNamespace(m2, n1)
err = app.AddSharedMachineToNamespace(machine2, namespace1)
c.Assert(err, check.IsNil)
p1sAfter, err := h.getPeers(m1)
peersOfMachine1AfterShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
c.Assert(len(p1sAfter), check.Equals, 2) // nodes 1, 2, 4
c.Assert(p1sAfter[0].Name, check.Equals, m2.Name)
c.Assert(p1sAfter[1].Name, check.Equals, m4.Name)
c.Assert(len(peersOfMachine1AfterShare), check.Equals, 2) // nodes 1, 2, 4
c.Assert(peersOfMachine1AfterShare[0].Name, check.Equals, machine2.Name)
c.Assert(peersOfMachine1AfterShare[1].Name, check.Equals, machine4.Name)
node1shared, err := h.getShared(m1)
sharedOfMachine1, err := app.getShared(machine1)
c.Assert(err, check.IsNil)
c.Assert(len(node1shared), check.Equals, 1) // nodes 1, 2, 4
c.Assert(node1shared[0].Name, check.Equals, m2.Name)
c.Assert(len(sharedOfMachine1), check.Equals, 1) // nodes 1, 2, 4
c.Assert(sharedOfMachine1[0].Name, check.Equals, machine2.Name)
pAlone, err := h.getPeers(m3)
peersOfMachine3, err := app.getPeers(machine3)
c.Assert(err, check.IsNil)
c.Assert(len(pAlone), check.Equals, 0) // node 3 is alone
c.Assert(len(peersOfMachine3), check.Equals, 0) // node 3 is alone
sharedMachines, err := h.ListSharedMachinesInNamespace(n1.Name)
sharedMachinesInNamespace1, err := app.ListSharedMachinesInNamespace(
namespace1.Name,
)
c.Assert(err, check.IsNil)
c.Assert(len(sharedMachines), check.Equals, 1)
c.Assert(len(sharedMachinesInNamespace1), check.Equals, 1)
err = h.DeleteMachine(m2)
err = app.DeleteMachine(machine2)
c.Assert(err, check.IsNil)
sharedMachines, err = h.ListSharedMachinesInNamespace(n1.Name)
sharedMachinesInNamespace1, err = app.ListSharedMachinesInNamespace(namespace1.Name)
c.Assert(err, check.IsNil)
c.Assert(len(sharedMachines), check.Equals, 0)
c.Assert(len(sharedMachinesInNamespace1), check.Equals, 0)
}

View file

@ -6,16 +6,15 @@ import (
"net/http"
"text/template"
"github.com/rs/zerolog/log"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
)
//go:embed gen/openapiv2/headscale/v1/headscale.swagger.json
var apiV1JSON []byte
func SwaggerUI(c *gin.Context) {
t := template.Must(template.New("swagger").Parse(`
func SwaggerUI(ctx *gin.Context) {
swaggerTemplate := template.Must(template.New("swagger").Parse(`
<html>
<head>
<link rel="stylesheet" type="text/css" href="https://unpkg.com/swagger-ui-dist@3/swagger-ui.css">
@ -48,18 +47,23 @@ func SwaggerUI(c *gin.Context) {
</html>`))
var payload bytes.Buffer
if err := t.Execute(&payload, struct{}{}); err != nil {
if err := swaggerTemplate.Execute(&payload, struct{}{}); err != nil {
log.Error().
Caller().
Err(err).
Msg("Could not render Swagger")
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Swagger"))
ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Swagger"),
)
return
}
c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
}
func SwaggerAPIv1(c *gin.Context) {
c.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON)
func SwaggerAPIv1(ctx *gin.Context) {
ctx.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON)
}

View file

@ -20,31 +20,48 @@ import (
"tailscale.com/types/wgkey"
)
const (
errCannotDecryptReponse = Error("cannot decrypt response")
errResponseMissingNonce = Error("response missing nonce")
errCouldNotAllocateIP = Error("could not find any suitable IP")
)
// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors
type Error string
func (e Error) Error() string { return string(e) }
func decode(msg []byte, v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) error {
func decode(
msg []byte,
v interface{},
pubKey *wgkey.Key,
privKey *wgkey.Private,
) error {
return decodeMsg(msg, v, pubKey, privKey)
}
func decodeMsg(msg []byte, v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) error {
func decodeMsg(
msg []byte,
output interface{},
pubKey *wgkey.Key,
privKey *wgkey.Private,
) error {
decrypted, err := decryptMsg(msg, pubKey, privKey)
if err != nil {
return err
}
// fmt.Println(string(decrypted))
if err := json.Unmarshal(decrypted, v); err != nil {
return fmt.Errorf("response: %v", err)
if err := json.Unmarshal(decrypted, output); err != nil {
return err
}
return nil
}
func decryptMsg(msg []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) {
var nonce [24]byte
if len(msg) < len(nonce)+1 {
return nil, fmt.Errorf("response missing nonce, len=%d", len(msg))
return nil, errResponseMissingNonce
}
copy(nonce[:], msg)
msg = msg[len(nonce):]
@ -52,8 +69,9 @@ func decryptMsg(msg []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte,
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
decrypted, ok := box.Open(nil, msg, &nonce, pub, pri)
if !ok {
return nil, fmt.Errorf("cannot decrypt response")
return nil, errCannotDecryptReponse
}
return decrypted, nil
}
@ -66,13 +84,18 @@ func encode(v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, e
return encodeMsg(b, pubKey, privKey)
}
func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) {
func encodeMsg(
payload []byte,
pubKey *wgkey.Key,
privKey *wgkey.Private,
) ([]byte, error) {
var nonce [24]byte
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
panic(err)
}
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
msg := box.Seal(nonce[:], b, &nonce, pub, pri)
msg := box.Seal(nonce[:], payload, &nonce, pub, pri)
return msg, nil
}
@ -89,7 +112,7 @@ func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
for {
if !ipPrefix.Contains(ip) {
return nil, fmt.Errorf("could not find any suitable IP in %s", ipPrefix)
return nil, errCouldNotAllocateIP
}
// Some OS (including Linux) does not like when IPs ends with 0 or 255, which
@ -98,13 +121,14 @@ func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
ipRaw := ip.As4()
if ipRaw[3] == 0 || ipRaw[3] == 255 {
ip = ip.Next()
continue
}
if ip.IsZero() &&
ip.IsLoopback() {
ip = ip.Next()
continue
}
@ -125,7 +149,7 @@ func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) {
if addr != "" {
ip, err := netaddr.ParseIP(addr)
if err != nil {
return nil, fmt.Errorf("failed to parse ip from database, %w", err)
return nil, fmt.Errorf("failed to parse ip from database: %w", err)
}
ips[index] = ip
@ -156,11 +180,16 @@ func tailNodesToString(nodes []*tailcfg.Node) string {
}
func tailMapResponseToString(resp tailcfg.MapResponse) string {
return fmt.Sprintf("{ Node: %s, Peers: %s }", resp.Node.Name, tailNodesToString(resp.Peers))
return fmt.Sprintf(
"{ Node: %s, Peers: %s }",
resp.Node.Name,
tailNodesToString(resp.Peers),
)
}
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, "unix", addr)
}
@ -174,7 +203,7 @@ func ipPrefixToString(prefixes []netaddr.IPPrefix) []string {
return result
}
func stringToIpPrefix(prefixes []string) ([]netaddr.IPPrefix, error) {
func stringToIPPrefix(prefixes []string) ([]netaddr.IPPrefix, error) {
result := make([]netaddr.IPPrefix, len(prefixes))
for index, prefixStr := range prefixes {
@ -189,7 +218,7 @@ func stringToIpPrefix(prefixes []string) ([]netaddr.IPPrefix, error) {
return result, nil
}
func containsIpPrefix(prefixes []netaddr.IPPrefix, prefix netaddr.IPPrefix) bool {
func containsIPPrefix(prefixes []netaddr.IPPrefix, prefix netaddr.IPPrefix) bool {
for _, p := range prefixes {
if prefix == p {
return true

View file

@ -6,7 +6,7 @@ import (
)
func (s *Suite) TestGetAvailableIp(c *check.C) {
ip, err := h.getAvailableIP()
ip, err := app.getAvailableIP()
c.Assert(err, check.IsNil)
@ -16,33 +16,33 @@ func (s *Suite) TestGetAvailableIp(c *check.C) {
}
func (s *Suite) TestGetUsedIps(c *check.C) {
ip, err := h.getAvailableIP()
ip, err := app.getAvailableIP()
c.Assert(err, check.IsNil)
n, err := h.CreateNamespace("test_ip")
namespace, err := app.CreateNamespace("test_ip")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "testmachine")
_, err = app.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
m := Machine{
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
IPAddress: ip.String(),
}
h.db.Save(&m)
app.db.Save(&machine)
ips, err := h.getUsedIPs()
ips, err := app.getUsedIPs()
c.Assert(err, check.IsNil)
@ -50,42 +50,42 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
c.Assert(ips[0], check.Equals, expected)
m1, err := h.GetMachineByID(0)
machine1, err := app.GetMachineByID(0)
c.Assert(err, check.IsNil)
c.Assert(m1.IPAddress, check.Equals, expected.String())
c.Assert(machine1.IPAddress, check.Equals, expected.String())
}
func (s *Suite) TestGetMultiIp(c *check.C) {
n, err := h.CreateNamespace("test-ip-multi")
namespace, err := app.CreateNamespace("test-ip-multi")
c.Assert(err, check.IsNil)
for i := 1; i <= 350; i++ {
ip, err := h.getAvailableIP()
for index := 1; index <= 350; index++ {
ip, err := app.getAvailableIP()
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "testmachine")
_, err = app.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
m := Machine{
ID: uint64(i),
machine := Machine{
ID: uint64(index),
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
IPAddress: ip.String(),
}
h.db.Save(&m)
app.db.Save(&machine)
}
ips, err := h.getUsedIPs()
ips, err := app.getUsedIPs()
c.Assert(err, check.IsNil)
@ -96,59 +96,67 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
c.Assert(ips[300], check.Equals, netaddr.MustParseIP("10.27.1.47"))
// Check that we can read back the IPs
m1, err := h.GetMachineByID(1)
machine1, err := app.GetMachineByID(1)
c.Assert(err, check.IsNil)
c.Assert(m1.IPAddress, check.Equals, netaddr.MustParseIP("10.27.0.1").String())
c.Assert(
machine1.IPAddress,
check.Equals,
netaddr.MustParseIP("10.27.0.1").String(),
)
m50, err := h.GetMachineByID(50)
machine50, err := app.GetMachineByID(50)
c.Assert(err, check.IsNil)
c.Assert(m50.IPAddress, check.Equals, netaddr.MustParseIP("10.27.0.50").String())
c.Assert(
machine50.IPAddress,
check.Equals,
netaddr.MustParseIP("10.27.0.50").String(),
)
expectedNextIP := netaddr.MustParseIP("10.27.1.97")
nextIP, err := h.getAvailableIP()
nextIP, err := app.getAvailableIP()
c.Assert(err, check.IsNil)
c.Assert(nextIP.String(), check.Equals, expectedNextIP.String())
// If we call get Available again, we should receive
// the same IP, as it has not been reserved.
nextIP2, err := h.getAvailableIP()
nextIP2, err := app.getAvailableIP()
c.Assert(err, check.IsNil)
c.Assert(nextIP2.String(), check.Equals, expectedNextIP.String())
}
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
ip, err := h.getAvailableIP()
ip, err := app.getAvailableIP()
c.Assert(err, check.IsNil)
expected := netaddr.MustParseIP("10.27.0.1")
c.Assert(ip.String(), check.Equals, expected.String())
n, err := h.CreateNamespace("test_ip")
namespace, err := app.CreateNamespace("test_ip")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "testmachine")
_, err = app.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
m := Machine{
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
}
h.db.Save(&m)
app.db.Save(&machine)
ip2, err := h.getAvailableIP()
ip2, err := app.getAvailableIP()
c.Assert(err, check.IsNil)
c.Assert(ip2.String(), check.Equals, expected.String())