Merge branch 'main' into patch-1
This commit is contained in:
commit
bd7b5e97cb
66 changed files with 2981 additions and 1869 deletions
3
.github/workflows/build.yml
vendored
3
.github/workflows/build.yml
vendored
|
@ -18,12 +18,11 @@ jobs:
|
|||
- name: Setup Go
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: "1.16.3"
|
||||
go-version: "1.17.3"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
go version
|
||||
go install golang.org/x/lint/golint@latest
|
||||
sudo apt update
|
||||
sudo apt install -y make
|
||||
|
||||
|
|
39
.github/workflows/lint.yml
vendored
39
.github/workflows/lint.yml
vendored
|
@ -1,20 +1,37 @@
|
|||
---
|
||||
name: CI
|
||||
|
||||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
# The "build" workflow
|
||||
lint:
|
||||
# The type of runner that the job will run on
|
||||
golangci-lint:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
# Steps represent a sequence of tasks that will be executed as part of the job
|
||||
steps:
|
||||
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
# Install and run golangci-lint as a separate step, it's much faster this
|
||||
# way because this action has caching. It'll get run again in `make lint`
|
||||
# below, but it's still much faster in the end than installing
|
||||
# golangci-lint manually in the `Run lint` step.
|
||||
- uses: golangci/golangci-lint-action@v2
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v2
|
||||
with:
|
||||
version: latest
|
||||
|
||||
prettier-lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Prettify code
|
||||
uses: creyD/prettier_action@v4.0
|
||||
with:
|
||||
prettier_options: >-
|
||||
--check **/*.{ts,js,md,yaml,yml,sass,css,scss,html}
|
||||
only_changed: false
|
||||
dry: true
|
||||
|
||||
proto-lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: bufbuild/buf-setup-action@v0.7.0
|
||||
- uses: bufbuild/buf-lint-action@v1
|
||||
with:
|
||||
input: "proto"
|
||||
|
|
|
@ -1,7 +1,53 @@
|
|||
---
|
||||
run:
|
||||
timeout: 5m
|
||||
timeout: 10m
|
||||
|
||||
issues:
|
||||
skip-dirs:
|
||||
- gen
|
||||
linters:
|
||||
enable-all: true
|
||||
disable:
|
||||
- exhaustivestruct
|
||||
- revive
|
||||
- lll
|
||||
- interfacer
|
||||
- scopelint
|
||||
- maligned
|
||||
- golint
|
||||
- gofmt
|
||||
- gochecknoglobals
|
||||
- gochecknoinits
|
||||
- gocognit
|
||||
- funlen
|
||||
- exhaustivestruct
|
||||
- tagliatelle
|
||||
- godox
|
||||
- ireturn
|
||||
|
||||
# In progress
|
||||
- gocritic
|
||||
|
||||
# We should strive to enable these:
|
||||
- wrapcheck
|
||||
- dupl
|
||||
- makezero
|
||||
|
||||
# We might want to enable this, but it might be a lot of work
|
||||
- cyclop
|
||||
- nestif
|
||||
- wsl # might be incompatible with gofumpt
|
||||
- testpackage
|
||||
- paralleltest
|
||||
|
||||
linters-settings:
|
||||
varnamelen:
|
||||
ignore-type-assert-ok: true
|
||||
ignore-map-index-ok: true
|
||||
ignore-names:
|
||||
- err
|
||||
- db
|
||||
- id
|
||||
- ip
|
||||
- ok
|
||||
- c
|
||||
|
|
15
Makefile
15
Makefile
|
@ -1,6 +1,14 @@
|
|||
# Calculate version
|
||||
version = $(shell ./scripts/version-at-commit.sh)
|
||||
|
||||
rwildcard=$(foreach d,$(wildcard $1*),$(call rwildcard,$d/,$2) $(filter $(subst *,%,$2),$d))
|
||||
|
||||
# GO_SOURCES = $(wildcard *.go)
|
||||
# PROTO_SOURCES = $(wildcard **/*.proto)
|
||||
GO_SOURCES = $(call rwildcard,,*.go)
|
||||
PROTO_SOURCES = $(call rwildcard,,*.proto)
|
||||
|
||||
|
||||
build:
|
||||
go build -ldflags "-s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=$(version)" cmd/headscale/headscale.go
|
||||
|
||||
|
@ -19,7 +27,12 @@ coverprofile_html:
|
|||
go tool cover -html=coverage.out
|
||||
|
||||
lint:
|
||||
golangci-lint run --fix
|
||||
golangci-lint run --fix --timeout 10m
|
||||
|
||||
fmt:
|
||||
prettier --write '**/**.{ts,js,md,yaml,yml,sass,css,scss,html}'
|
||||
golines --max-len=88 --base-formatter=gofumpt -w $(GO_SOURCES)
|
||||
clang-format -style="{BasedOnStyle: Google, IndentWidth: 4, AlignConsecutiveDeclarations: true, AlignConsecutiveAssignments: true, ColumnLimit: 0}" -i $(PROTO_SOURCES)
|
||||
|
||||
proto-lint:
|
||||
cd proto/ && buf lint
|
||||
|
|
23
README.md
23
README.md
|
@ -54,7 +54,6 @@ Suggestions/PRs welcomed!
|
|||
|
||||
Please have a look at the documentation under [`docs/`](docs/).
|
||||
|
||||
|
||||
## Disclaimer
|
||||
|
||||
1. We have nothing to do with Tailscale, or Tailscale Inc.
|
||||
|
@ -64,13 +63,30 @@ Please have a look at the documentation under [`docs/`](docs/).
|
|||
|
||||
To contribute to Headscale you would need the lastest version of [Go](https://golang.org) and [Buf](https://buf.build)(Protobuf generator).
|
||||
|
||||
### Code style
|
||||
|
||||
To ensure we have some consistency with a growing number of contributes, this project has adopted linting and style/formatting rules:
|
||||
|
||||
The **Go** code is linted with [`golangci-lint`](https://golangci-lint.run) and
|
||||
formatted with [`golines`](https://github.com/segmentio/golines) (width 88) and
|
||||
[`gofumpt`](https://github.com/mvdan/gofumpt).
|
||||
Please configure your editor to run the tools while developing and make sure to
|
||||
run `make lint` and `make fmt` before committing any code.
|
||||
|
||||
The **Proto** code is linted with [`buf`](https://docs.buf.build/lint/overview) and
|
||||
formatted with [`clang-format`](https://clang.llvm.org/docs/ClangFormat.html).
|
||||
|
||||
The **rest** (markdown, yaml, etc) is formatted with [`prettier`](https://prettier.io).
|
||||
|
||||
Check out the `.golangci.yaml` and `Makefile` to see the specific configuration.
|
||||
|
||||
### Install development tools
|
||||
|
||||
- Go
|
||||
- Buf
|
||||
- Protobuf tools:
|
||||
|
||||
```shell
|
||||
```shell
|
||||
make install-protobuf-plugins
|
||||
```
|
||||
|
||||
|
@ -81,6 +97,7 @@ Some parts of the project requires the generation of Go code from Protobuf (if c
|
|||
```shell
|
||||
make generate
|
||||
```
|
||||
|
||||
**Note**: Please check in changes from `gen/` in a separate commit to make it easier to review.
|
||||
|
||||
To run the tests:
|
||||
|
@ -261,5 +278,3 @@ make build
|
|||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
|
||||
|
|
149
acls.go
149
acls.go
|
@ -9,23 +9,30 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/tailscale/hujson"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
const (
|
||||
errorEmptyPolicy = Error("empty policy")
|
||||
errorInvalidAction = Error("invalid action")
|
||||
errorInvalidUserSection = Error("invalid user section")
|
||||
errorInvalidGroup = Error("invalid group")
|
||||
errorInvalidTag = Error("invalid tag")
|
||||
errorInvalidNamespace = Error("invalid namespace")
|
||||
errorInvalidPortFormat = Error("invalid port format")
|
||||
errEmptyPolicy = Error("empty policy")
|
||||
errInvalidAction = Error("invalid action")
|
||||
errInvalidUserSection = Error("invalid user section")
|
||||
errInvalidGroup = Error("invalid group")
|
||||
errInvalidTag = Error("invalid tag")
|
||||
errInvalidNamespace = Error("invalid namespace")
|
||||
errInvalidPortFormat = Error("invalid port format")
|
||||
)
|
||||
|
||||
// LoadACLPolicy loads the ACL policy from the specify path, and generates the ACL rules
|
||||
const (
|
||||
Base10 = 10
|
||||
BitSize16 = 16
|
||||
portRangeBegin = 0
|
||||
portRangeEnd = 65535
|
||||
expectedTokenItems = 2
|
||||
)
|
||||
|
||||
// LoadACLPolicy loads the ACL policy from the specify path, and generates the ACL rules.
|
||||
func (h *Headscale) LoadACLPolicy(path string) error {
|
||||
policyFile, err := os.Open(path)
|
||||
if err != nil {
|
||||
|
@ -34,23 +41,23 @@ func (h *Headscale) LoadACLPolicy(path string) error {
|
|||
defer policyFile.Close()
|
||||
|
||||
var policy ACLPolicy
|
||||
b, err := io.ReadAll(policyFile)
|
||||
policyBytes, err := io.ReadAll(policyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ast, err := hujson.Parse(b)
|
||||
ast, err := hujson.Parse(policyBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ast.Standardize()
|
||||
b = ast.Pack()
|
||||
err = json.Unmarshal(b, &policy)
|
||||
policyBytes = ast.Pack()
|
||||
err = json.Unmarshal(policyBytes, &policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if policy.IsZero() {
|
||||
return errorEmptyPolicy
|
||||
return errEmptyPolicy
|
||||
}
|
||||
|
||||
h.aclPolicy = &policy
|
||||
|
@ -59,37 +66,40 @@ func (h *Headscale) LoadACLPolicy(path string) error {
|
|||
return err
|
||||
}
|
||||
h.aclRules = rules
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
|
||||
rules := []tailcfg.FilterRule{}
|
||||
|
||||
for i, a := range h.aclPolicy.ACLs {
|
||||
if a.Action != "accept" {
|
||||
return nil, errorInvalidAction
|
||||
for index, acl := range h.aclPolicy.ACLs {
|
||||
if acl.Action != "accept" {
|
||||
return nil, errInvalidAction
|
||||
}
|
||||
|
||||
r := tailcfg.FilterRule{}
|
||||
filterRule := tailcfg.FilterRule{}
|
||||
|
||||
srcIPs := []string{}
|
||||
for j, u := range a.Users {
|
||||
srcs, err := h.generateACLPolicySrcIP(u)
|
||||
for innerIndex, user := range acl.Users {
|
||||
srcs, err := h.generateACLPolicySrcIP(user)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Msgf("Error parsing ACL %d, User %d", i, j)
|
||||
Msgf("Error parsing ACL %d, User %d", index, innerIndex)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
srcIPs = append(srcIPs, srcs...)
|
||||
}
|
||||
r.SrcIPs = srcIPs
|
||||
filterRule.SrcIPs = srcIPs
|
||||
|
||||
destPorts := []tailcfg.NetPortRange{}
|
||||
for j, d := range a.Ports {
|
||||
dests, err := h.generateACLPolicyDestPorts(d)
|
||||
for innerIndex, ports := range acl.Ports {
|
||||
dests, err := h.generateACLPolicyDestPorts(ports)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Msgf("Error parsing ACL %d, Port %d", i, j)
|
||||
Msgf("Error parsing ACL %d, Port %d", index, innerIndex)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
destPorts = append(destPorts, dests...)
|
||||
|
@ -108,10 +118,12 @@ func (h *Headscale) generateACLPolicySrcIP(u string) ([]string, error) {
|
|||
return h.expandAlias(u)
|
||||
}
|
||||
|
||||
func (h *Headscale) generateACLPolicyDestPorts(d string) ([]tailcfg.NetPortRange, error) {
|
||||
func (h *Headscale) generateACLPolicyDestPorts(
|
||||
d string,
|
||||
) ([]tailcfg.NetPortRange, error) {
|
||||
tokens := strings.Split(d, ":")
|
||||
if len(tokens) < 2 || len(tokens) > 3 {
|
||||
return nil, errorInvalidPortFormat
|
||||
if len(tokens) < expectedTokenItems || len(tokens) > 3 {
|
||||
return nil, errInvalidPortFormat
|
||||
}
|
||||
|
||||
var alias string
|
||||
|
@ -121,7 +133,7 @@ func (h *Headscale) generateACLPolicyDestPorts(d string) ([]tailcfg.NetPortRange
|
|||
// tag:montreal-webserver:80,443
|
||||
// tag:api-server:443
|
||||
// example-host-1:*
|
||||
if len(tokens) == 2 {
|
||||
if len(tokens) == expectedTokenItems {
|
||||
alias = tokens[0]
|
||||
} else {
|
||||
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
|
||||
|
@ -146,34 +158,36 @@ func (h *Headscale) generateACLPolicyDestPorts(d string) ([]tailcfg.NetPortRange
|
|||
dests = append(dests, pr)
|
||||
}
|
||||
}
|
||||
|
||||
return dests, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) expandAlias(s string) ([]string, error) {
|
||||
if s == "*" {
|
||||
func (h *Headscale) expandAlias(alias string) ([]string, error) {
|
||||
if alias == "*" {
|
||||
return []string{"*"}, nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(s, "group:") {
|
||||
if _, ok := h.aclPolicy.Groups[s]; !ok {
|
||||
return nil, errorInvalidGroup
|
||||
if strings.HasPrefix(alias, "group:") {
|
||||
if _, ok := h.aclPolicy.Groups[alias]; !ok {
|
||||
return nil, errInvalidGroup
|
||||
}
|
||||
ips := []string{}
|
||||
for _, n := range h.aclPolicy.Groups[s] {
|
||||
for _, n := range h.aclPolicy.Groups[alias] {
|
||||
nodes, err := h.ListMachinesInNamespace(n)
|
||||
if err != nil {
|
||||
return nil, errorInvalidNamespace
|
||||
return nil, errInvalidNamespace
|
||||
}
|
||||
for _, node := range nodes {
|
||||
ips = append(ips, node.IPAddress)
|
||||
}
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(s, "tag:") {
|
||||
if _, ok := h.aclPolicy.TagOwners[s]; !ok {
|
||||
return nil, errorInvalidTag
|
||||
if strings.HasPrefix(alias, "tag:") {
|
||||
if _, ok := h.aclPolicy.TagOwners[alias]; !ok {
|
||||
return nil, errInvalidTag
|
||||
}
|
||||
|
||||
// This will have HORRIBLE performance.
|
||||
|
@ -183,10 +197,10 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
|
|||
return nil, err
|
||||
}
|
||||
ips := []string{}
|
||||
for _, m := range machines {
|
||||
for _, machine := range machines {
|
||||
hostinfo := tailcfg.Hostinfo{}
|
||||
if len(m.HostInfo) != 0 {
|
||||
hi, err := m.HostInfo.MarshalJSON()
|
||||
if len(machine.HostInfo) != 0 {
|
||||
hi, err := machine.HostInfo.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -197,17 +211,19 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
|
|||
|
||||
// FIXME: Check TagOwners allows this
|
||||
for _, t := range hostinfo.RequestTags {
|
||||
if s[4:] == t {
|
||||
ips = append(ips, m.IPAddress)
|
||||
if alias[4:] == t {
|
||||
ips = append(ips, machine.IPAddress)
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
n, err := h.GetNamespace(s)
|
||||
n, err := h.GetNamespace(alias)
|
||||
if err == nil {
|
||||
nodes, err := h.ListMachinesInNamespace(n.Name)
|
||||
if err != nil {
|
||||
|
@ -217,49 +233,54 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
|
|||
for _, n := range nodes {
|
||||
ips = append(ips, n.IPAddress)
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
if h, ok := h.aclPolicy.Hosts[s]; ok {
|
||||
if h, ok := h.aclPolicy.Hosts[alias]; ok {
|
||||
return []string{h.String()}, nil
|
||||
}
|
||||
|
||||
ip, err := netaddr.ParseIP(s)
|
||||
ip, err := netaddr.ParseIP(alias)
|
||||
if err == nil {
|
||||
return []string{ip.String()}, nil
|
||||
}
|
||||
|
||||
cidr, err := netaddr.ParseIPPrefix(s)
|
||||
cidr, err := netaddr.ParseIPPrefix(alias)
|
||||
if err == nil {
|
||||
return []string{cidr.String()}, nil
|
||||
}
|
||||
|
||||
return nil, errorInvalidUserSection
|
||||
return nil, errInvalidUserSection
|
||||
}
|
||||
|
||||
func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) {
|
||||
if s == "*" {
|
||||
return &[]tailcfg.PortRange{{First: 0, Last: 65535}}, nil
|
||||
func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
|
||||
if portsStr == "*" {
|
||||
return &[]tailcfg.PortRange{
|
||||
{First: portRangeBegin, Last: portRangeEnd},
|
||||
}, nil
|
||||
}
|
||||
|
||||
ports := []tailcfg.PortRange{}
|
||||
for _, p := range strings.Split(s, ",") {
|
||||
rang := strings.Split(p, "-")
|
||||
if len(rang) == 1 {
|
||||
pi, err := strconv.ParseUint(rang[0], 10, 16)
|
||||
for _, portStr := range strings.Split(portsStr, ",") {
|
||||
rang := strings.Split(portStr, "-")
|
||||
switch len(rang) {
|
||||
case 1:
|
||||
port, err := strconv.ParseUint(rang[0], Base10, BitSize16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ports = append(ports, tailcfg.PortRange{
|
||||
First: uint16(pi),
|
||||
Last: uint16(pi),
|
||||
First: uint16(port),
|
||||
Last: uint16(port),
|
||||
})
|
||||
} else if len(rang) == 2 {
|
||||
start, err := strconv.ParseUint(rang[0], 10, 16)
|
||||
|
||||
case expectedTokenItems:
|
||||
start, err := strconv.ParseUint(rang[0], Base10, BitSize16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
last, err := strconv.ParseUint(rang[1], 10, 16)
|
||||
last, err := strconv.ParseUint(rang[1], Base10, BitSize16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -267,9 +288,11 @@ func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) {
|
|||
First: uint16(start),
|
||||
Last: uint16(last),
|
||||
})
|
||||
} else {
|
||||
return nil, errorInvalidPortFormat
|
||||
|
||||
default:
|
||||
return nil, errInvalidPortFormat
|
||||
}
|
||||
}
|
||||
|
||||
return &ports, nil
|
||||
}
|
||||
|
|
76
acls_test.go
76
acls_test.go
|
@ -5,54 +5,58 @@ import (
|
|||
)
|
||||
|
||||
func (s *Suite) TestWrongPath(c *check.C) {
|
||||
err := h.LoadACLPolicy("asdfg")
|
||||
err := app.LoadACLPolicy("asdfg")
|
||||
c.Assert(err, check.NotNil)
|
||||
}
|
||||
|
||||
func (s *Suite) TestBrokenHuJson(c *check.C) {
|
||||
err := h.LoadACLPolicy("./tests/acls/broken.hujson")
|
||||
err := app.LoadACLPolicy("./tests/acls/broken.hujson")
|
||||
c.Assert(err, check.NotNil)
|
||||
}
|
||||
|
||||
func (s *Suite) TestInvalidPolicyHuson(c *check.C) {
|
||||
err := h.LoadACLPolicy("./tests/acls/invalid.hujson")
|
||||
err := app.LoadACLPolicy("./tests/acls/invalid.hujson")
|
||||
c.Assert(err, check.NotNil)
|
||||
c.Assert(err, check.Equals, errorEmptyPolicy)
|
||||
c.Assert(err, check.Equals, errEmptyPolicy)
|
||||
}
|
||||
|
||||
func (s *Suite) TestParseHosts(c *check.C) {
|
||||
var hs Hosts
|
||||
err := hs.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100","example-host-2": "100.100.101.100/24"}`))
|
||||
c.Assert(hs, check.NotNil)
|
||||
var hosts Hosts
|
||||
err := hosts.UnmarshalJSON(
|
||||
[]byte(
|
||||
`{"example-host-1": "100.100.100.100","example-host-2": "100.100.101.100/24"}`,
|
||||
),
|
||||
)
|
||||
c.Assert(hosts, check.NotNil)
|
||||
c.Assert(err, check.IsNil)
|
||||
}
|
||||
|
||||
func (s *Suite) TestParseInvalidCIDR(c *check.C) {
|
||||
var hs Hosts
|
||||
err := hs.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100/42"}`))
|
||||
c.Assert(hs, check.IsNil)
|
||||
var hosts Hosts
|
||||
err := hosts.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100/42"}`))
|
||||
c.Assert(hosts, check.IsNil)
|
||||
c.Assert(err, check.NotNil)
|
||||
}
|
||||
|
||||
func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
|
||||
err := h.LoadACLPolicy("./tests/acls/acl_policy_invalid.hujson")
|
||||
err := app.LoadACLPolicy("./tests/acls/acl_policy_invalid.hujson")
|
||||
c.Assert(err, check.NotNil)
|
||||
}
|
||||
|
||||
func (s *Suite) TestBasicRule(c *check.C) {
|
||||
err := h.LoadACLPolicy("./tests/acls/acl_policy_basic_1.hujson")
|
||||
err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_1.hujson")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
rules, err := h.generateACLRules()
|
||||
rules, err := app.generateACLRules()
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(rules, check.NotNil)
|
||||
}
|
||||
|
||||
func (s *Suite) TestPortRange(c *check.C) {
|
||||
err := h.LoadACLPolicy("./tests/acls/acl_policy_basic_range.hujson")
|
||||
err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_range.hujson")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
rules, err := h.generateACLRules()
|
||||
rules, err := app.generateACLRules()
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(rules, check.NotNil)
|
||||
|
||||
|
@ -63,10 +67,10 @@ func (s *Suite) TestPortRange(c *check.C) {
|
|||
}
|
||||
|
||||
func (s *Suite) TestPortWildcard(c *check.C) {
|
||||
err := h.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.hujson")
|
||||
err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.hujson")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
rules, err := h.generateACLRules()
|
||||
rules, err := app.generateACLRules()
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(rules, check.NotNil)
|
||||
|
||||
|
@ -79,33 +83,35 @@ func (s *Suite) TestPortWildcard(c *check.C) {
|
|||
}
|
||||
|
||||
func (s *Suite) TestPortNamespace(c *check.C) {
|
||||
n, err := h.CreateNamespace("testnamespace")
|
||||
namespace, err := app.CreateNamespace("testnamespace")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = h.GetMachine("testnamespace", "testmachine")
|
||||
_, err = app.GetMachine("testnamespace", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
ip, _ := h.getAvailableIP()
|
||||
m := Machine{
|
||||
ip, _ := app.getAvailableIP()
|
||||
machine := Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "testmachine",
|
||||
NamespaceID: n.ID,
|
||||
NamespaceID: namespace.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
IPAddress: ip.String(),
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
h.db.Save(&m)
|
||||
app.db.Save(&machine)
|
||||
|
||||
err = h.LoadACLPolicy("./tests/acls/acl_policy_basic_namespace_as_user.hujson")
|
||||
err = app.LoadACLPolicy(
|
||||
"./tests/acls/acl_policy_basic_namespace_as_user.hujson",
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
rules, err := h.generateACLRules()
|
||||
rules, err := app.generateACLRules()
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(rules, check.NotNil)
|
||||
|
||||
|
@ -119,33 +125,33 @@ func (s *Suite) TestPortNamespace(c *check.C) {
|
|||
}
|
||||
|
||||
func (s *Suite) TestPortGroup(c *check.C) {
|
||||
n, err := h.CreateNamespace("testnamespace")
|
||||
namespace, err := app.CreateNamespace("testnamespace")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = h.GetMachine("testnamespace", "testmachine")
|
||||
_, err = app.GetMachine("testnamespace", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
ip, _ := h.getAvailableIP()
|
||||
m := Machine{
|
||||
ip, _ := app.getAvailableIP()
|
||||
machine := Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "testmachine",
|
||||
NamespaceID: n.ID,
|
||||
NamespaceID: namespace.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
IPAddress: ip.String(),
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
h.db.Save(&m)
|
||||
app.db.Save(&machine)
|
||||
|
||||
err = h.LoadACLPolicy("./tests/acls/acl_policy_basic_groups.hujson")
|
||||
err = app.LoadACLPolicy("./tests/acls/acl_policy_basic_groups.hujson")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
rules, err := h.generateACLRules()
|
||||
rules, err := app.generateACLRules()
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(rules, check.NotNil)
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
// ACLPolicy represents a Tailscale ACL Policy
|
||||
// ACLPolicy represents a Tailscale ACL Policy.
|
||||
type ACLPolicy struct {
|
||||
Groups Groups `json:"Groups"`
|
||||
Hosts Hosts `json:"Hosts"`
|
||||
|
@ -17,61 +17,63 @@ type ACLPolicy struct {
|
|||
Tests []ACLTest `json:"Tests"`
|
||||
}
|
||||
|
||||
// ACL is a basic rule for the ACL Policy
|
||||
// ACL is a basic rule for the ACL Policy.
|
||||
type ACL struct {
|
||||
Action string `json:"Action"`
|
||||
Users []string `json:"Users"`
|
||||
Ports []string `json:"Ports"`
|
||||
}
|
||||
|
||||
// Groups references a series of alias in the ACL rules
|
||||
// Groups references a series of alias in the ACL rules.
|
||||
type Groups map[string][]string
|
||||
|
||||
// Hosts are alias for IP addresses or subnets
|
||||
// Hosts are alias for IP addresses or subnets.
|
||||
type Hosts map[string]netaddr.IPPrefix
|
||||
|
||||
// TagOwners specify what users (namespaces?) are allow to use certain tags
|
||||
// TagOwners specify what users (namespaces?) are allow to use certain tags.
|
||||
type TagOwners map[string][]string
|
||||
|
||||
// ACLTest is not implemented, but should be use to check if a certain rule is allowed
|
||||
// ACLTest is not implemented, but should be use to check if a certain rule is allowed.
|
||||
type ACLTest struct {
|
||||
User string `json:"User"`
|
||||
Allow []string `json:"Allow"`
|
||||
Deny []string `json:"Deny,omitempty"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON allows to parse the Hosts directly into netaddr objects
|
||||
func (h *Hosts) UnmarshalJSON(data []byte) error {
|
||||
hosts := Hosts{}
|
||||
hs := make(map[string]string)
|
||||
// UnmarshalJSON allows to parse the Hosts directly into netaddr objects.
|
||||
func (hosts *Hosts) UnmarshalJSON(data []byte) error {
|
||||
newHosts := Hosts{}
|
||||
hostIPPrefixMap := make(map[string]string)
|
||||
ast, err := hujson.Parse(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ast.Standardize()
|
||||
data = ast.Pack()
|
||||
err = json.Unmarshal(data, &hs)
|
||||
err = json.Unmarshal(data, &hostIPPrefixMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for k, v := range hs {
|
||||
if !strings.Contains(v, "/") {
|
||||
v = v + "/32"
|
||||
for host, prefixStr := range hostIPPrefixMap {
|
||||
if !strings.Contains(prefixStr, "/") {
|
||||
prefixStr += "/32"
|
||||
}
|
||||
prefix, err := netaddr.ParseIPPrefix(v)
|
||||
prefix, err := netaddr.ParseIPPrefix(prefixStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hosts[k] = prefix
|
||||
newHosts[host] = prefix
|
||||
}
|
||||
*h = hosts
|
||||
*hosts = newHosts
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsZero is perhaps a bit naive here
|
||||
func (p ACLPolicy) IsZero() bool {
|
||||
if len(p.Groups) == 0 && len(p.Hosts) == 0 && len(p.ACLs) == 0 {
|
||||
// IsZero is perhaps a bit naive here.
|
||||
func (policy ACLPolicy) IsZero() bool {
|
||||
if len(policy.Groups) == 0 && len(policy.Hosts) == 0 && len(policy.ACLs) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
297
api.go
297
api.go
|
@ -10,31 +10,37 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/wgkey"
|
||||
)
|
||||
|
||||
const reservedResponseHeaderSize = 4
|
||||
|
||||
// KeyHandler provides the Headscale pub key
|
||||
// Listens in /key
|
||||
func (h *Headscale) KeyHandler(c *gin.Context) {
|
||||
c.Data(200, "text/plain; charset=utf-8", []byte(h.publicKey.HexString()))
|
||||
// Listens in /key.
|
||||
func (h *Headscale) KeyHandler(ctx *gin.Context) {
|
||||
ctx.Data(
|
||||
http.StatusOK,
|
||||
"text/plain; charset=utf-8",
|
||||
[]byte(h.publicKey.HexString()),
|
||||
)
|
||||
}
|
||||
|
||||
// RegisterWebAPI shows a simple message in the browser to point to the CLI
|
||||
// Listens in /register
|
||||
func (h *Headscale) RegisterWebAPI(c *gin.Context) {
|
||||
mKeyStr := c.Query("key")
|
||||
if mKeyStr == "" {
|
||||
c.String(http.StatusBadRequest, "Wrong params")
|
||||
// Listens in /register.
|
||||
func (h *Headscale) RegisterWebAPI(ctx *gin.Context) {
|
||||
machineKeyStr := ctx.Query("key")
|
||||
if machineKeyStr == "" {
|
||||
ctx.String(http.StatusBadRequest, "Wrong params")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
||||
ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
||||
<html>
|
||||
<body>
|
||||
<h1>headscale</h1>
|
||||
|
@ -51,43 +57,45 @@ func (h *Headscale) RegisterWebAPI(c *gin.Context) {
|
|||
</body>
|
||||
</html>
|
||||
|
||||
`, mKeyStr)))
|
||||
`, machineKeyStr)))
|
||||
}
|
||||
|
||||
// RegistrationHandler handles the actual registration process of a machine
|
||||
// Endpoint /machine/:id
|
||||
func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||
body, _ := io.ReadAll(c.Request.Body)
|
||||
mKeyStr := c.Param("id")
|
||||
mKey, err := wgkey.ParseHex(mKeyStr)
|
||||
// Endpoint /machine/:id.
|
||||
func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
|
||||
body, _ := io.ReadAll(ctx.Request.Body)
|
||||
machineKeyStr := ctx.Param("id")
|
||||
machineKey, err := wgkey.ParseHex(machineKeyStr)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot parse machine key")
|
||||
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
|
||||
c.String(http.StatusInternalServerError, "Sad!")
|
||||
ctx.String(http.StatusInternalServerError, "Sad!")
|
||||
|
||||
return
|
||||
}
|
||||
req := tailcfg.RegisterRequest{}
|
||||
err = decode(body, &req, &mKey, h.privateKey)
|
||||
err = decode(body, &req, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot decode message")
|
||||
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
|
||||
c.String(http.StatusInternalServerError, "Very sad!")
|
||||
ctx.String(http.StatusInternalServerError, "Very sad!")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
m, err := h.GetMachineByMachineKey(mKey.HexString())
|
||||
machine, err := h.GetMachineByMachineKey(machineKey.HexString())
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
|
||||
newMachine := Machine{
|
||||
Expiry: &time.Time{},
|
||||
MachineKey: mKey.HexString(),
|
||||
MachineKey: machineKey.HexString(),
|
||||
Name: req.Hostinfo.Hostname,
|
||||
}
|
||||
if err := h.db.Create(&newMachine).Error; err != nil {
|
||||
|
@ -95,88 +103,96 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
|||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Could not create row")
|
||||
machineRegistrations.WithLabelValues("unknown", "web", "error", m.Namespace.Name).Inc()
|
||||
machineRegistrations.WithLabelValues("unknown", "web", "error", machine.Namespace.Name).
|
||||
Inc()
|
||||
|
||||
return
|
||||
}
|
||||
m = &newMachine
|
||||
machine = &newMachine
|
||||
}
|
||||
|
||||
if !m.Registered && req.Auth.AuthKey != "" {
|
||||
h.handleAuthKey(c, h.db, mKey, req, *m)
|
||||
if !machine.Registered && req.Auth.AuthKey != "" {
|
||||
h.handleAuthKey(ctx, h.db, machineKey, req, *machine)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resp := tailcfg.RegisterResponse{}
|
||||
|
||||
// We have the updated key!
|
||||
if m.NodeKey == wgkey.Key(req.NodeKey).HexString() {
|
||||
|
||||
if machine.NodeKey == wgkey.Key(req.NodeKey).HexString() {
|
||||
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
|
||||
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
|
||||
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
|
||||
log.Info().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Client requested logout")
|
||||
|
||||
m.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired
|
||||
h.db.Save(&m)
|
||||
machine.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired
|
||||
h.db.Save(&machine)
|
||||
|
||||
resp.AuthURL = ""
|
||||
resp.MachineAuthorized = false
|
||||
resp.User = *m.Namespace.toUser()
|
||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||
resp.User = *machine.Namespace.toUser()
|
||||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
c.String(http.StatusInternalServerError, "")
|
||||
ctx.String(http.StatusInternalServerError, "")
|
||||
|
||||
return
|
||||
}
|
||||
c.Data(200, "application/json; charset=utf-8", respBody)
|
||||
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if m.Registered && m.Expiry.UTC().After(now) {
|
||||
if machine.Registered && machine.Expiry.UTC().After(now) {
|
||||
// The machine registration is valid, respond with redirect to /map
|
||||
log.Debug().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Client is registered and we have the current NodeKey. All clear to /map")
|
||||
|
||||
resp.AuthURL = ""
|
||||
resp.MachineAuthorized = true
|
||||
resp.User = *m.Namespace.toUser()
|
||||
resp.Login = *m.Namespace.toLogin()
|
||||
resp.User = *machine.Namespace.toUser()
|
||||
resp.Login = *machine.Namespace.toLogin()
|
||||
|
||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
machineRegistrations.WithLabelValues("update", "web", "error", m.Namespace.Name).Inc()
|
||||
c.String(http.StatusInternalServerError, "")
|
||||
machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name).
|
||||
Inc()
|
||||
ctx.String(http.StatusInternalServerError, "")
|
||||
|
||||
return
|
||||
}
|
||||
machineRegistrations.WithLabelValues("update", "web", "success", m.Namespace.Name).Inc()
|
||||
c.Data(200, "application/json; charset=utf-8", respBody)
|
||||
machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name).
|
||||
Inc()
|
||||
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// The client has registered before, but has expired
|
||||
log.Debug().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Machine registration has expired. Sending a authurl to register")
|
||||
|
||||
if h.cfg.OIDC.Issuer != "" {
|
||||
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
|
||||
} else {
|
||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
|
||||
}
|
||||
|
||||
// When a client connects, it may request a specific expiry time in its
|
||||
|
@ -185,102 +201,120 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
|||
// into two steps (which cant pass arbitrary data between them easily) and needs to be
|
||||
// retrieved again after the user has authenticated. After the authentication flow
|
||||
// completes, RequestedExpiry is copied into Expiry.
|
||||
m.RequestedExpiry = &req.Expiry
|
||||
machine.RequestedExpiry = &req.Expiry
|
||||
|
||||
h.db.Save(&m)
|
||||
h.db.Save(&machine)
|
||||
|
||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
machineRegistrations.WithLabelValues("new", "web", "error", m.Namespace.Name).Inc()
|
||||
c.String(http.StatusInternalServerError, "")
|
||||
machineRegistrations.WithLabelValues("new", "web", "error", machine.Namespace.Name).
|
||||
Inc()
|
||||
ctx.String(http.StatusInternalServerError, "")
|
||||
|
||||
return
|
||||
}
|
||||
machineRegistrations.WithLabelValues("new", "web", "success", m.Namespace.Name).Inc()
|
||||
c.Data(200, "application/json; charset=utf-8", respBody)
|
||||
machineRegistrations.WithLabelValues("new", "web", "success", machine.Namespace.Name).
|
||||
Inc()
|
||||
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
|
||||
if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && m.Expiry.UTC().After(now) {
|
||||
if machine.NodeKey == wgkey.Key(req.OldNodeKey).HexString() &&
|
||||
machine.Expiry.UTC().After(now) {
|
||||
log.Debug().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("We have the OldNodeKey in the database. This is a key refresh")
|
||||
m.NodeKey = wgkey.Key(req.NodeKey).HexString()
|
||||
h.db.Save(&m)
|
||||
machine.NodeKey = wgkey.Key(req.NodeKey).HexString()
|
||||
h.db.Save(&machine)
|
||||
|
||||
resp.AuthURL = ""
|
||||
resp.User = *m.Namespace.toUser()
|
||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||
resp.User = *machine.Namespace.toUser()
|
||||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
c.String(http.StatusInternalServerError, "Extremely sad!")
|
||||
ctx.String(http.StatusInternalServerError, "Extremely sad!")
|
||||
|
||||
return
|
||||
}
|
||||
c.Data(200, "application/json; charset=utf-8", respBody)
|
||||
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// The machine registration is new, redirect the client to the registration URL
|
||||
log.Debug().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("The node is sending us a new NodeKey, sending auth url")
|
||||
if h.cfg.OIDC.Issuer != "" {
|
||||
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
||||
resp.AuthURL = fmt.Sprintf(
|
||||
"%s/oidc/register/%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
||||
machineKey.HexString(),
|
||||
)
|
||||
} else {
|
||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
|
||||
}
|
||||
|
||||
// save the requested expiry time for retrieval later in the authentication flow
|
||||
m.RequestedExpiry = &req.Expiry
|
||||
m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey
|
||||
h.db.Save(&m)
|
||||
machine.RequestedExpiry = &req.Expiry
|
||||
machine.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey
|
||||
h.db.Save(&machine)
|
||||
|
||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
c.String(http.StatusInternalServerError, "")
|
||||
ctx.String(http.StatusInternalServerError, "")
|
||||
|
||||
return
|
||||
}
|
||||
c.Data(200, "application/json; charset=utf-8", respBody)
|
||||
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||
}
|
||||
|
||||
func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Machine) ([]byte, error) {
|
||||
func (h *Headscale) getMapResponse(
|
||||
machineKey wgkey.Key,
|
||||
req tailcfg.MapRequest,
|
||||
machine *Machine,
|
||||
) ([]byte, error) {
|
||||
log.Trace().
|
||||
Str("func", "getMapResponse").
|
||||
Str("machine", req.Hostinfo.Hostname).
|
||||
Msg("Creating Map response")
|
||||
node, err := m.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
|
||||
node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("func", "getMapResponse").
|
||||
Err(err).
|
||||
Msg("Cannot convert to node")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peers, err := h.getPeers(m)
|
||||
peers, err := h.getPeers(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("func", "getMapResponse").
|
||||
Err(err).
|
||||
Msg("Cannot fetch peers")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
profiles := getMapResponseUserProfiles(*m, peers)
|
||||
profiles := getMapResponseUserProfiles(*machine, peers)
|
||||
|
||||
nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
|
||||
if err != nil {
|
||||
|
@ -288,17 +322,16 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma
|
|||
Str("func", "getMapResponse").
|
||||
Err(err).
|
||||
Msg("Failed to convert peers to Tailscale nodes")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dnsConfig, err := getMapResponseDNSConfig(h.cfg.DNSConfig, h.cfg.BaseDomain, *m, peers)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("func", "getMapResponse").
|
||||
Err(err).
|
||||
Msg("Failed generate the DNSConfig")
|
||||
return nil, err
|
||||
}
|
||||
dnsConfig := getMapResponseDNSConfig(
|
||||
h.cfg.DNSConfig,
|
||||
h.cfg.BaseDomain,
|
||||
*machine,
|
||||
peers,
|
||||
)
|
||||
|
||||
resp := tailcfg.MapResponse{
|
||||
KeepAlive: false,
|
||||
|
@ -323,66 +356,71 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma
|
|||
|
||||
encoder, _ := zstd.NewWriter(nil)
|
||||
srcCompressed := encoder.EncodeAll(src, nil)
|
||||
respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey)
|
||||
respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
respBody, err = encode(resp, &mKey, h.privateKey)
|
||||
respBody, err = encode(resp, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// declare the incoming size on the first 4 bytes
|
||||
data := make([]byte, 4)
|
||||
data := make([]byte, reservedResponseHeaderSize)
|
||||
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
|
||||
data = append(data, respBody...)
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) getMapKeepAliveResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Machine) ([]byte, error) {
|
||||
resp := tailcfg.MapResponse{
|
||||
func (h *Headscale) getMapKeepAliveResponse(
|
||||
machineKey wgkey.Key,
|
||||
mapRequest tailcfg.MapRequest,
|
||||
) ([]byte, error) {
|
||||
mapResponse := tailcfg.MapResponse{
|
||||
KeepAlive: true,
|
||||
}
|
||||
var respBody []byte
|
||||
var err error
|
||||
if req.Compress == "zstd" {
|
||||
src, _ := json.Marshal(resp)
|
||||
if mapRequest.Compress == "zstd" {
|
||||
src, _ := json.Marshal(mapResponse)
|
||||
encoder, _ := zstd.NewWriter(nil)
|
||||
srcCompressed := encoder.EncodeAll(src, nil)
|
||||
respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey)
|
||||
respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
respBody, err = encode(resp, &mKey, h.privateKey)
|
||||
respBody, err = encode(mapResponse, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
data := make([]byte, 4)
|
||||
data := make([]byte, reservedResponseHeaderSize)
|
||||
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
|
||||
data = append(data, respBody...)
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) handleAuthKey(
|
||||
c *gin.Context,
|
||||
ctx *gin.Context,
|
||||
db *gorm.DB,
|
||||
idKey wgkey.Key,
|
||||
req tailcfg.RegisterRequest,
|
||||
m Machine,
|
||||
reqisterRequest tailcfg.RegisterRequest,
|
||||
machine Machine,
|
||||
) {
|
||||
log.Debug().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", req.Hostinfo.Hostname).
|
||||
Msgf("Processing auth key for %s", req.Hostinfo.Hostname)
|
||||
Str("machine", reqisterRequest.Hostinfo.Hostname).
|
||||
Msgf("Processing auth key for %s", reqisterRequest.Hostinfo.Hostname)
|
||||
resp := tailcfg.RegisterResponse{}
|
||||
pak, err := h.checkKeyValidity(req.Auth.AuthKey)
|
||||
pak, err := h.checkKeyValidity(reqisterRequest.Auth.AuthKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Err(err).
|
||||
Msg("Failed authentication via AuthKey")
|
||||
resp.MachineAuthorized = false
|
||||
|
@ -390,48 +428,56 @@ func (h *Headscale) handleAuthKey(
|
|||
if err != nil {
|
||||
log.Error().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
c.String(http.StatusInternalServerError, "")
|
||||
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc()
|
||||
ctx.String(http.StatusInternalServerError, "")
|
||||
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
|
||||
Inc()
|
||||
|
||||
return
|
||||
}
|
||||
c.Data(401, "application/json; charset=utf-8", respBody)
|
||||
ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody)
|
||||
log.Error().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Failed authentication via AuthKey")
|
||||
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc()
|
||||
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
|
||||
Inc()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Authentication key was valid, proceeding to acquire an IP address")
|
||||
ip, err := h.getAvailableIP()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Failed to find an available IP")
|
||||
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc()
|
||||
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
|
||||
Inc()
|
||||
|
||||
return
|
||||
}
|
||||
log.Info().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("ip", ip.String()).
|
||||
Msgf("Assigning %s to %s", ip, m.Name)
|
||||
Msgf("Assigning %s to %s", ip, machine.Name)
|
||||
|
||||
m.AuthKeyID = uint(pak.ID)
|
||||
m.IPAddress = ip.String()
|
||||
m.NamespaceID = pak.NamespaceID
|
||||
m.NodeKey = wgkey.Key(req.NodeKey).HexString() // we update it just in case
|
||||
m.Registered = true
|
||||
m.RegisterMethod = "authKey"
|
||||
db.Save(&m)
|
||||
machine.AuthKeyID = uint(pak.ID)
|
||||
machine.IPAddress = ip.String()
|
||||
machine.NamespaceID = pak.NamespaceID
|
||||
machine.NodeKey = wgkey.Key(reqisterRequest.NodeKey).
|
||||
HexString()
|
||||
// we update it just in case
|
||||
machine.Registered = true
|
||||
machine.RegisterMethod = "authKey"
|
||||
db.Save(&machine)
|
||||
|
||||
pak.Used = true
|
||||
db.Save(&pak)
|
||||
|
@ -442,18 +488,21 @@ func (h *Headscale) handleAuthKey(
|
|||
if err != nil {
|
||||
log.Error().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc()
|
||||
c.String(http.StatusInternalServerError, "Extremely sad!")
|
||||
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
|
||||
Inc()
|
||||
ctx.String(http.StatusInternalServerError, "Extremely sad!")
|
||||
|
||||
return
|
||||
}
|
||||
machineRegistrations.WithLabelValues("new", "authkey", "success", m.Namespace.Name).Inc()
|
||||
c.Data(200, "application/json; charset=utf-8", respBody)
|
||||
machineRegistrations.WithLabelValues("new", "authkey", "success", machine.Namespace.Name).
|
||||
Inc()
|
||||
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||
log.Info().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("ip", ip.String()).
|
||||
Msg("Successfully authenticated via AuthKey")
|
||||
}
|
||||
|
|
279
app.go
279
app.go
|
@ -18,20 +18,19 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/grpc-ecosystem/go-grpc-middleware"
|
||||
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/philip-bui/grpc-zerolog"
|
||||
"github.com/patrickmn/go-cache"
|
||||
zerolog "github.com/philip-bui/grpc-zerolog"
|
||||
zl "github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/soheilhy/cmux"
|
||||
ginprometheus "github.com/zsais/go-gin-prometheus"
|
||||
"golang.org/x/crypto/acme"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
|
@ -48,7 +47,16 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
AUTH_PREFIX = "Bearer "
|
||||
AuthPrefix = "Bearer "
|
||||
Postgres = "postgresql"
|
||||
Sqlite = "sqlite3"
|
||||
updateInterval = 5000
|
||||
HTTPReadTimeout = 30 * time.Second
|
||||
|
||||
errUnsupportedDatabase = Error("unsupported DB")
|
||||
errUnsupportedLetsEncryptChallengeType = Error(
|
||||
"unknown value for Lets Encrypt challenge type",
|
||||
)
|
||||
)
|
||||
|
||||
// Config contains the initial Headscale configuration.
|
||||
|
@ -151,16 +159,22 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
|
|||
|
||||
var dbString string
|
||||
switch cfg.DBtype {
|
||||
case "postgres":
|
||||
dbString = fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s sslmode=disable", cfg.DBhost,
|
||||
cfg.DBport, cfg.DBname, cfg.DBuser, cfg.DBpass)
|
||||
case "sqlite3":
|
||||
case Postgres:
|
||||
dbString = fmt.Sprintf(
|
||||
"host=%s port=%d dbname=%s user=%s password=%s sslmode=disable",
|
||||
cfg.DBhost,
|
||||
cfg.DBport,
|
||||
cfg.DBname,
|
||||
cfg.DBuser,
|
||||
cfg.DBpass,
|
||||
)
|
||||
case Sqlite:
|
||||
dbString = cfg.DBpath
|
||||
default:
|
||||
return nil, errors.New("unsupported DB")
|
||||
return nil, errUnsupportedDatabase
|
||||
}
|
||||
|
||||
h := Headscale{
|
||||
app := Headscale{
|
||||
cfg: cfg,
|
||||
dbType: cfg.DBtype,
|
||||
dbString: dbString,
|
||||
|
@ -169,33 +183,32 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
|
|||
aclRules: tailcfg.FilterAllowAll, // default allowall
|
||||
}
|
||||
|
||||
err = h.initDB()
|
||||
err = app.initDB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cfg.OIDC.Issuer != "" {
|
||||
err = h.initOIDC()
|
||||
err = app.initOIDC()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS
|
||||
magicDNSDomains, err := generateMagicDNSRootDomains(h.cfg.IPPrefix, h.cfg.BaseDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS
|
||||
magicDNSDomains := generateMagicDNSRootDomains(
|
||||
app.cfg.IPPrefix,
|
||||
)
|
||||
// we might have routes already from Split DNS
|
||||
if h.cfg.DNSConfig.Routes == nil {
|
||||
h.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver)
|
||||
if app.cfg.DNSConfig.Routes == nil {
|
||||
app.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver)
|
||||
}
|
||||
for _, d := range magicDNSDomains {
|
||||
h.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil
|
||||
app.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil
|
||||
}
|
||||
}
|
||||
|
||||
return &h, nil
|
||||
return &app, nil
|
||||
}
|
||||
|
||||
// Redirect to our TLS url.
|
||||
|
@ -221,30 +234,37 @@ func (h *Headscale) expireEphemeralNodesWorker() {
|
|||
return
|
||||
}
|
||||
|
||||
for _, ns := range namespaces {
|
||||
machines, err := h.ListMachinesInNamespace(ns.Name)
|
||||
for _, namespace := range namespaces {
|
||||
machines, err := h.ListMachinesInNamespace(namespace.Name)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("namespace", ns.Name).Msg("Error listing machines in namespace")
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("namespace", namespace.Name).
|
||||
Msg("Error listing machines in namespace")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
for _, m := range machines {
|
||||
if m.AuthKey != nil && m.LastSeen != nil && m.AuthKey.Ephemeral &&
|
||||
time.Now().After(m.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
|
||||
log.Info().Str("machine", m.Name).Msg("Ephemeral client removed from database")
|
||||
for _, machine := range machines {
|
||||
if machine.AuthKey != nil && machine.LastSeen != nil &&
|
||||
machine.AuthKey.Ephemeral &&
|
||||
time.Now().
|
||||
After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
|
||||
log.Info().
|
||||
Str("machine", machine.Name).
|
||||
Msg("Ephemeral client removed from database")
|
||||
|
||||
err = h.db.Unscoped().Delete(m).Error
|
||||
err = h.db.Unscoped().Delete(machine).Error
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("🤮 Cannot delete ephemeral machine from the database")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
h.setLastStateChangeToNow(ns.Name)
|
||||
h.setLastStateChangeToNow(namespace.Name)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -266,36 +286,56 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
|||
req interface{},
|
||||
info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler) (interface{}, error) {
|
||||
|
||||
// Check if the request is coming from the on-server client.
|
||||
// This is not secure, but it is to maintain maintainability
|
||||
// with the "legacy" database-based client
|
||||
// It is also neede for grpc-gateway to be able to connect to
|
||||
// the server
|
||||
p, _ := peer.FromContext(ctx)
|
||||
client, _ := peer.FromContext(ctx)
|
||||
|
||||
log.Trace().Caller().Str("client_address", p.Addr.String()).Msg("Client is trying to authenticate")
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("client_address", client.Addr.String()).
|
||||
Msg("Client is trying to authenticate")
|
||||
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
meta, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
log.Error().Caller().Str("client_address", p.Addr.String()).Msg("Retrieving metadata is failed")
|
||||
return ctx, status.Errorf(codes.InvalidArgument, "Retrieving metadata is failed")
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("client_address", client.Addr.String()).
|
||||
Msg("Retrieving metadata is failed")
|
||||
|
||||
return ctx, status.Errorf(
|
||||
codes.InvalidArgument,
|
||||
"Retrieving metadata is failed",
|
||||
)
|
||||
}
|
||||
|
||||
authHeader, ok := md["authorization"]
|
||||
authHeader, ok := meta["authorization"]
|
||||
if !ok {
|
||||
log.Error().Caller().Str("client_address", p.Addr.String()).Msg("Authorization token is not supplied")
|
||||
return ctx, status.Errorf(codes.Unauthenticated, "Authorization token is not supplied")
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("client_address", client.Addr.String()).
|
||||
Msg("Authorization token is not supplied")
|
||||
|
||||
return ctx, status.Errorf(
|
||||
codes.Unauthenticated,
|
||||
"Authorization token is not supplied",
|
||||
)
|
||||
}
|
||||
|
||||
token := authHeader[0]
|
||||
|
||||
if !strings.HasPrefix(token, AUTH_PREFIX) {
|
||||
if !strings.HasPrefix(token, AuthPrefix) {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("client_address", p.Addr.String()).
|
||||
Str("client_address", client.Addr.String()).
|
||||
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
||||
return ctx, status.Error(codes.Unauthenticated, `missing "Bearer " prefix in "Authorization" header`)
|
||||
|
||||
return ctx, status.Error(
|
||||
codes.Unauthenticated,
|
||||
`missing "Bearer " prefix in "Authorization" header`,
|
||||
)
|
||||
}
|
||||
|
||||
// TODO(kradalby): Implement API key backend:
|
||||
|
@ -307,35 +347,38 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
|||
// Currently all other than localhost traffic is unauthorized, this is intentional to allow
|
||||
// us to make use of gRPC for our CLI, but not having to implement any of the remote capabilities
|
||||
// and API key auth
|
||||
return ctx, status.Error(codes.Unauthenticated, "Authentication is not implemented yet")
|
||||
return ctx, status.Error(
|
||||
codes.Unauthenticated,
|
||||
"Authentication is not implemented yet",
|
||||
)
|
||||
|
||||
//if strings.TrimPrefix(token, AUTH_PREFIX) != a.Token {
|
||||
// log.Error().Caller().Str("client_address", p.Addr.String()).Msg("invalid token")
|
||||
// return ctx, status.Error(codes.Unauthenticated, "invalid token")
|
||||
//}
|
||||
// if strings.TrimPrefix(token, AUTH_PREFIX) != a.Token {
|
||||
// log.Error().Caller().Str("client_address", p.Addr.String()).Msg("invalid token")
|
||||
// return ctx, status.Error(codes.Unauthenticated, "invalid token")
|
||||
// }
|
||||
|
||||
// return handler(ctx, req)
|
||||
}
|
||||
|
||||
func (h *Headscale) httpAuthenticationMiddleware(c *gin.Context) {
|
||||
func (h *Headscale) httpAuthenticationMiddleware(ctx *gin.Context) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("client_address", c.ClientIP()).
|
||||
Str("client_address", ctx.ClientIP()).
|
||||
Msg("HTTP authentication invoked")
|
||||
|
||||
authHeader := c.GetHeader("authorization")
|
||||
authHeader := ctx.GetHeader("authorization")
|
||||
|
||||
if !strings.HasPrefix(authHeader, AUTH_PREFIX) {
|
||||
if !strings.HasPrefix(authHeader, AuthPrefix) {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("client_address", c.ClientIP()).
|
||||
Str("client_address", ctx.ClientIP()).
|
||||
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
||||
c.AbortWithStatus(http.StatusUnauthorized)
|
||||
ctx.AbortWithStatus(http.StatusUnauthorized)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
c.AbortWithStatus(http.StatusUnauthorized)
|
||||
ctx.AbortWithStatus(http.StatusUnauthorized)
|
||||
|
||||
// TODO(kradalby): Implement API key backend
|
||||
// Currently all traffic is unauthorized, this is intentional to allow
|
||||
|
@ -359,6 +402,7 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error {
|
|||
if _, err := os.Stat(h.cfg.UnixSocket); errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return os.Remove(h.cfg.UnixSocket)
|
||||
}
|
||||
|
||||
|
@ -401,14 +445,17 @@ func (h *Headscale) Serve() error {
|
|||
|
||||
// Create the cmux object that will multiplex 2 protocols on the same port.
|
||||
// The two following listeners will be served on the same port below gracefully.
|
||||
m := cmux.New(networkListener)
|
||||
networkMutex := cmux.New(networkListener)
|
||||
// Match gRPC requests here
|
||||
grpcListener := m.MatchWithWriters(
|
||||
grpcListener := networkMutex.MatchWithWriters(
|
||||
cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"),
|
||||
cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc+proto"),
|
||||
cmux.HTTP2MatchHeaderFieldSendSettings(
|
||||
"content-type",
|
||||
"application/grpc+proto",
|
||||
),
|
||||
)
|
||||
// Otherwise match regular http requests.
|
||||
httpListener := m.Match(cmux.Any())
|
||||
httpListener := networkMutex.Match(cmux.Any())
|
||||
|
||||
grpcGatewayMux := runtime.NewServeMux()
|
||||
|
||||
|
@ -431,30 +478,33 @@ func (h *Headscale) Serve() error {
|
|||
return err
|
||||
}
|
||||
|
||||
r := gin.Default()
|
||||
router := gin.Default()
|
||||
|
||||
p := ginprometheus.NewPrometheus("gin")
|
||||
p.Use(r)
|
||||
prometheus := ginprometheus.NewPrometheus("gin")
|
||||
prometheus.Use(router)
|
||||
|
||||
r.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) })
|
||||
r.GET("/key", h.KeyHandler)
|
||||
r.GET("/register", h.RegisterWebAPI)
|
||||
r.POST("/machine/:id/map", h.PollNetMapHandler)
|
||||
r.POST("/machine/:id", h.RegistrationHandler)
|
||||
r.GET("/oidc/register/:mkey", h.RegisterOIDC)
|
||||
r.GET("/oidc/callback", h.OIDCCallback)
|
||||
r.GET("/apple", h.AppleMobileConfig)
|
||||
r.GET("/apple/:platform", h.ApplePlatformConfig)
|
||||
r.GET("/swagger", SwaggerUI)
|
||||
r.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1)
|
||||
router.GET(
|
||||
"/health",
|
||||
func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) },
|
||||
)
|
||||
router.GET("/key", h.KeyHandler)
|
||||
router.GET("/register", h.RegisterWebAPI)
|
||||
router.POST("/machine/:id/map", h.PollNetMapHandler)
|
||||
router.POST("/machine/:id", h.RegistrationHandler)
|
||||
router.GET("/oidc/register/:mkey", h.RegisterOIDC)
|
||||
router.GET("/oidc/callback", h.OIDCCallback)
|
||||
router.GET("/apple", h.AppleMobileConfig)
|
||||
router.GET("/apple/:platform", h.ApplePlatformConfig)
|
||||
router.GET("/swagger", SwaggerUI)
|
||||
router.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1)
|
||||
|
||||
api := r.Group("/api")
|
||||
api := router.Group("/api")
|
||||
api.Use(h.httpAuthenticationMiddleware)
|
||||
{
|
||||
api.Any("/v1/*any", gin.WrapF(grpcGatewayMux.ServeHTTP))
|
||||
}
|
||||
|
||||
r.NoRoute(stdoutHandler)
|
||||
router.NoRoute(stdoutHandler)
|
||||
|
||||
// Fetch an initial DERP Map before we start serving
|
||||
h.DERPMap = GetDERPMap(h.cfg.DERP)
|
||||
|
@ -466,14 +516,13 @@ func (h *Headscale) Serve() error {
|
|||
}
|
||||
|
||||
// I HATE THIS
|
||||
updateMillisecondsWait := int64(5000)
|
||||
go h.watchForKVUpdates(updateMillisecondsWait)
|
||||
go h.expireEphemeralNodes(updateMillisecondsWait)
|
||||
go h.watchForKVUpdates(updateInterval)
|
||||
go h.expireEphemeralNodes(updateInterval)
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: h.cfg.Addr,
|
||||
Handler: r,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
Handler: router,
|
||||
ReadTimeout: HTTPReadTimeout,
|
||||
// Go does not handle timeouts in HTTP very well, and there is
|
||||
// no good way to handle streaming timeouts, therefore we need to
|
||||
// keep this at unlimited and be careful to clean up connections
|
||||
|
@ -519,36 +568,40 @@ func (h *Headscale) Serve() error {
|
|||
reflection.Register(grpcServer)
|
||||
reflection.Register(grpcSocket)
|
||||
|
||||
g := new(errgroup.Group)
|
||||
errorGroup := new(errgroup.Group)
|
||||
|
||||
g.Go(func() error { return grpcSocket.Serve(socketListener) })
|
||||
errorGroup.Go(func() error { return grpcSocket.Serve(socketListener) })
|
||||
|
||||
// TODO(kradalby): Verify if we need the same TLS setup for gRPC as HTTP
|
||||
g.Go(func() error { return grpcServer.Serve(grpcListener) })
|
||||
errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) })
|
||||
|
||||
if tlsConfig != nil {
|
||||
g.Go(func() error {
|
||||
errorGroup.Go(func() error {
|
||||
tlsl := tls.NewListener(httpListener, tlsConfig)
|
||||
|
||||
return httpServer.Serve(tlsl)
|
||||
})
|
||||
} else {
|
||||
g.Go(func() error { return httpServer.Serve(httpListener) })
|
||||
errorGroup.Go(func() error { return httpServer.Serve(httpListener) })
|
||||
}
|
||||
|
||||
g.Go(func() error { return m.Serve() })
|
||||
errorGroup.Go(func() error { return networkMutex.Serve() })
|
||||
|
||||
log.Info().Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr)
|
||||
log.Info().
|
||||
Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr)
|
||||
|
||||
return g.Wait()
|
||||
return errorGroup.Wait()
|
||||
}
|
||||
|
||||
func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
||||
var err error
|
||||
if h.cfg.TLSLetsEncryptHostname != "" {
|
||||
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
||||
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
|
||||
log.Warn().
|
||||
Msg("Listening with TLS but ServerURL does not start with https://")
|
||||
}
|
||||
|
||||
m := autocert.Manager{
|
||||
certManager := autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
HostPolicy: autocert.HostWhitelist(h.cfg.TLSLetsEncryptHostname),
|
||||
Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir),
|
||||
|
@ -558,40 +611,44 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
|||
Email: h.cfg.ACMEEmail,
|
||||
}
|
||||
|
||||
if h.cfg.TLSLetsEncryptChallengeType == "TLS-ALPN-01" {
|
||||
switch h.cfg.TLSLetsEncryptChallengeType {
|
||||
case "TLS-ALPN-01":
|
||||
// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
|
||||
// The RFC requires that the validation is done on port 443; in other words, headscale
|
||||
// must be reachable on port 443.
|
||||
return m.TLSConfig(), nil
|
||||
} else if h.cfg.TLSLetsEncryptChallengeType == "HTTP-01" {
|
||||
return certManager.TLSConfig(), nil
|
||||
|
||||
case "HTTP-01":
|
||||
// Configuration via autocert with HTTP-01. This requires listening on
|
||||
// port 80 for the certificate validation in addition to the headscale
|
||||
// service, which can be configured to run on any other port.
|
||||
go func() {
|
||||
log.Fatal().
|
||||
Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, m.HTTPHandler(http.HandlerFunc(h.redirect)))).
|
||||
Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))).
|
||||
Msg("failed to set up a HTTP server")
|
||||
}()
|
||||
|
||||
return m.TLSConfig(), nil
|
||||
} else {
|
||||
return nil, errors.New("unknown value for TLSLetsEncryptChallengeType")
|
||||
return certManager.TLSConfig(), nil
|
||||
|
||||
default:
|
||||
return nil, errUnsupportedLetsEncryptChallengeType
|
||||
}
|
||||
} else if h.cfg.TLSCertPath == "" {
|
||||
if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
|
||||
log.Warn().Msg("Listening without TLS but ServerURL does not start with http://")
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
return nil, err
|
||||
} else {
|
||||
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
||||
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
|
||||
}
|
||||
var err error
|
||||
tlsConfig := &tls.Config{}
|
||||
tlsConfig.ClientAuth = tls.RequireAnyClientCert
|
||||
tlsConfig.NextProtos = []string{"http/1.1"}
|
||||
tlsConfig.Certificates = make([]tls.Certificate, 1)
|
||||
tlsConfig := &tls.Config{
|
||||
ClientAuth: tls.RequireAnyClientCert,
|
||||
NextProtos: []string{"http/1.1"},
|
||||
Certificates: make([]tls.Certificate, 1),
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(h.cfg.TLSCertPath, h.cfg.TLSKeyPath)
|
||||
|
||||
return tlsConfig, err
|
||||
|
@ -628,13 +685,13 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time {
|
|||
}
|
||||
}
|
||||
|
||||
func stdoutHandler(c *gin.Context) {
|
||||
b, _ := io.ReadAll(c.Request.Body)
|
||||
func stdoutHandler(ctx *gin.Context) {
|
||||
body, _ := io.ReadAll(ctx.Request.Body)
|
||||
|
||||
log.Trace().
|
||||
Interface("header", c.Request.Header).
|
||||
Interface("proto", c.Request.Proto).
|
||||
Interface("url", c.Request.URL).
|
||||
Bytes("body", b).
|
||||
Interface("header", ctx.Request.Header).
|
||||
Interface("proto", ctx.Request.Proto).
|
||||
Interface("url", ctx.Request.URL).
|
||||
Bytes("body", body).
|
||||
Msg("Request did not match")
|
||||
}
|
||||
|
|
14
app_test.go
14
app_test.go
|
@ -17,8 +17,10 @@ var _ = check.Suite(&Suite{})
|
|||
|
||||
type Suite struct{}
|
||||
|
||||
var tmpDir string
|
||||
var h Headscale
|
||||
var (
|
||||
tmpDir string
|
||||
app Headscale
|
||||
)
|
||||
|
||||
func (s *Suite) SetUpTest(c *check.C) {
|
||||
s.ResetDB(c)
|
||||
|
@ -41,18 +43,18 @@ func (s *Suite) ResetDB(c *check.C) {
|
|||
IPPrefix: netaddr.MustParseIPPrefix("10.27.0.0/23"),
|
||||
}
|
||||
|
||||
h = Headscale{
|
||||
app = Headscale{
|
||||
cfg: cfg,
|
||||
dbType: "sqlite3",
|
||||
dbString: tmpDir + "/headscale_test.db",
|
||||
}
|
||||
err = h.initDB()
|
||||
err = app.initDB()
|
||||
if err != nil {
|
||||
c.Fatal(err)
|
||||
}
|
||||
db, err := h.openDB()
|
||||
db, err := app.openDB()
|
||||
if err != nil {
|
||||
c.Fatal(err)
|
||||
}
|
||||
h.db = db
|
||||
app.db = db
|
||||
}
|
||||
|
|
|
@ -5,16 +5,15 @@ import (
|
|||
"net/http"
|
||||
"text/template"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// AppleMobileConfig shows a simple message in the browser to point to the CLI
|
||||
// Listens in /register
|
||||
func (h *Headscale) AppleMobileConfig(c *gin.Context) {
|
||||
t := template.Must(template.New("apple").Parse(`
|
||||
// Listens in /register.
|
||||
func (h *Headscale) AppleMobileConfig(ctx *gin.Context) {
|
||||
appleTemplate := template.Must(template.New("apple").Parse(`
|
||||
<html>
|
||||
<body>
|
||||
<h1>Apple configuration profiles</h1>
|
||||
|
@ -56,7 +55,7 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {
|
|||
|
||||
<p>Or</p>
|
||||
<p>Use your terminal to configure the default setting for Tailscale by issuing:</p>
|
||||
<code>defaults write io.tailscale.ipn.macos ControlURL {{.Url}}</code>
|
||||
<code>defaults write io.tailscale.ipn.macos ControlURL {{.URL}}</code>
|
||||
|
||||
<p>Restart Tailscale.app and log in.</p>
|
||||
|
||||
|
@ -64,24 +63,29 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {
|
|||
</html>`))
|
||||
|
||||
config := map[string]interface{}{
|
||||
"Url": h.cfg.ServerURL,
|
||||
"URL": h.cfg.ServerURL,
|
||||
}
|
||||
|
||||
var payload bytes.Buffer
|
||||
if err := t.Execute(&payload, config); err != nil {
|
||||
if err := appleTemplate.Execute(&payload, config); err != nil {
|
||||
log.Error().
|
||||
Str("handler", "AppleMobileConfig").
|
||||
Err(err).
|
||||
Msg("Could not render Apple index template")
|
||||
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple index template"))
|
||||
ctx.Data(
|
||||
http.StatusInternalServerError,
|
||||
"text/html; charset=utf-8",
|
||||
[]byte("Could not render Apple index template"),
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
|
||||
ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
|
||||
}
|
||||
|
||||
func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
||||
platform := c.Param("platform")
|
||||
func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) {
|
||||
platform := ctx.Param("platform")
|
||||
|
||||
id, err := uuid.NewV4()
|
||||
if err != nil {
|
||||
|
@ -89,23 +93,33 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
|||
Str("handler", "ApplePlatformConfig").
|
||||
Err(err).
|
||||
Msg("Failed not create UUID")
|
||||
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Failed to create UUID"))
|
||||
ctx.Data(
|
||||
http.StatusInternalServerError,
|
||||
"text/html; charset=utf-8",
|
||||
[]byte("Failed to create UUID"),
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
contentId, err := uuid.NewV4()
|
||||
contentID, err := uuid.NewV4()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "ApplePlatformConfig").
|
||||
Err(err).
|
||||
Msg("Failed not create UUID")
|
||||
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Failed to create UUID"))
|
||||
ctx.Data(
|
||||
http.StatusInternalServerError,
|
||||
"text/html; charset=utf-8",
|
||||
[]byte("Failed to create UUID"),
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
platformConfig := AppleMobilePlatformConfig{
|
||||
UUID: contentId,
|
||||
Url: h.cfg.ServerURL,
|
||||
UUID: contentID,
|
||||
URL: h.cfg.ServerURL,
|
||||
}
|
||||
|
||||
var payload bytes.Buffer
|
||||
|
@ -117,7 +131,12 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
|||
Str("handler", "ApplePlatformConfig").
|
||||
Err(err).
|
||||
Msg("Could not render Apple macOS template")
|
||||
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple macOS template"))
|
||||
ctx.Data(
|
||||
http.StatusInternalServerError,
|
||||
"text/html; charset=utf-8",
|
||||
[]byte("Could not render Apple macOS template"),
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
case "ios":
|
||||
|
@ -126,17 +145,27 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
|||
Str("handler", "ApplePlatformConfig").
|
||||
Err(err).
|
||||
Msg("Could not render Apple iOS template")
|
||||
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple iOS template"))
|
||||
ctx.Data(
|
||||
http.StatusInternalServerError,
|
||||
"text/html; charset=utf-8",
|
||||
[]byte("Could not render Apple iOS template"),
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
default:
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte("Invalid platform, only ios and macos is supported"))
|
||||
ctx.Data(
|
||||
http.StatusOK,
|
||||
"text/html; charset=utf-8",
|
||||
[]byte("Invalid platform, only ios and macos is supported"),
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
config := AppleMobileConfig{
|
||||
UUID: id,
|
||||
Url: h.cfg.ServerURL,
|
||||
URL: h.cfg.ServerURL,
|
||||
Payload: payload.String(),
|
||||
}
|
||||
|
||||
|
@ -146,25 +175,35 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
|||
Str("handler", "ApplePlatformConfig").
|
||||
Err(err).
|
||||
Msg("Could not render Apple platform template")
|
||||
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple platform template"))
|
||||
ctx.Data(
|
||||
http.StatusInternalServerError,
|
||||
"text/html; charset=utf-8",
|
||||
[]byte("Could not render Apple platform template"),
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "application/x-apple-aspen-config; charset=utf-8", content.Bytes())
|
||||
ctx.Data(
|
||||
http.StatusOK,
|
||||
"application/x-apple-aspen-config; charset=utf-8",
|
||||
content.Bytes(),
|
||||
)
|
||||
}
|
||||
|
||||
type AppleMobileConfig struct {
|
||||
UUID uuid.UUID
|
||||
Url string
|
||||
URL string
|
||||
Payload string
|
||||
}
|
||||
|
||||
type AppleMobilePlatformConfig struct {
|
||||
UUID uuid.UUID
|
||||
Url string
|
||||
URL string
|
||||
}
|
||||
|
||||
var commonTemplate = template.Must(template.New("mobileconfig").Parse(`<?xml version="1.0" encoding="UTF-8"?>
|
||||
var commonTemplate = template.Must(
|
||||
template.New("mobileconfig").Parse(`<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
|
@ -173,7 +212,7 @@ var commonTemplate = template.Must(template.New("mobileconfig").Parse(`<?xml ver
|
|||
<key>PayloadDisplayName</key>
|
||||
<string>Headscale</string>
|
||||
<key>PayloadDescription</key>
|
||||
<string>Configure Tailscale login server to: {{.Url}}</string>
|
||||
<string>Configure Tailscale login server to: {{.URL}}</string>
|
||||
<key>PayloadIdentifier</key>
|
||||
<string>com.github.juanfont.headscale</string>
|
||||
<key>PayloadRemovalDisallowed</key>
|
||||
|
@ -187,7 +226,8 @@ var commonTemplate = template.Must(template.New("mobileconfig").Parse(`<?xml ver
|
|||
{{.Payload}}
|
||||
</array>
|
||||
</dict>
|
||||
</plist>`))
|
||||
</plist>`),
|
||||
)
|
||||
|
||||
var iosTemplate = template.Must(template.New("iosTemplate").Parse(`
|
||||
<dict>
|
||||
|
@ -203,7 +243,7 @@ var iosTemplate = template.Must(template.New("iosTemplate").Parse(`
|
|||
<true/>
|
||||
|
||||
<key>ControlURL</key>
|
||||
<string>{{.Url}}</string>
|
||||
<string>{{.URL}}</string>
|
||||
</dict>
|
||||
`))
|
||||
|
||||
|
@ -221,6 +261,6 @@ var macosTemplate = template.Must(template.New("macosTemplate").Parse(`
|
|||
<true/>
|
||||
|
||||
<key>ControlURL</key>
|
||||
<string>{{.Url}}</string>
|
||||
<string>{{.URL}}</string>
|
||||
</dict>
|
||||
`))
|
||||
|
|
19
cli_test.go
19
cli_test.go
|
@ -7,31 +7,34 @@ import (
|
|||
)
|
||||
|
||||
func (s *Suite) TestRegisterMachine(c *check.C) {
|
||||
n, err := h.CreateNamespace("test")
|
||||
namespace, err := app.CreateNamespace("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
m := Machine{
|
||||
machine := Machine{
|
||||
ID: 0,
|
||||
MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "testmachine",
|
||||
NamespaceID: n.ID,
|
||||
NamespaceID: namespace.ID,
|
||||
IPAddress: "10.0.0.1",
|
||||
Expiry: &now,
|
||||
RequestedExpiry: &now,
|
||||
}
|
||||
h.db.Save(&m)
|
||||
app.db.Save(&machine)
|
||||
|
||||
_, err = h.GetMachine("test", "testmachine")
|
||||
_, err = app.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
m2, err := h.RegisterMachine("8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e", n.Name)
|
||||
machineAfterRegistering, err := app.RegisterMachine(
|
||||
"8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e",
|
||||
namespace.Name,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(m2.Registered, check.Equals, true)
|
||||
c.Assert(machineAfterRegistering.Registered, check.Equals, true)
|
||||
|
||||
_, err = m2.GetHostInfo()
|
||||
_, err = machineAfterRegistering.GetHostInfo()
|
||||
c.Assert(err, check.IsNil)
|
||||
}
|
||||
|
|
|
@ -27,7 +27,8 @@ func init() {
|
|||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
createNodeCmd.Flags().StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise")
|
||||
createNodeCmd.Flags().
|
||||
StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise")
|
||||
|
||||
debugCmd.AddCommand(createNodeCmd)
|
||||
}
|
||||
|
@ -47,6 +48,7 @@ var createNodeCmd = &cobra.Command{
|
|||
namespace, err := cmd.Flags().GetString("namespace")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -56,19 +58,34 @@ var createNodeCmd = &cobra.Command{
|
|||
|
||||
name, err := cmd.Flags().GetString("name")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting node from flag: %s", err), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting node from flag: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
machineKey, err := cmd.Flags().GetString("key")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting key from flag: %s", err), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting key from flag: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
routes, err := cmd.Flags().GetStringSlice("route")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting routes from flag: %s", err), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting routes from flag: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -81,7 +98,12 @@ var createNodeCmd = &cobra.Command{
|
|||
|
||||
response, err := client.DebugCreateMachine(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Cannot create machine: %s", status.Convert(err).Message()), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Cannot create machine: %s", status.Convert(err).Message()),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
|
||||
survey "github.com/AlecAivazis/survey/v2"
|
||||
"github.com/juanfont/headscale"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/pterm/pterm"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
@ -19,6 +20,10 @@ func init() {
|
|||
namespaceCmd.AddCommand(renameNamespaceCmd)
|
||||
}
|
||||
|
||||
const (
|
||||
errMissingParameter = headscale.Error("missing parameters")
|
||||
)
|
||||
|
||||
var namespaceCmd = &cobra.Command{
|
||||
Use: "namespaces",
|
||||
Short: "Manage the namespaces of Headscale",
|
||||
|
@ -29,8 +34,9 @@ var createNamespaceCmd = &cobra.Command{
|
|||
Short: "Creates a new namespace",
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
if len(args) < 1 {
|
||||
return fmt.Errorf("Missing parameters")
|
||||
return errMissingParameter
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
|
@ -49,7 +55,15 @@ var createNamespaceCmd = &cobra.Command{
|
|||
log.Trace().Interface("request", request).Msg("Sending CreateNamespace request")
|
||||
response, err := client.CreateNamespace(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Cannot create namespace: %s", status.Convert(err).Message()), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Cannot create namespace: %s",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -62,8 +76,9 @@ var destroyNamespaceCmd = &cobra.Command{
|
|||
Short: "Destroys a namespace",
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
if len(args) < 1 {
|
||||
return fmt.Errorf("Missing parameters")
|
||||
return errMissingParameter
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
|
@ -81,7 +96,12 @@ var destroyNamespaceCmd = &cobra.Command{
|
|||
|
||||
_, err := client.GetNamespace(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error: %s", status.Convert(err).Message()), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error: %s", status.Convert(err).Message()),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -89,7 +109,10 @@ var destroyNamespaceCmd = &cobra.Command{
|
|||
force, _ := cmd.Flags().GetBool("force")
|
||||
if !force {
|
||||
prompt := &survey.Confirm{
|
||||
Message: fmt.Sprintf("Do you want to remove the namespace '%s' and any associated preauthkeys?", namespaceName),
|
||||
Message: fmt.Sprintf(
|
||||
"Do you want to remove the namespace '%s' and any associated preauthkeys?",
|
||||
namespaceName,
|
||||
),
|
||||
}
|
||||
err := survey.AskOne(prompt, &confirm)
|
||||
if err != nil {
|
||||
|
@ -102,7 +125,15 @@ var destroyNamespaceCmd = &cobra.Command{
|
|||
|
||||
response, err := client.DeleteNamespace(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Cannot destroy namespace: %s", status.Convert(err).Message()), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Cannot destroy namespace: %s",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
SuccessOutput(response, "Namespace destroyed", output)
|
||||
|
@ -126,19 +157,25 @@ var listNamespacesCmd = &cobra.Command{
|
|||
|
||||
response, err := client.ListNamespaces(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Cannot get namespaces: %s", status.Convert(err).Message()), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Cannot get namespaces: %s", status.Convert(err).Message()),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if output != "" {
|
||||
SuccessOutput(response.Namespaces, "", output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
d := pterm.TableData{{"ID", "Name", "Created"}}
|
||||
tableData := pterm.TableData{{"ID", "Name", "Created"}}
|
||||
for _, namespace := range response.GetNamespaces() {
|
||||
d = append(
|
||||
d,
|
||||
tableData = append(
|
||||
tableData,
|
||||
[]string{
|
||||
namespace.GetId(),
|
||||
namespace.GetName(),
|
||||
|
@ -146,9 +183,14 @@ var listNamespacesCmd = &cobra.Command{
|
|||
},
|
||||
)
|
||||
}
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
},
|
||||
|
@ -158,9 +200,11 @@ var renameNamespaceCmd = &cobra.Command{
|
|||
Use: "rename OLD_NAME NEW_NAME",
|
||||
Short: "Renames a namespace",
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
if len(args) < 2 {
|
||||
return fmt.Errorf("Missing parameters")
|
||||
expectedArguments := 2
|
||||
if len(args) < expectedArguments {
|
||||
return errMissingParameter
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
|
@ -177,7 +221,15 @@ var renameNamespaceCmd = &cobra.Command{
|
|||
|
||||
response, err := client.RenameNamespace(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Cannot rename namespace: %s", status.Convert(err).Message()), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Cannot rename namespace: %s",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"time"
|
||||
|
||||
survey "github.com/AlecAivazis/survey/v2"
|
||||
"github.com/juanfont/headscale"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/pterm/pterm"
|
||||
"github.com/spf13/cobra"
|
||||
|
@ -77,6 +78,7 @@ var registerNodeCmd = &cobra.Command{
|
|||
namespace, err := cmd.Flags().GetString("namespace")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -86,7 +88,12 @@ var registerNodeCmd = &cobra.Command{
|
|||
|
||||
machineKey, err := cmd.Flags().GetString("key")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting machine key from flag: %s", err), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting machine key from flag: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -97,7 +104,15 @@ var registerNodeCmd = &cobra.Command{
|
|||
|
||||
response, err := client.RegisterMachine(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Cannot register machine: %s\n", status.Convert(err).Message()), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Cannot register machine: %s\n",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -113,6 +128,7 @@ var listNodesCmd = &cobra.Command{
|
|||
namespace, err := cmd.Flags().GetString("namespace")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -126,24 +142,36 @@ var listNodesCmd = &cobra.Command{
|
|||
|
||||
response, err := client.ListMachines(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if output != "" {
|
||||
SuccessOutput(response.Machines, "", output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
d, err := nodesToPtables(namespace, response.Machines)
|
||||
tableData, err := nodesToPtables(namespace, response.Machines)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
},
|
||||
|
@ -155,9 +183,14 @@ var deleteNodeCmd = &cobra.Command{
|
|||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
id, err := cmd.Flags().GetInt("identifier")
|
||||
identifier, err := cmd.Flags().GetInt("identifier")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error converting ID to integer: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -166,24 +199,35 @@ var deleteNodeCmd = &cobra.Command{
|
|||
defer conn.Close()
|
||||
|
||||
getRequest := &v1.GetMachineRequest{
|
||||
MachineId: uint64(id),
|
||||
MachineId: uint64(identifier),
|
||||
}
|
||||
|
||||
getResponse, err := client.GetMachine(ctx, getRequest)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Error getting node node: %s",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
deleteRequest := &v1.DeleteMachineRequest{
|
||||
MachineId: uint64(id),
|
||||
MachineId: uint64(identifier),
|
||||
}
|
||||
|
||||
confirm := false
|
||||
force, _ := cmd.Flags().GetBool("force")
|
||||
if !force {
|
||||
prompt := &survey.Confirm{
|
||||
Message: fmt.Sprintf("Do you want to remove the node %s?", getResponse.GetMachine().Name),
|
||||
Message: fmt.Sprintf(
|
||||
"Do you want to remove the node %s?",
|
||||
getResponse.GetMachine().Name,
|
||||
),
|
||||
}
|
||||
err = survey.AskOne(prompt, &confirm)
|
||||
if err != nil {
|
||||
|
@ -195,13 +239,26 @@ var deleteNodeCmd = &cobra.Command{
|
|||
response, err := client.DeleteMachine(ctx, deleteRequest)
|
||||
if output != "" {
|
||||
SuccessOutput(response, "", output)
|
||||
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error deleting node: %s", status.Convert(err).Message()), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Error deleting node: %s",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
SuccessOutput(map[string]string{"Result": "Node deleted"}, "Node deleted", output)
|
||||
SuccessOutput(
|
||||
map[string]string{"Result": "Node deleted"},
|
||||
"Node deleted",
|
||||
output,
|
||||
)
|
||||
} else {
|
||||
SuccessOutput(map[string]string{"Result": "Node not deleted"}, "Node not deleted", output)
|
||||
}
|
||||
|
@ -210,12 +267,12 @@ var deleteNodeCmd = &cobra.Command{
|
|||
|
||||
func sharingWorker(
|
||||
cmd *cobra.Command,
|
||||
args []string,
|
||||
) (string, *v1.Machine, *v1.Namespace, error) {
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
namespaceStr, err := cmd.Flags().GetString("namespace")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
||||
|
||||
return "", nil, nil, err
|
||||
}
|
||||
|
||||
|
@ -223,19 +280,25 @@ func sharingWorker(
|
|||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
id, err := cmd.Flags().GetInt("identifier")
|
||||
identifier, err := cmd.Flags().GetInt("identifier")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output)
|
||||
|
||||
return "", nil, nil, err
|
||||
}
|
||||
|
||||
machineRequest := &v1.GetMachineRequest{
|
||||
MachineId: uint64(id),
|
||||
MachineId: uint64(identifier),
|
||||
}
|
||||
|
||||
machineResponse, err := client.GetMachine(ctx, machineRequest)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()),
|
||||
output,
|
||||
)
|
||||
|
||||
return "", nil, nil, err
|
||||
}
|
||||
|
||||
|
@ -245,7 +308,12 @@ func sharingWorker(
|
|||
|
||||
namespaceResponse, err := client.GetNamespace(ctx, namespaceRequest)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()),
|
||||
output,
|
||||
)
|
||||
|
||||
return "", nil, nil, err
|
||||
}
|
||||
|
||||
|
@ -256,9 +324,14 @@ var shareMachineCmd = &cobra.Command{
|
|||
Use: "share",
|
||||
Short: "Shares a node from the current namespace to the specified one",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output, machine, namespace, err := sharingWorker(cmd, args)
|
||||
output, machine, namespace, err := sharingWorker(cmd)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Failed to fetch namespace or machine: %s", err), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Failed to fetch namespace or machine: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -273,7 +346,12 @@ var shareMachineCmd = &cobra.Command{
|
|||
|
||||
response, err := client.ShareMachine(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error sharing node: %s", status.Convert(err).Message()), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error sharing node: %s", status.Convert(err).Message()),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -285,9 +363,14 @@ var unshareMachineCmd = &cobra.Command{
|
|||
Use: "unshare",
|
||||
Short: "Unshares a node from the specified namespace",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output, machine, namespace, err := sharingWorker(cmd, args)
|
||||
output, machine, namespace, err := sharingWorker(cmd)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Failed to fetch namespace or machine: %s", err), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Failed to fetch namespace or machine: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -302,7 +385,12 @@ var unshareMachineCmd = &cobra.Command{
|
|||
|
||||
response, err := client.UnshareMachine(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error unsharing node: %s", status.Convert(err).Message()), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error unsharing node: %s", status.Convert(err).Message()),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -310,8 +398,22 @@ var unshareMachineCmd = &cobra.Command{
|
|||
},
|
||||
}
|
||||
|
||||
func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.TableData, error) {
|
||||
d := pterm.TableData{{"ID", "Name", "NodeKey", "Namespace", "IP address", "Ephemeral", "Last seen", "Online"}}
|
||||
func nodesToPtables(
|
||||
currentNamespace string,
|
||||
machines []*v1.Machine,
|
||||
) (pterm.TableData, error) {
|
||||
tableData := pterm.TableData{
|
||||
{
|
||||
"ID",
|
||||
"Name",
|
||||
"NodeKey",
|
||||
"Namespace",
|
||||
"IP address",
|
||||
"Ephemeral",
|
||||
"Last seen",
|
||||
"Online",
|
||||
},
|
||||
}
|
||||
|
||||
for _, machine := range machines {
|
||||
var ephemeral bool
|
||||
|
@ -331,7 +433,9 @@ func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.Tabl
|
|||
nodeKey := tailcfg.NodeKey(nKey)
|
||||
|
||||
var online string
|
||||
if lastSeen.After(time.Now().Add(-5 * time.Minute)) { // TODO: Find a better way to reliably show if online
|
||||
if lastSeen.After(
|
||||
time.Now().Add(-5 * time.Minute),
|
||||
) { // TODO: Find a better way to reliably show if online
|
||||
online = pterm.LightGreen("true")
|
||||
} else {
|
||||
online = pterm.LightRed("false")
|
||||
|
@ -344,10 +448,10 @@ func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.Tabl
|
|||
// Shared into this namespace
|
||||
namespace = pterm.LightYellow(machine.Namespace.Name)
|
||||
}
|
||||
d = append(
|
||||
d,
|
||||
tableData = append(
|
||||
tableData,
|
||||
[]string{
|
||||
strconv.FormatUint(machine.Id, 10),
|
||||
strconv.FormatUint(machine.Id, headscale.Base10),
|
||||
machine.Name,
|
||||
nodeKey.ShortString(),
|
||||
namespace,
|
||||
|
@ -358,5 +462,6 @@ func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.Tabl
|
|||
},
|
||||
)
|
||||
}
|
||||
return d, nil
|
||||
|
||||
return tableData, nil
|
||||
}
|
||||
|
|
|
@ -12,6 +12,10 @@ import (
|
|||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultPreAuthKeyExpiry = 24 * time.Hour
|
||||
)
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(preauthkeysCmd)
|
||||
preauthkeysCmd.PersistentFlags().StringP("namespace", "n", "", "Namespace")
|
||||
|
@ -22,10 +26,12 @@ func init() {
|
|||
preauthkeysCmd.AddCommand(listPreAuthKeys)
|
||||
preauthkeysCmd.AddCommand(createPreAuthKeyCmd)
|
||||
preauthkeysCmd.AddCommand(expirePreAuthKeyCmd)
|
||||
createPreAuthKeyCmd.PersistentFlags().Bool("reusable", false, "Make the preauthkey reusable")
|
||||
createPreAuthKeyCmd.PersistentFlags().Bool("ephemeral", false, "Preauthkey for ephemeral nodes")
|
||||
createPreAuthKeyCmd.PersistentFlags().
|
||||
Bool("reusable", false, "Make the preauthkey reusable")
|
||||
createPreAuthKeyCmd.PersistentFlags().
|
||||
Bool("ephemeral", false, "Preauthkey for ephemeral nodes")
|
||||
createPreAuthKeyCmd.Flags().
|
||||
DurationP("expiration", "e", 24*time.Hour, "Human-readable expiration of the key (30m, 24h, 365d...)")
|
||||
DurationP("expiration", "e", DefaultPreAuthKeyExpiry, "Human-readable expiration of the key (30m, 24h, 365d...)")
|
||||
}
|
||||
|
||||
var preauthkeysCmd = &cobra.Command{
|
||||
|
@ -39,9 +45,10 @@ var listPreAuthKeys = &cobra.Command{
|
|||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
n, err := cmd.Flags().GetString("namespace")
|
||||
namespace, err := cmd.Flags().GetString("namespace")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -50,48 +57,61 @@ var listPreAuthKeys = &cobra.Command{
|
|||
defer conn.Close()
|
||||
|
||||
request := &v1.ListPreAuthKeysRequest{
|
||||
Namespace: n,
|
||||
Namespace: namespace,
|
||||
}
|
||||
|
||||
response, err := client.ListPreAuthKeys(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting the list of keys: %s", err), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting the list of keys: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if output != "" {
|
||||
SuccessOutput(response.PreAuthKeys, "", output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
d := pterm.TableData{{"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"}}
|
||||
for _, k := range response.PreAuthKeys {
|
||||
tableData := pterm.TableData{
|
||||
{"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"},
|
||||
}
|
||||
for _, key := range response.PreAuthKeys {
|
||||
expiration := "-"
|
||||
if k.GetExpiration() != nil {
|
||||
expiration = k.Expiration.AsTime().Format("2006-01-02 15:04:05")
|
||||
if key.GetExpiration() != nil {
|
||||
expiration = key.Expiration.AsTime().Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
var reusable string
|
||||
if k.GetEphemeral() {
|
||||
if key.GetEphemeral() {
|
||||
reusable = "N/A"
|
||||
} else {
|
||||
reusable = fmt.Sprintf("%v", k.GetReusable())
|
||||
reusable = fmt.Sprintf("%v", key.GetReusable())
|
||||
}
|
||||
|
||||
d = append(d, []string{
|
||||
k.GetId(),
|
||||
k.GetKey(),
|
||||
tableData = append(tableData, []string{
|
||||
key.GetId(),
|
||||
key.GetKey(),
|
||||
reusable,
|
||||
strconv.FormatBool(k.GetEphemeral()),
|
||||
strconv.FormatBool(k.GetUsed()),
|
||||
strconv.FormatBool(key.GetEphemeral()),
|
||||
strconv.FormatBool(key.GetUsed()),
|
||||
expiration,
|
||||
k.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
|
||||
key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
|
||||
})
|
||||
|
||||
}
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
},
|
||||
|
@ -106,6 +126,7 @@ var createPreAuthKeyCmd = &cobra.Command{
|
|||
namespace, err := cmd.Flags().GetString("namespace")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -139,7 +160,12 @@ var createPreAuthKeyCmd = &cobra.Command{
|
|||
|
||||
response, err := client.CreatePreAuthKey(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -152,8 +178,9 @@ var expirePreAuthKeyCmd = &cobra.Command{
|
|||
Short: "Expire a preauthkey",
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
if len(args) < 1 {
|
||||
return fmt.Errorf("missing parameters")
|
||||
return errMissingParameter
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
|
@ -161,6 +188,7 @@ var expirePreAuthKeyCmd = &cobra.Command{
|
|||
namespace, err := cmd.Flags().GetString("namespace")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -175,7 +203,12 @@ var expirePreAuthKeyCmd = &cobra.Command{
|
|||
|
||||
response, err := client.ExpirePreAuthKey(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -10,7 +10,8 @@ import (
|
|||
func init() {
|
||||
rootCmd.PersistentFlags().
|
||||
StringP("output", "o", "", "Output format. Empty for human-readable, 'json', 'json-line' or 'yaml'")
|
||||
rootCmd.PersistentFlags().Bool("force", false, "Disable prompts and forces the execution")
|
||||
rootCmd.PersistentFlags().
|
||||
Bool("force", false, "Disable prompts and forces the execution")
|
||||
}
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
|
|
|
@ -21,7 +21,8 @@ func init() {
|
|||
}
|
||||
routesCmd.AddCommand(listRoutesCmd)
|
||||
|
||||
enableRouteCmd.Flags().StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to enable")
|
||||
enableRouteCmd.Flags().
|
||||
StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to enable")
|
||||
enableRouteCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
err = enableRouteCmd.MarkFlagRequired("identifier")
|
||||
if err != nil {
|
||||
|
@ -44,9 +45,14 @@ var listRoutesCmd = &cobra.Command{
|
|||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
machineId, err := cmd.Flags().GetUint64("identifier")
|
||||
machineID, err := cmd.Flags().GetUint64("identifier")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting machine id from flag: %s", err), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting machine id from flag: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -55,29 +61,41 @@ var listRoutesCmd = &cobra.Command{
|
|||
defer conn.Close()
|
||||
|
||||
request := &v1.GetMachineRouteRequest{
|
||||
MachineId: machineId,
|
||||
MachineId: machineID,
|
||||
}
|
||||
|
||||
response, err := client.GetMachineRoute(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if output != "" {
|
||||
SuccessOutput(response.Routes, "", output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
d := routesToPtables(response.Routes)
|
||||
tableData := routesToPtables(response.Routes)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
},
|
||||
|
@ -93,15 +111,26 @@ omit the route you do not want to enable.
|
|||
`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
machineId, err := cmd.Flags().GetUint64("identifier")
|
||||
|
||||
machineID, err := cmd.Flags().GetUint64("identifier")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting machine id from flag: %s", err), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting machine id from flag: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
routes, err := cmd.Flags().GetStringSlice("route")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting routes from flag: %s", err), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting routes from flag: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -110,45 +139,61 @@ omit the route you do not want to enable.
|
|||
defer conn.Close()
|
||||
|
||||
request := &v1.EnableMachineRoutesRequest{
|
||||
MachineId: machineId,
|
||||
MachineId: machineID,
|
||||
Routes: routes,
|
||||
}
|
||||
|
||||
response, err := client.EnableMachineRoutes(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Cannot register machine: %s\n", status.Convert(err).Message()), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Cannot register machine: %s\n",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if output != "" {
|
||||
SuccessOutput(response.Routes, "", output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
d := routesToPtables(response.Routes)
|
||||
tableData := routesToPtables(response.Routes)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// routesToPtables converts the list of routes to a nice table
|
||||
// routesToPtables converts the list of routes to a nice table.
|
||||
func routesToPtables(routes *v1.Routes) pterm.TableData {
|
||||
d := pterm.TableData{{"Route", "Enabled"}}
|
||||
tableData := pterm.TableData{{"Route", "Enabled"}}
|
||||
|
||||
for _, route := range routes.GetAdvertisedRoutes() {
|
||||
enabled := isStringInSlice(routes.EnabledRoutes, route)
|
||||
|
||||
d = append(d, []string{route, strconv.FormatBool(enabled)})
|
||||
tableData = append(tableData, []string{route, strconv.FormatBool(enabled)})
|
||||
}
|
||||
return d
|
||||
|
||||
return tableData
|
||||
}
|
||||
|
||||
func isStringInSlice(strs []string, s string) bool {
|
||||
|
|
|
@ -52,9 +52,8 @@ func LoadConfig(path string) error {
|
|||
viper.SetDefault("cli.insecure", false)
|
||||
viper.SetDefault("cli.timeout", "5s")
|
||||
|
||||
err := viper.ReadInConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Fatal error reading config file: %s \n", err)
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
return fmt.Errorf("fatal error reading config file: %w", err)
|
||||
}
|
||||
|
||||
// Collect any validation errors and return them all at once
|
||||
|
@ -82,6 +81,7 @@ func LoadConfig(path string) error {
|
|||
errorText += "Fatal config error: server_url must start with https:// or http://\n"
|
||||
}
|
||||
if errorText != "" {
|
||||
//nolint
|
||||
return errors.New(strings.TrimSuffix(errorText, "\n"))
|
||||
} else {
|
||||
return nil
|
||||
|
@ -149,9 +149,14 @@ func GetDNSConfig() (*tailcfg.DNSConfig, string) {
|
|||
if viper.IsSet("dns_config.restricted_nameservers") {
|
||||
if len(dnsConfig.Nameservers) > 0 {
|
||||
dnsConfig.Routes = make(map[string][]dnstype.Resolver)
|
||||
restrictedDNS := viper.GetStringMapStringSlice("dns_config.restricted_nameservers")
|
||||
restrictedDNS := viper.GetStringMapStringSlice(
|
||||
"dns_config.restricted_nameservers",
|
||||
)
|
||||
for domain, restrictedNameservers := range restrictedDNS {
|
||||
restrictedResolvers := make([]dnstype.Resolver, len(restrictedNameservers))
|
||||
restrictedResolvers := make(
|
||||
[]dnstype.Resolver,
|
||||
len(restrictedNameservers),
|
||||
)
|
||||
for index, nameserverStr := range restrictedNameservers {
|
||||
nameserver, err := netaddr.ParseIP(nameserverStr)
|
||||
if err != nil {
|
||||
|
@ -208,6 +213,7 @@ func absPath(path string) string {
|
|||
path = filepath.Join(dir, path)
|
||||
}
|
||||
}
|
||||
|
||||
return path
|
||||
}
|
||||
|
||||
|
@ -219,7 +225,9 @@ func getHeadscaleConfig() headscale.Config {
|
|||
"10h",
|
||||
) // use 10h here because it is the length of a standard business day plus a small amount of leeway
|
||||
if viper.GetDuration("max_machine_registration_duration") >= time.Second {
|
||||
maxMachineRegistrationDuration = viper.GetDuration("max_machine_registration_duration")
|
||||
maxMachineRegistrationDuration = viper.GetDuration(
|
||||
"max_machine_registration_duration",
|
||||
)
|
||||
}
|
||||
|
||||
// defaultMachineRegistrationDuration is the default time assigned to a machine registration if one is not
|
||||
|
@ -229,7 +237,9 @@ func getHeadscaleConfig() headscale.Config {
|
|||
"8h",
|
||||
) // use 8h here because it's the length of a standard business day
|
||||
if viper.GetDuration("default_machine_registration_duration") >= time.Second {
|
||||
defaultMachineRegistrationDuration = viper.GetDuration("default_machine_registration_duration")
|
||||
defaultMachineRegistrationDuration = viper.GetDuration(
|
||||
"default_machine_registration_duration",
|
||||
)
|
||||
}
|
||||
|
||||
dnsConfig, baseDomain := GetDNSConfig()
|
||||
|
@ -244,7 +254,9 @@ func getHeadscaleConfig() headscale.Config {
|
|||
|
||||
DERP: derpConfig,
|
||||
|
||||
EphemeralNodeInactivityTimeout: viper.GetDuration("ephemeral_node_inactivity_timeout"),
|
||||
EphemeralNodeInactivityTimeout: viper.GetDuration(
|
||||
"ephemeral_node_inactivity_timeout",
|
||||
),
|
||||
|
||||
DBtype: viper.GetString("db_type"),
|
||||
DBpath: absPath(viper.GetString("db_path")),
|
||||
|
@ -254,9 +266,11 @@ func getHeadscaleConfig() headscale.Config {
|
|||
DBuser: viper.GetString("db_user"),
|
||||
DBpass: viper.GetString("db_pass"),
|
||||
|
||||
TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"),
|
||||
TLSLetsEncryptListen: viper.GetString("tls_letsencrypt_listen"),
|
||||
TLSLetsEncryptCacheDir: absPath(viper.GetString("tls_letsencrypt_cache_dir")),
|
||||
TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"),
|
||||
TLSLetsEncryptListen: viper.GetString("tls_letsencrypt_listen"),
|
||||
TLSLetsEncryptCacheDir: absPath(
|
||||
viper.GetString("tls_letsencrypt_cache_dir"),
|
||||
),
|
||||
TLSLetsEncryptChallengeType: viper.GetString("tls_letsencrypt_challenge_type"),
|
||||
|
||||
TLSCertPath: absPath(viper.GetString("tls_cert_path")),
|
||||
|
@ -292,11 +306,14 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
|||
// to avoid races
|
||||
minInactivityTimeout, _ := time.ParseDuration("65s")
|
||||
if viper.GetDuration("ephemeral_node_inactivity_timeout") <= minInactivityTimeout {
|
||||
// TODO: Find a better way to return this text
|
||||
//nolint
|
||||
err := fmt.Errorf(
|
||||
"ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s\n",
|
||||
"ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s",
|
||||
viper.GetString("ephemeral_node_inactivity_timeout"),
|
||||
minInactivityTimeout,
|
||||
)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -304,7 +321,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
|||
|
||||
cfg.OIDC.MatchMap = loadOIDCMatchMap()
|
||||
|
||||
h, err := headscale.NewHeadscale(cfg)
|
||||
app, err := headscale.NewHeadscale(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -313,7 +330,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
|||
|
||||
if viper.GetString("acl_policy_path") != "" {
|
||||
aclPath := absPath(viper.GetString("acl_policy_path"))
|
||||
err = h.LoadACLPolicy(aclPath)
|
||||
err = app.LoadACLPolicy(aclPath)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("path", aclPath).
|
||||
|
@ -322,7 +339,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
|||
}
|
||||
}
|
||||
|
||||
return h, nil
|
||||
return app, nil
|
||||
}
|
||||
|
||||
func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
|
||||
|
@ -342,7 +359,6 @@ func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.
|
|||
|
||||
// If the address is not set, we assume that we are on the server hosting headscale.
|
||||
if address == "" {
|
||||
|
||||
log.Debug().
|
||||
Str("socket", cfg.UnixSocket).
|
||||
Msgf("HEADSCALE_CLI_ADDRESS environment is not set, connecting to unix socket.")
|
||||
|
@ -402,10 +418,13 @@ func SuccessOutput(result interface{}, override string, outputFormat string) {
|
|||
log.Fatal().Err(err)
|
||||
}
|
||||
default:
|
||||
//nolint
|
||||
fmt.Println(override)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
//nolint
|
||||
fmt.Println(string(j))
|
||||
}
|
||||
|
||||
|
@ -423,6 +442,7 @@ func HasMachineOutputFlag() bool {
|
|||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
|
@ -431,7 +451,10 @@ type tokenAuth struct {
|
|||
}
|
||||
|
||||
// Return value is mapped to request headers.
|
||||
func (t tokenAuth) GetRequestMetadata(ctx context.Context, in ...string) (map[string]string, error) {
|
||||
func (t tokenAuth) GetRequestMetadata(
|
||||
ctx context.Context,
|
||||
in ...string,
|
||||
) (map[string]string, error) {
|
||||
return map[string]string{
|
||||
"authorization": "Bearer " + t.token,
|
||||
}, nil
|
||||
|
|
|
@ -23,6 +23,8 @@ func main() {
|
|||
colors = true
|
||||
case termcolor.LevelBasic:
|
||||
colors = true
|
||||
case termcolor.LevelNone:
|
||||
colors = false
|
||||
default:
|
||||
// no color, return text as is.
|
||||
colors = false
|
||||
|
@ -41,8 +43,7 @@ func main() {
|
|||
NoColor: !colors,
|
||||
})
|
||||
|
||||
err := cli.LoadConfig("")
|
||||
if err != nil {
|
||||
if err := cli.LoadConfig(""); err != nil {
|
||||
log.Fatal().Err(err)
|
||||
}
|
||||
|
||||
|
@ -63,13 +64,15 @@ func main() {
|
|||
}
|
||||
|
||||
if !viper.GetBool("disable_check_updates") && !machineOutput {
|
||||
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") && cli.Version != "dev" {
|
||||
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
|
||||
cli.Version != "dev" {
|
||||
githubTag := &latest.GithubTag{
|
||||
Owner: "juanfont",
|
||||
Repository: "headscale",
|
||||
}
|
||||
res, err := latest.Check(githubTag, cli.Version)
|
||||
if err == nil && res.Outdated {
|
||||
//nolint
|
||||
fmt.Printf(
|
||||
"An updated version of Headscale has been found (%s vs. your current %s). Check it out https://github.com/juanfont/headscale/releases\n",
|
||||
res.Current,
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
@ -40,7 +39,10 @@ func (*Suite) TestConfigLoading(c *check.C) {
|
|||
}
|
||||
|
||||
// Symlink the example config file
|
||||
err = os.Symlink(filepath.Clean(path+"/../../config-example.yaml"), filepath.Join(tmpDir, "config.yaml"))
|
||||
err = os.Symlink(
|
||||
filepath.Clean(path+"/../../config-example.yaml"),
|
||||
filepath.Join(tmpDir, "config.yaml"),
|
||||
)
|
||||
if err != nil {
|
||||
c.Fatal(err)
|
||||
}
|
||||
|
@ -74,7 +76,10 @@ func (*Suite) TestDNSConfigLoading(c *check.C) {
|
|||
}
|
||||
|
||||
// Symlink the example config file
|
||||
err = os.Symlink(filepath.Clean(path+"/../../config-example.yaml"), filepath.Join(tmpDir, "config.yaml"))
|
||||
err = os.Symlink(
|
||||
filepath.Clean(path+"/../../config-example.yaml"),
|
||||
filepath.Join(tmpDir, "config.yaml"),
|
||||
)
|
||||
if err != nil {
|
||||
c.Fatal(err)
|
||||
}
|
||||
|
@ -94,7 +99,7 @@ func (*Suite) TestDNSConfigLoading(c *check.C) {
|
|||
func writeConfig(c *check.C, tmpDir string, configYaml []byte) {
|
||||
// Populate a custom config file
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
err := ioutil.WriteFile(configFile, configYaml, 0o644)
|
||||
err := ioutil.WriteFile(configFile, configYaml, 0o600)
|
||||
if err != nil {
|
||||
c.Fatalf("Couldn't write file %s", configFile)
|
||||
}
|
||||
|
@ -106,7 +111,6 @@ func (*Suite) TestTLSConfigValidation(c *check.C) {
|
|||
c.Fatal(err)
|
||||
}
|
||||
// defer os.RemoveAll(tmpDir)
|
||||
fmt.Println(tmpDir)
|
||||
|
||||
configYaml := []byte(
|
||||
"---\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"\"\ntls_cert_path: \"abc.pem\"",
|
||||
|
@ -128,8 +132,11 @@ func (*Suite) TestTLSConfigValidation(c *check.C) {
|
|||
check.Matches,
|
||||
".*Fatal config error: the only supported values for tls_letsencrypt_challenge_type are.*",
|
||||
)
|
||||
c.Assert(tmp, check.Matches, ".*Fatal config error: server_url must start with https:// or http://.*")
|
||||
fmt.Println(tmp)
|
||||
c.Assert(
|
||||
tmp,
|
||||
check.Matches,
|
||||
".*Fatal config error: server_url must start with https:// or http://.*",
|
||||
)
|
||||
|
||||
// Check configuration validation errors (2)
|
||||
configYaml = []byte(
|
||||
|
|
35
db.go
35
db.go
|
@ -9,7 +9,10 @@ import (
|
|||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
const dbVersion = "1"
|
||||
const (
|
||||
dbVersion = "1"
|
||||
errValueNotFound = Error("not found")
|
||||
)
|
||||
|
||||
// KV is a key-value store in a psql table. For future use...
|
||||
type KV struct {
|
||||
|
@ -24,7 +27,7 @@ func (h *Headscale) initDB() error {
|
|||
}
|
||||
h.db = db
|
||||
|
||||
if h.dbType == "postgres" {
|
||||
if h.dbType == Postgres {
|
||||
db.Exec("create extension if not exists \"uuid-ossp\";")
|
||||
}
|
||||
err = db.AutoMigrate(&Machine{})
|
||||
|
@ -50,6 +53,7 @@ func (h *Headscale) initDB() error {
|
|||
}
|
||||
|
||||
err = h.setValue("db_version", dbVersion)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -65,12 +69,12 @@ func (h *Headscale) openDB() (*gorm.DB, error) {
|
|||
}
|
||||
|
||||
switch h.dbType {
|
||||
case "sqlite3":
|
||||
case Sqlite:
|
||||
db, err = gorm.Open(sqlite.Open(h.dbString), &gorm.Config{
|
||||
DisableForeignKeyConstraintWhenMigrating: true,
|
||||
Logger: log,
|
||||
})
|
||||
case "postgres":
|
||||
case Postgres:
|
||||
db, err = gorm.Open(postgres.Open(h.dbString), &gorm.Config{
|
||||
DisableForeignKeyConstraintWhenMigrating: true,
|
||||
Logger: log,
|
||||
|
@ -84,28 +88,33 @@ func (h *Headscale) openDB() (*gorm.DB, error) {
|
|||
return db, nil
|
||||
}
|
||||
|
||||
// getValue returns the value for the given key in KV
|
||||
// getValue returns the value for the given key in KV.
|
||||
func (h *Headscale) getValue(key string) (string, error) {
|
||||
var row KV
|
||||
if result := h.db.First(&row, "key = ?", key); errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", errors.New("not found")
|
||||
if result := h.db.First(&row, "key = ?", key); errors.Is(
|
||||
result.Error,
|
||||
gorm.ErrRecordNotFound,
|
||||
) {
|
||||
return "", errValueNotFound
|
||||
}
|
||||
|
||||
return row.Value, nil
|
||||
}
|
||||
|
||||
// setValue sets value for the given key in KV
|
||||
// setValue sets value for the given key in KV.
|
||||
func (h *Headscale) setValue(key string, value string) error {
|
||||
kv := KV{
|
||||
keyValue := KV{
|
||||
Key: key,
|
||||
Value: value,
|
||||
}
|
||||
|
||||
_, err := h.getValue(key)
|
||||
if err == nil {
|
||||
h.db.Model(&kv).Where("key = ?", key).Update("value", value)
|
||||
if _, err := h.getValue(key); err == nil {
|
||||
h.db.Model(&keyValue).Where("key = ?", key).Update("value", value)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
h.db.Create(kv)
|
||||
h.db.Create(keyValue)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
# If you plan to somehow use headscale, please deploy your own DERP infra: https://tailscale.com/kb/1118/custom-derp-servers/
|
||||
regions:
|
||||
regions:
|
||||
900:
|
||||
regionid: 900
|
||||
regioncode: custom
|
||||
regionname: My Region
|
||||
nodes:
|
||||
- name: 1a
|
||||
regionid: 1
|
||||
hostname: myderp.mydomain.no
|
||||
ipv4: 123.123.123.123
|
||||
ipv6: "2604:a880:400:d1::828:b001"
|
||||
stunport: 0
|
||||
stunonly: false
|
||||
derptestport: 0
|
||||
- name: 1a
|
||||
regionid: 1
|
||||
hostname: myderp.mydomain.no
|
||||
ipv4: 123.123.123.123
|
||||
ipv6: "2604:a880:400:d1::828:b001"
|
||||
stunport: 0
|
||||
stunonly: false
|
||||
derptestport: 0
|
||||
|
|
24
derp.go
24
derp.go
|
@ -1,6 +1,7 @@
|
|||
package headscale
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
|
@ -10,9 +11,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"gopkg.in/yaml.v2"
|
||||
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
|
@ -28,14 +27,24 @@ func loadDERPMapFromPath(path string) (*tailcfg.DERPMap, error) {
|
|||
return nil, err
|
||||
}
|
||||
err = yaml.Unmarshal(b, &derpMap)
|
||||
|
||||
return &derpMap, err
|
||||
}
|
||||
|
||||
func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
|
||||
client := http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
ctx, cancel := context.WithTimeout(context.Background(), HTTPReadTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", addr.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := client.Get(addr.String())
|
||||
|
||||
client := http.Client{
|
||||
Timeout: HTTPReadTimeout,
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -48,6 +57,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
|
|||
|
||||
var derpMap tailcfg.DERPMap
|
||||
err = json.Unmarshal(body, &derpMap)
|
||||
|
||||
return &derpMap, err
|
||||
}
|
||||
|
||||
|
@ -55,7 +65,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
|
|||
// DERPMap, it will _only_ look at the Regions, an integer.
|
||||
// If a region exists in two of the given DERPMaps, the region
|
||||
// form the _last_ DERPMap will be preserved.
|
||||
// An empty DERPMap list will result in a DERPMap with no regions
|
||||
// An empty DERPMap list will result in a DERPMap with no regions.
|
||||
func mergeDERPMaps(derpMaps []*tailcfg.DERPMap) *tailcfg.DERPMap {
|
||||
result := tailcfg.DERPMap{
|
||||
OmitDefaultRegions: false,
|
||||
|
@ -86,6 +96,7 @@ func GetDERPMap(cfg DERPConfig) *tailcfg.DERPMap {
|
|||
Str("path", path).
|
||||
Err(err).
|
||||
Msg("Could not load DERP map from path")
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
|
@ -104,6 +115,7 @@ func GetDERPMap(cfg DERPConfig) *tailcfg.DERPMap {
|
|||
Str("url", addr.String()).
|
||||
Err(err).
|
||||
Msg("Could not load DERP map from path")
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
|
|
34
dns.go
34
dns.go
|
@ -10,6 +10,10 @@ import (
|
|||
"tailscale.com/util/dnsname"
|
||||
)
|
||||
|
||||
const (
|
||||
ByteSize = 8
|
||||
)
|
||||
|
||||
// generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`.
|
||||
// This list of reverse DNS entries instructs the OS on what subnets and domains the Tailscale embedded DNS
|
||||
// server (listening in 100.100.100.100 udp/53) should be used for.
|
||||
|
@ -30,7 +34,9 @@ import (
|
|||
|
||||
// From the netmask we can find out the wildcard bits (the bits that are not set in the netmask).
|
||||
// This allows us to then calculate the subnets included in the subsequent class block and generate the entries.
|
||||
func generateMagicDNSRootDomains(ipPrefix netaddr.IPPrefix, baseDomain string) ([]dnsname.FQDN, error) {
|
||||
func generateMagicDNSRootDomains(
|
||||
ipPrefix netaddr.IPPrefix,
|
||||
) []dnsname.FQDN {
|
||||
// TODO(juanfont): we are not handing out IPv6 addresses yet
|
||||
// and in fact this is Tailscale.com's range (note the fd7a:115c:a1e0: range in the fc00::/7 network)
|
||||
ipv6base := dnsname.FQDN("0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa.")
|
||||
|
@ -41,15 +47,15 @@ func generateMagicDNSRootDomains(ipPrefix netaddr.IPPrefix, baseDomain string) (
|
|||
maskBits, _ := netRange.Mask.Size()
|
||||
|
||||
// lastOctet is the last IP byte covered by the mask
|
||||
lastOctet := maskBits / 8
|
||||
lastOctet := maskBits / ByteSize
|
||||
|
||||
// wildcardBits is the number of bits not under the mask in the lastOctet
|
||||
wildcardBits := 8 - maskBits%8
|
||||
wildcardBits := ByteSize - maskBits%ByteSize
|
||||
|
||||
// min is the value in the lastOctet byte of the IP
|
||||
// max is basically 2^wildcardBits - i.e., the value when all the wildcardBits are set to 1
|
||||
min := uint(netRange.IP[lastOctet])
|
||||
max := uint((min + 1<<uint(wildcardBits)) - 1)
|
||||
max := (min + 1<<uint(wildcardBits)) - 1
|
||||
|
||||
// here we generate the base domain (e.g., 100.in-addr.arpa., 16.172.in-addr.arpa., etc.)
|
||||
rdnsSlice := []string{}
|
||||
|
@ -66,18 +72,27 @@ func generateMagicDNSRootDomains(ipPrefix netaddr.IPPrefix, baseDomain string) (
|
|||
}
|
||||
fqdns = append(fqdns, fqdn)
|
||||
}
|
||||
return fqdns, nil
|
||||
|
||||
return fqdns
|
||||
}
|
||||
|
||||
func getMapResponseDNSConfig(dnsConfigOrig *tailcfg.DNSConfig, baseDomain string, m Machine, peers Machines) (*tailcfg.DNSConfig, error) {
|
||||
func getMapResponseDNSConfig(
|
||||
dnsConfigOrig *tailcfg.DNSConfig,
|
||||
baseDomain string,
|
||||
machine Machine,
|
||||
peers Machines,
|
||||
) *tailcfg.DNSConfig {
|
||||
var dnsConfig *tailcfg.DNSConfig
|
||||
if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled
|
||||
// Only inject the Search Domain of the current namespace - shared nodes should use their full FQDN
|
||||
dnsConfig = dnsConfigOrig.Clone()
|
||||
dnsConfig.Domains = append(dnsConfig.Domains, fmt.Sprintf("%s.%s", m.Namespace.Name, baseDomain))
|
||||
dnsConfig.Domains = append(
|
||||
dnsConfig.Domains,
|
||||
fmt.Sprintf("%s.%s", machine.Namespace.Name, baseDomain),
|
||||
)
|
||||
|
||||
namespaceSet := set.New(set.ThreadSafe)
|
||||
namespaceSet.Add(m.Namespace)
|
||||
namespaceSet.Add(machine.Namespace)
|
||||
for _, p := range peers {
|
||||
namespaceSet.Add(p.Namespace)
|
||||
}
|
||||
|
@ -88,5 +103,6 @@ func getMapResponseDNSConfig(dnsConfigOrig *tailcfg.DNSConfig, baseDomain string
|
|||
} else {
|
||||
dnsConfig = dnsConfigOrig
|
||||
}
|
||||
return dnsConfig, nil
|
||||
|
||||
return dnsConfig
|
||||
}
|
||||
|
|
211
dns_test.go
211
dns_test.go
|
@ -11,13 +11,13 @@ import (
|
|||
|
||||
func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
|
||||
prefix := netaddr.MustParseIPPrefix("100.64.0.0/10")
|
||||
domains, err := generateMagicDNSRootDomains(prefix, "foobar.headscale.net")
|
||||
c.Assert(err, check.IsNil)
|
||||
domains := generateMagicDNSRootDomains(prefix)
|
||||
|
||||
found := false
|
||||
for _, domain := range domains {
|
||||
if domain == "64.100.in-addr.arpa." {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -27,6 +27,7 @@ func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
|
|||
for _, domain := range domains {
|
||||
if domain == "100.100.in-addr.arpa." {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -36,6 +37,7 @@ func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
|
|||
for _, domain := range domains {
|
||||
if domain == "127.100.in-addr.arpa." {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -44,13 +46,13 @@ func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
|
|||
|
||||
func (s *Suite) TestMagicDNSRootDomains172(c *check.C) {
|
||||
prefix := netaddr.MustParseIPPrefix("172.16.0.0/16")
|
||||
domains, err := generateMagicDNSRootDomains(prefix, "headscale.net")
|
||||
c.Assert(err, check.IsNil)
|
||||
domains := generateMagicDNSRootDomains(prefix)
|
||||
|
||||
found := false
|
||||
for _, domain := range domains {
|
||||
if domain == "0.16.172.in-addr.arpa." {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -60,6 +62,7 @@ func (s *Suite) TestMagicDNSRootDomains172(c *check.C) {
|
|||
for _, domain := range domains {
|
||||
if domain == "255.16.172.in-addr.arpa." {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -67,100 +70,120 @@ func (s *Suite) TestMagicDNSRootDomains172(c *check.C) {
|
|||
}
|
||||
|
||||
func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||
n1, err := h.CreateNamespace("shared1")
|
||||
namespaceShared1, err := app.CreateNamespace("shared1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
n2, err := h.CreateNamespace("shared2")
|
||||
namespaceShared2, err := app.CreateNamespace("shared2")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
n3, err := h.CreateNamespace("shared3")
|
||||
namespaceShared3, err := app.CreateNamespace("shared3")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak1n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
||||
preAuthKeyInShared1, err := app.CreatePreAuthKey(
|
||||
namespaceShared1.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak2n2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
|
||||
preAuthKeyInShared2, err := app.CreatePreAuthKey(
|
||||
namespaceShared2.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak3n3, err := h.CreatePreAuthKey(n3.Name, false, false, nil)
|
||||
preAuthKeyInShared3, err := app.CreatePreAuthKey(
|
||||
namespaceShared3.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak4n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
||||
PreAuthKey2InShared1, err := app.CreatePreAuthKey(
|
||||
namespaceShared1.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
|
||||
_, err = app.GetMachine(namespaceShared1.Name, "test_get_shared_nodes_1")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
m1 := &Machine{
|
||||
machineInShared1 := &Machine{
|
||||
ID: 1,
|
||||
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
Name: "test_get_shared_nodes_1",
|
||||
NamespaceID: n1.ID,
|
||||
Namespace: *n1,
|
||||
NamespaceID: namespaceShared1.ID,
|
||||
Namespace: *namespaceShared1,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
IPAddress: "100.64.0.1",
|
||||
AuthKeyID: uint(pak1n1.ID),
|
||||
AuthKeyID: uint(preAuthKeyInShared1.ID),
|
||||
}
|
||||
h.db.Save(m1)
|
||||
app.db.Save(machineInShared1)
|
||||
|
||||
_, err = h.GetMachine(n1.Name, m1.Name)
|
||||
_, err = app.GetMachine(namespaceShared1.Name, machineInShared1.Name)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
m2 := &Machine{
|
||||
machineInShared2 := &Machine{
|
||||
ID: 2,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Name: "test_get_shared_nodes_2",
|
||||
NamespaceID: n2.ID,
|
||||
Namespace: *n2,
|
||||
NamespaceID: namespaceShared2.ID,
|
||||
Namespace: *namespaceShared2,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
IPAddress: "100.64.0.2",
|
||||
AuthKeyID: uint(pak2n2.ID),
|
||||
AuthKeyID: uint(preAuthKeyInShared2.ID),
|
||||
}
|
||||
h.db.Save(m2)
|
||||
app.db.Save(machineInShared2)
|
||||
|
||||
_, err = h.GetMachine(n2.Name, m2.Name)
|
||||
_, err = app.GetMachine(namespaceShared2.Name, machineInShared2.Name)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
m3 := &Machine{
|
||||
machineInShared3 := &Machine{
|
||||
ID: 3,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Name: "test_get_shared_nodes_3",
|
||||
NamespaceID: n3.ID,
|
||||
Namespace: *n3,
|
||||
NamespaceID: namespaceShared3.ID,
|
||||
Namespace: *namespaceShared3,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
IPAddress: "100.64.0.3",
|
||||
AuthKeyID: uint(pak3n3.ID),
|
||||
AuthKeyID: uint(preAuthKeyInShared3.ID),
|
||||
}
|
||||
h.db.Save(m3)
|
||||
app.db.Save(machineInShared3)
|
||||
|
||||
_, err = h.GetMachine(n3.Name, m3.Name)
|
||||
_, err = app.GetMachine(namespaceShared3.Name, machineInShared3.Name)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
m4 := &Machine{
|
||||
machine2InShared1 := &Machine{
|
||||
ID: 4,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Name: "test_get_shared_nodes_4",
|
||||
NamespaceID: n1.ID,
|
||||
Namespace: *n1,
|
||||
NamespaceID: namespaceShared1.ID,
|
||||
Namespace: *namespaceShared1,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
IPAddress: "100.64.0.4",
|
||||
AuthKeyID: uint(pak4n1.ID),
|
||||
AuthKeyID: uint(PreAuthKey2InShared1.ID),
|
||||
}
|
||||
h.db.Save(m4)
|
||||
app.db.Save(machine2InShared1)
|
||||
|
||||
err = h.AddSharedMachineToNamespace(m2, n1)
|
||||
err = app.AddSharedMachineToNamespace(machineInShared2, namespaceShared1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
baseDomain := "foobar.headscale.net"
|
||||
|
@ -170,122 +193,146 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
|||
Proxied: true,
|
||||
}
|
||||
|
||||
m1peers, err := h.getPeers(m1)
|
||||
peersOfMachineInShared1, err := app.getPeers(machineInShared1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
dnsConfig, err := getMapResponseDNSConfig(&dnsConfigOrig, baseDomain, *m1, m1peers)
|
||||
c.Assert(err, check.IsNil)
|
||||
dnsConfig := getMapResponseDNSConfig(
|
||||
&dnsConfigOrig,
|
||||
baseDomain,
|
||||
*machineInShared1,
|
||||
peersOfMachineInShared1,
|
||||
)
|
||||
c.Assert(dnsConfig, check.NotNil)
|
||||
c.Assert(len(dnsConfig.Routes), check.Equals, 2)
|
||||
|
||||
routeN1 := fmt.Sprintf("%s.%s", n1.Name, baseDomain)
|
||||
_, ok := dnsConfig.Routes[routeN1]
|
||||
domainRouteShared1 := fmt.Sprintf("%s.%s", namespaceShared1.Name, baseDomain)
|
||||
_, ok := dnsConfig.Routes[domainRouteShared1]
|
||||
c.Assert(ok, check.Equals, true)
|
||||
|
||||
routeN2 := fmt.Sprintf("%s.%s", n2.Name, baseDomain)
|
||||
_, ok = dnsConfig.Routes[routeN2]
|
||||
domainRouteShared2 := fmt.Sprintf("%s.%s", namespaceShared2.Name, baseDomain)
|
||||
_, ok = dnsConfig.Routes[domainRouteShared2]
|
||||
c.Assert(ok, check.Equals, true)
|
||||
|
||||
routeN3 := fmt.Sprintf("%s.%s", n3.Name, baseDomain)
|
||||
_, ok = dnsConfig.Routes[routeN3]
|
||||
domainRouteShared3 := fmt.Sprintf("%s.%s", namespaceShared3.Name, baseDomain)
|
||||
_, ok = dnsConfig.Routes[domainRouteShared3]
|
||||
c.Assert(ok, check.Equals, false)
|
||||
}
|
||||
|
||||
func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||
n1, err := h.CreateNamespace("shared1")
|
||||
namespaceShared1, err := app.CreateNamespace("shared1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
n2, err := h.CreateNamespace("shared2")
|
||||
namespaceShared2, err := app.CreateNamespace("shared2")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
n3, err := h.CreateNamespace("shared3")
|
||||
namespaceShared3, err := app.CreateNamespace("shared3")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak1n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
||||
preAuthKeyInShared1, err := app.CreatePreAuthKey(
|
||||
namespaceShared1.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak2n2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
|
||||
preAuthKeyInShared2, err := app.CreatePreAuthKey(
|
||||
namespaceShared2.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak3n3, err := h.CreatePreAuthKey(n3.Name, false, false, nil)
|
||||
preAuthKeyInShared3, err := app.CreatePreAuthKey(
|
||||
namespaceShared3.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak4n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
||||
preAuthKey2InShared1, err := app.CreatePreAuthKey(
|
||||
namespaceShared1.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
|
||||
_, err = app.GetMachine(namespaceShared1.Name, "test_get_shared_nodes_1")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
m1 := &Machine{
|
||||
machineInShared1 := &Machine{
|
||||
ID: 1,
|
||||
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
Name: "test_get_shared_nodes_1",
|
||||
NamespaceID: n1.ID,
|
||||
Namespace: *n1,
|
||||
NamespaceID: namespaceShared1.ID,
|
||||
Namespace: *namespaceShared1,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
IPAddress: "100.64.0.1",
|
||||
AuthKeyID: uint(pak1n1.ID),
|
||||
AuthKeyID: uint(preAuthKeyInShared1.ID),
|
||||
}
|
||||
h.db.Save(m1)
|
||||
app.db.Save(machineInShared1)
|
||||
|
||||
_, err = h.GetMachine(n1.Name, m1.Name)
|
||||
_, err = app.GetMachine(namespaceShared1.Name, machineInShared1.Name)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
m2 := &Machine{
|
||||
machineInShared2 := &Machine{
|
||||
ID: 2,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Name: "test_get_shared_nodes_2",
|
||||
NamespaceID: n2.ID,
|
||||
Namespace: *n2,
|
||||
NamespaceID: namespaceShared2.ID,
|
||||
Namespace: *namespaceShared2,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
IPAddress: "100.64.0.2",
|
||||
AuthKeyID: uint(pak2n2.ID),
|
||||
AuthKeyID: uint(preAuthKeyInShared2.ID),
|
||||
}
|
||||
h.db.Save(m2)
|
||||
app.db.Save(machineInShared2)
|
||||
|
||||
_, err = h.GetMachine(n2.Name, m2.Name)
|
||||
_, err = app.GetMachine(namespaceShared2.Name, machineInShared2.Name)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
m3 := &Machine{
|
||||
machineInShared3 := &Machine{
|
||||
ID: 3,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Name: "test_get_shared_nodes_3",
|
||||
NamespaceID: n3.ID,
|
||||
Namespace: *n3,
|
||||
NamespaceID: namespaceShared3.ID,
|
||||
Namespace: *namespaceShared3,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
IPAddress: "100.64.0.3",
|
||||
AuthKeyID: uint(pak3n3.ID),
|
||||
AuthKeyID: uint(preAuthKeyInShared3.ID),
|
||||
}
|
||||
h.db.Save(m3)
|
||||
app.db.Save(machineInShared3)
|
||||
|
||||
_, err = h.GetMachine(n3.Name, m3.Name)
|
||||
_, err = app.GetMachine(namespaceShared3.Name, machineInShared3.Name)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
m4 := &Machine{
|
||||
machine2InShared1 := &Machine{
|
||||
ID: 4,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Name: "test_get_shared_nodes_4",
|
||||
NamespaceID: n1.ID,
|
||||
Namespace: *n1,
|
||||
NamespaceID: namespaceShared1.ID,
|
||||
Namespace: *namespaceShared1,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
IPAddress: "100.64.0.4",
|
||||
AuthKeyID: uint(pak4n1.ID),
|
||||
AuthKeyID: uint(preAuthKey2InShared1.ID),
|
||||
}
|
||||
h.db.Save(m4)
|
||||
app.db.Save(machine2InShared1)
|
||||
|
||||
err = h.AddSharedMachineToNamespace(m2, n1)
|
||||
err = app.AddSharedMachineToNamespace(machineInShared2, namespaceShared1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
baseDomain := "foobar.headscale.net"
|
||||
|
@ -295,11 +342,15 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
|||
Proxied: false,
|
||||
}
|
||||
|
||||
m1peers, err := h.getPeers(m1)
|
||||
peersOfMachine1Shared1, err := app.getPeers(machineInShared1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
dnsConfig, err := getMapResponseDNSConfig(&dnsConfigOrig, baseDomain, *m1, m1peers)
|
||||
c.Assert(err, check.IsNil)
|
||||
dnsConfig := getMapResponseDNSConfig(
|
||||
&dnsConfigOrig,
|
||||
baseDomain,
|
||||
*machineInShared1,
|
||||
peersOfMachine1Shared1,
|
||||
)
|
||||
c.Assert(dnsConfig, check.NotNil)
|
||||
c.Assert(len(dnsConfig.Routes), check.Equals, 0)
|
||||
c.Assert(len(dnsConfig.Domains), check.Equals, 1)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Running headscale
|
||||
|
||||
## Server configuration
|
||||
|
||||
1. Download the headscale binary https://github.com/juanfont/headscale/releases, and place it somewhere in your $PATH or use the docker container
|
||||
|
||||
```shell
|
||||
|
@ -18,10 +19,11 @@
|
|||
```shell
|
||||
mkdir config
|
||||
```
|
||||
|
||||
|
||||
3. Get yourself a DB
|
||||
|
||||
a) Get a Postgres DB running in docker
|
||||
|
||||
```shell
|
||||
docker run --name headscale \
|
||||
-e POSTGRES_DB=headscale
|
||||
|
@ -30,7 +32,9 @@
|
|||
-p 5432:5432 \
|
||||
-d postgres
|
||||
```
|
||||
|
||||
or b) Prepare a SQLite DB file
|
||||
|
||||
```shell
|
||||
touch config/db.sqlite
|
||||
```
|
||||
|
@ -41,7 +45,7 @@
|
|||
wg genkey > config/private.key
|
||||
|
||||
cp config.yaml.[sqlite|postgres].example config/config.yaml
|
||||
|
||||
|
||||
cp derp-example.yaml config/derp.yaml
|
||||
```
|
||||
|
||||
|
@ -81,16 +85,19 @@
|
|||
-p 127.0.0.1:8080:8080 \
|
||||
headscale/headscale:x.x.x headscale serve
|
||||
```
|
||||
|
||||
## Nodes configuration
|
||||
|
||||
If you used tailscale.com before in your nodes, make sure you clear the tailscaled data folder
|
||||
|
||||
```shell
|
||||
systemctl stop tailscaled
|
||||
rm -fr /var/lib/tailscale
|
||||
systemctl start tailscaled
|
||||
```
|
||||
```shell
|
||||
systemctl stop tailscaled
|
||||
rm -fr /var/lib/tailscale
|
||||
systemctl start tailscaled
|
||||
```
|
||||
|
||||
### Adding node based on MACHINEKEY
|
||||
|
||||
1. Add your first machine
|
||||
|
||||
```shell
|
||||
|
|
17
grpcv1.go
17
grpcv1.go
|
@ -155,7 +155,10 @@ func (api headscaleV1APIServer) RegisterMachine(
|
|||
ctx context.Context,
|
||||
request *v1.RegisterMachineRequest,
|
||||
) (*v1.RegisterMachineResponse, error) {
|
||||
log.Trace().Str("namespace", request.GetNamespace()).Str("machine_key", request.GetKey()).Msg("Registering machine")
|
||||
log.Trace().
|
||||
Str("namespace", request.GetNamespace()).
|
||||
Str("machine_key", request.GetKey()).
|
||||
Msg("Registering machine")
|
||||
machine, err := api.h.RegisterMachine(
|
||||
request.GetKey(),
|
||||
request.GetNamespace(),
|
||||
|
@ -208,7 +211,9 @@ func (api headscaleV1APIServer) ListMachines(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
sharedMachines, err := api.h.ListSharedMachinesInNamespace(request.GetNamespace())
|
||||
sharedMachines, err := api.h.ListSharedMachinesInNamespace(
|
||||
request.GetNamespace(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -333,12 +338,16 @@ func (api headscaleV1APIServer) DebugCreateMachine(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
routes, err := stringToIpPrefix(request.GetRoutes())
|
||||
routes, err := stringToIPPrefix(request.GetRoutes())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Trace().Caller().Interface("route-prefix", routes).Interface("route-str", request.GetRoutes()).Msg("")
|
||||
log.Trace().
|
||||
Caller().
|
||||
Interface("route-prefix", routes).
|
||||
Interface("route-str", request.GetRoutes()).
|
||||
Msg("")
|
||||
|
||||
hostinfo := tailcfg.Hostinfo{
|
||||
RoutableIPs: routes,
|
||||
|
|
|
@ -88,6 +88,7 @@ func (s *IntegrationCLITestSuite) SetupTest() {
|
|||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("status code not OK")
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
// TODO(kradalby): If we cannot access headscale, or any other fatal error during
|
||||
|
@ -109,7 +110,10 @@ func (s *IntegrationCLITestSuite) TearDownTest() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *IntegrationCLITestSuite) HandleStats(suiteName string, stats *suite.SuiteInformation) {
|
||||
func (s *IntegrationCLITestSuite) HandleStats(
|
||||
suiteName string,
|
||||
stats *suite.SuiteInformation,
|
||||
) {
|
||||
s.stats = stats
|
||||
}
|
||||
|
||||
|
@ -144,7 +148,6 @@ func (s *IntegrationCLITestSuite) TestNamespaceCommand() {
|
|||
namespaces := make([]*v1.Namespace, len(names))
|
||||
|
||||
for index, namespaceName := range names {
|
||||
|
||||
namespace, err := s.createNamespace(namespaceName)
|
||||
assert.Nil(s.T(), err)
|
||||
|
||||
|
@ -298,11 +301,26 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
|
|||
assert.True(s.T(), listedPreAuthKeys[3].Expiration.AsTime().After(time.Now()))
|
||||
assert.True(s.T(), listedPreAuthKeys[4].Expiration.AsTime().After(time.Now()))
|
||||
|
||||
assert.True(s.T(), listedPreAuthKeys[0].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
|
||||
assert.True(s.T(), listedPreAuthKeys[1].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
|
||||
assert.True(s.T(), listedPreAuthKeys[2].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
|
||||
assert.True(s.T(), listedPreAuthKeys[3].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
|
||||
assert.True(s.T(), listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
|
||||
assert.True(
|
||||
s.T(),
|
||||
listedPreAuthKeys[0].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
|
||||
)
|
||||
assert.True(
|
||||
s.T(),
|
||||
listedPreAuthKeys[1].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
|
||||
)
|
||||
assert.True(
|
||||
s.T(),
|
||||
listedPreAuthKeys[2].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
|
||||
)
|
||||
assert.True(
|
||||
s.T(),
|
||||
listedPreAuthKeys[3].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
|
||||
)
|
||||
assert.True(
|
||||
s.T(),
|
||||
listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
|
||||
)
|
||||
|
||||
// Expire three keys
|
||||
for i := 0; i < 3; i++ {
|
||||
|
@ -341,11 +359,26 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
|
|||
err = json.Unmarshal([]byte(listAfterExpireResult), &listedAfterExpirePreAuthKeys)
|
||||
assert.Nil(s.T(), err)
|
||||
|
||||
assert.True(s.T(), listedAfterExpirePreAuthKeys[0].Expiration.AsTime().Before(time.Now()))
|
||||
assert.True(s.T(), listedAfterExpirePreAuthKeys[1].Expiration.AsTime().Before(time.Now()))
|
||||
assert.True(s.T(), listedAfterExpirePreAuthKeys[2].Expiration.AsTime().Before(time.Now()))
|
||||
assert.True(s.T(), listedAfterExpirePreAuthKeys[3].Expiration.AsTime().After(time.Now()))
|
||||
assert.True(s.T(), listedAfterExpirePreAuthKeys[4].Expiration.AsTime().After(time.Now()))
|
||||
assert.True(
|
||||
s.T(),
|
||||
listedAfterExpirePreAuthKeys[0].Expiration.AsTime().Before(time.Now()),
|
||||
)
|
||||
assert.True(
|
||||
s.T(),
|
||||
listedAfterExpirePreAuthKeys[1].Expiration.AsTime().Before(time.Now()),
|
||||
)
|
||||
assert.True(
|
||||
s.T(),
|
||||
listedAfterExpirePreAuthKeys[2].Expiration.AsTime().Before(time.Now()),
|
||||
)
|
||||
assert.True(
|
||||
s.T(),
|
||||
listedAfterExpirePreAuthKeys[3].Expiration.AsTime().After(time.Now()),
|
||||
)
|
||||
assert.True(
|
||||
s.T(),
|
||||
listedAfterExpirePreAuthKeys[4].Expiration.AsTime().After(time.Now()),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *IntegrationCLITestSuite) TestPreAuthKeyCommandWithoutExpiry() {
|
||||
|
@ -689,7 +722,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
|
|||
assert.Nil(s.T(), err)
|
||||
|
||||
var listOnlySharedMachineNamespace []v1.Machine
|
||||
err = json.Unmarshal([]byte(listOnlySharedMachineNamespaceResult), &listOnlySharedMachineNamespace)
|
||||
err = json.Unmarshal(
|
||||
[]byte(listOnlySharedMachineNamespaceResult),
|
||||
&listOnlySharedMachineNamespace,
|
||||
)
|
||||
assert.Nil(s.T(), err)
|
||||
|
||||
assert.Len(s.T(), listOnlySharedMachineNamespace, 2)
|
||||
|
@ -738,7 +774,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
|
|||
assert.Nil(s.T(), err)
|
||||
|
||||
var listOnlyMachineNamespaceAfterDelete []v1.Machine
|
||||
err = json.Unmarshal([]byte(listOnlyMachineNamespaceAfterDeleteResult), &listOnlyMachineNamespaceAfterDelete)
|
||||
err = json.Unmarshal(
|
||||
[]byte(listOnlyMachineNamespaceAfterDeleteResult),
|
||||
&listOnlyMachineNamespaceAfterDelete,
|
||||
)
|
||||
assert.Nil(s.T(), err)
|
||||
|
||||
assert.Len(s.T(), listOnlyMachineNamespaceAfterDelete, 4)
|
||||
|
@ -789,7 +828,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
|
|||
assert.Nil(s.T(), err)
|
||||
|
||||
var listOnlyMachineNamespaceAfterShare []v1.Machine
|
||||
err = json.Unmarshal([]byte(listOnlyMachineNamespaceAfterShareResult), &listOnlyMachineNamespaceAfterShare)
|
||||
err = json.Unmarshal(
|
||||
[]byte(listOnlyMachineNamespaceAfterShareResult),
|
||||
&listOnlyMachineNamespaceAfterShare,
|
||||
)
|
||||
assert.Nil(s.T(), err)
|
||||
|
||||
assert.Len(s.T(), listOnlyMachineNamespaceAfterShare, 5)
|
||||
|
@ -846,7 +888,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
|
|||
assert.Nil(s.T(), err)
|
||||
|
||||
var listOnlyMachineNamespaceAfterUnshare []v1.Machine
|
||||
err = json.Unmarshal([]byte(listOnlyMachineNamespaceAfterUnshareResult), &listOnlyMachineNamespaceAfterUnshare)
|
||||
err = json.Unmarshal(
|
||||
[]byte(listOnlyMachineNamespaceAfterUnshareResult),
|
||||
&listOnlyMachineNamespaceAfterUnshare,
|
||||
)
|
||||
assert.Nil(s.T(), err)
|
||||
|
||||
assert.Len(s.T(), listOnlyMachineNamespaceAfterUnshare, 4)
|
||||
|
@ -1010,5 +1055,9 @@ func (s *IntegrationCLITestSuite) TestRouteCommand() {
|
|||
)
|
||||
assert.Nil(s.T(), err)
|
||||
|
||||
assert.Contains(s.T(), string(failEnableNonAdvertisedRoute), "route (route-machine) is not available on node")
|
||||
assert.Contains(
|
||||
s.T(),
|
||||
string(failEnableNonAdvertisedRoute),
|
||||
"route (route-machine) is not available on node",
|
||||
)
|
||||
}
|
||||
|
|
|
@ -12,12 +12,18 @@ import (
|
|||
"github.com/ory/dockertest/v3/docker"
|
||||
)
|
||||
|
||||
func ExecuteCommand(resource *dockertest.Resource, cmd []string, env []string) (string, error) {
|
||||
const DOCKER_EXECUTE_TIMEOUT = 10 * time.Second
|
||||
|
||||
func ExecuteCommand(
|
||||
resource *dockertest.Resource,
|
||||
cmd []string,
|
||||
env []string,
|
||||
) (string, error) {
|
||||
var stdout bytes.Buffer
|
||||
var stderr bytes.Buffer
|
||||
|
||||
// TODO(kradalby): Make configurable
|
||||
timeout := 10 * time.Second
|
||||
timeout := DOCKER_EXECUTE_TIMEOUT
|
||||
|
||||
type result struct {
|
||||
exitCode int
|
||||
|
@ -51,11 +57,13 @@ func ExecuteCommand(resource *dockertest.Resource, cmd []string, env []string) (
|
|||
fmt.Println("Command: ", cmd)
|
||||
fmt.Println("stdout: ", stdout.String())
|
||||
fmt.Println("stderr: ", stderr.String())
|
||||
|
||||
return "", fmt.Errorf("command failed with: %s", stderr.String())
|
||||
}
|
||||
|
||||
return stdout.String(), nil
|
||||
case <-time.After(timeout):
|
||||
|
||||
return "", fmt.Errorf("command timed out after %s", timeout)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,10 +23,9 @@ import (
|
|||
"github.com/ory/dockertest/v3/docker"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/client/tailscale/apitype"
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
var tailscaleVersions = []string{"1.16.2", "1.14.3", "1.12.3"}
|
||||
|
@ -89,7 +88,10 @@ func TestIntegrationTestSuite(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *IntegrationTestSuite) saveLog(resource *dockertest.Resource, basePath string) error {
|
||||
func (s *IntegrationTestSuite) saveLog(
|
||||
resource *dockertest.Resource,
|
||||
basePath string,
|
||||
) error {
|
||||
err := os.MkdirAll(basePath, os.ModePerm)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -118,12 +120,20 @@ func (s *IntegrationTestSuite) saveLog(resource *dockertest.Resource, basePath s
|
|||
|
||||
fmt.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath)
|
||||
|
||||
err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stdout.log"), []byte(stdout.String()), 0o644)
|
||||
err = ioutil.WriteFile(
|
||||
path.Join(basePath, resource.Container.Name+".stdout.log"),
|
||||
[]byte(stdout.String()),
|
||||
0o644,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stderr.log"), []byte(stdout.String()), 0o644)
|
||||
err = ioutil.WriteFile(
|
||||
path.Join(basePath, resource.Container.Name+".stderr.log"),
|
||||
[]byte(stdout.String()),
|
||||
0o644,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -144,24 +154,38 @@ func (s *IntegrationTestSuite) tailscaleContainer(
|
|||
},
|
||||
},
|
||||
}
|
||||
hostname := fmt.Sprintf("%s-tailscale-%s-%s", namespace, strings.Replace(version, ".", "-", -1), identifier)
|
||||
hostname := fmt.Sprintf(
|
||||
"%s-tailscale-%s-%s",
|
||||
namespace,
|
||||
strings.Replace(version, ".", "-", -1),
|
||||
identifier,
|
||||
)
|
||||
tailscaleOptions := &dockertest.RunOptions{
|
||||
Name: hostname,
|
||||
Networks: []*dockertest.Network{&s.network},
|
||||
Cmd: []string{"tailscaled", "--tun=userspace-networking", "--socks5-server=localhost:1055"},
|
||||
Cmd: []string{
|
||||
"tailscaled",
|
||||
"--tun=userspace-networking",
|
||||
"--socks5-server=localhost:1055",
|
||||
},
|
||||
}
|
||||
|
||||
pts, err := s.pool.BuildAndRunWithBuildOptions(tailscaleBuildOptions, tailscaleOptions, DockerRestartPolicy)
|
||||
pts, err := s.pool.BuildAndRunWithBuildOptions(
|
||||
tailscaleBuildOptions,
|
||||
tailscaleOptions,
|
||||
DockerRestartPolicy,
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatalf("Could not start resource: %s", err)
|
||||
}
|
||||
fmt.Printf("Created %s container\n", hostname)
|
||||
|
||||
return hostname, pts
|
||||
}
|
||||
|
||||
func (s *IntegrationTestSuite) SetupSuite() {
|
||||
var err error
|
||||
h = Headscale{
|
||||
app = Headscale{
|
||||
dbType: "sqlite3",
|
||||
dbString: "integration_test_db.sqlite3",
|
||||
}
|
||||
|
@ -210,7 +234,11 @@ func (s *IntegrationTestSuite) SetupSuite() {
|
|||
for i := 0; i < scales.count; i++ {
|
||||
version := tailscaleVersions[i%len(tailscaleVersions)]
|
||||
|
||||
hostname, container := s.tailscaleContainer(namespace, fmt.Sprint(i), version)
|
||||
hostname, container := s.tailscaleContainer(
|
||||
namespace,
|
||||
fmt.Sprint(i),
|
||||
version,
|
||||
)
|
||||
scales.tailscales[hostname] = *container
|
||||
}
|
||||
}
|
||||
|
@ -220,13 +248,16 @@ func (s *IntegrationTestSuite) SetupSuite() {
|
|||
|
||||
if err := s.pool.Retry(func() error {
|
||||
url := fmt.Sprintf("http://%s/health", hostEndpoint)
|
||||
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("status code not OK")
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
// TODO(kradalby): If we cannot access headscale, or any other fatal error during
|
||||
|
@ -273,7 +304,10 @@ func (s *IntegrationTestSuite) SetupSuite() {
|
|||
|
||||
headscaleEndpoint := "http://headscale:8080"
|
||||
|
||||
fmt.Printf("Joining tailscale containers to headscale at %s\n", headscaleEndpoint)
|
||||
fmt.Printf(
|
||||
"Joining tailscale containers to headscale at %s\n",
|
||||
headscaleEndpoint,
|
||||
)
|
||||
for hostname, tailscale := range scales.tailscales {
|
||||
command := []string{
|
||||
"tailscale",
|
||||
|
@ -307,7 +341,10 @@ func (s *IntegrationTestSuite) SetupSuite() {
|
|||
func (s *IntegrationTestSuite) TearDownSuite() {
|
||||
}
|
||||
|
||||
func (s *IntegrationTestSuite) HandleStats(suiteName string, stats *suite.SuiteInformation) {
|
||||
func (s *IntegrationTestSuite) HandleStats(
|
||||
suiteName string,
|
||||
stats *suite.SuiteInformation,
|
||||
) {
|
||||
s.stats = stats
|
||||
}
|
||||
|
||||
|
@ -427,7 +464,13 @@ func (s *IntegrationTestSuite) TestPingAllPeers() {
|
|||
ip.String(),
|
||||
}
|
||||
|
||||
fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip)
|
||||
fmt.Printf(
|
||||
"Pinging from %s (%s) to %s (%s)\n",
|
||||
hostname,
|
||||
ips[hostname],
|
||||
peername,
|
||||
ip,
|
||||
)
|
||||
result, err := ExecuteCommand(
|
||||
&tailscale,
|
||||
command,
|
||||
|
@ -449,7 +492,15 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
|
|||
|
||||
result, err := ExecuteCommand(
|
||||
&s.headscale,
|
||||
[]string{"headscale", "nodes", "list", "--output", "json", "--namespace", "shared"},
|
||||
[]string{
|
||||
"headscale",
|
||||
"nodes",
|
||||
"list",
|
||||
"--output",
|
||||
"json",
|
||||
"--namespace",
|
||||
"shared",
|
||||
},
|
||||
[]string{},
|
||||
)
|
||||
assert.Nil(s.T(), err)
|
||||
|
@ -459,7 +510,6 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
|
|||
assert.Nil(s.T(), err)
|
||||
|
||||
for _, machine := range machineList {
|
||||
|
||||
result, err := ExecuteCommand(
|
||||
&s.headscale,
|
||||
[]string{
|
||||
|
@ -520,7 +570,13 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
|
|||
ip.String(),
|
||||
}
|
||||
|
||||
fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, mainIps[hostname], peername, ip)
|
||||
fmt.Printf(
|
||||
"Pinging from %s (%s) to %s (%s)\n",
|
||||
hostname,
|
||||
mainIps[hostname],
|
||||
peername,
|
||||
ip,
|
||||
)
|
||||
result, err := ExecuteCommand(
|
||||
&tailscale,
|
||||
command,
|
||||
|
@ -553,7 +609,6 @@ func (s *IntegrationTestSuite) TestTailDrop() {
|
|||
for peername, ip := range ips {
|
||||
s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
|
||||
if peername != hostname {
|
||||
|
||||
// Under normal circumstances, we should be able to send a file
|
||||
// using `tailscale file cp` - but not in userspace networking mode
|
||||
// So curl!
|
||||
|
@ -578,9 +633,19 @@ func (s *IntegrationTestSuite) TestTailDrop() {
|
|||
"PUT",
|
||||
"--upload-file",
|
||||
fmt.Sprintf("/tmp/file_from_%s", hostname),
|
||||
fmt.Sprintf("%s/v0/put/file_from_%s", peerAPI, hostname),
|
||||
fmt.Sprintf(
|
||||
"%s/v0/put/file_from_%s",
|
||||
peerAPI,
|
||||
hostname,
|
||||
),
|
||||
}
|
||||
fmt.Printf("Sending file from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip)
|
||||
fmt.Printf(
|
||||
"Sending file from %s (%s) to %s (%s)\n",
|
||||
hostname,
|
||||
ips[hostname],
|
||||
peername,
|
||||
ip,
|
||||
)
|
||||
_, err = ExecuteCommand(
|
||||
&tailscale,
|
||||
command,
|
||||
|
@ -621,7 +686,13 @@ func (s *IntegrationTestSuite) TestTailDrop() {
|
|||
"ls",
|
||||
fmt.Sprintf("/tmp/file_from_%s", peername),
|
||||
}
|
||||
fmt.Printf("Checking file in %s (%s) from %s (%s)\n", hostname, ips[hostname], peername, ip)
|
||||
fmt.Printf(
|
||||
"Checking file in %s (%s) from %s (%s)\n",
|
||||
hostname,
|
||||
ips[hostname],
|
||||
peername,
|
||||
ip,
|
||||
)
|
||||
result, err := ExecuteCommand(
|
||||
&tailscale,
|
||||
command,
|
||||
|
@ -629,7 +700,11 @@ func (s *IntegrationTestSuite) TestTailDrop() {
|
|||
)
|
||||
assert.Nil(t, err)
|
||||
fmt.Printf("Result for %s: %s\n", peername, result)
|
||||
assert.Equal(t, result, fmt.Sprintf("/tmp/file_from_%s\n", peername))
|
||||
assert.Equal(
|
||||
t,
|
||||
result,
|
||||
fmt.Sprintf("/tmp/file_from_%s\n", peername),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -696,10 +771,13 @@ func getIPs(tailscales map[string]dockertest.Resource) (map[string]netaddr.IP, e
|
|||
|
||||
ips[hostname] = ip
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
func getAPIURLs(tailscales map[string]dockertest.Resource) (map[netaddr.IP]string, error) {
|
||||
func getAPIURLs(
|
||||
tailscales map[string]dockertest.Resource,
|
||||
) (map[netaddr.IP]string, error) {
|
||||
fts := make(map[netaddr.IP]string)
|
||||
for _, tailscale := range tailscales {
|
||||
command := []string{
|
||||
|
@ -733,5 +811,6 @@ func getAPIURLs(tailscales map[string]dockertest.Resource) (map[netaddr.IP]strin
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fts, nil
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ Configure DERP servers by editing `base/site/derp.yaml` if needed.
|
|||
You'll somehow need to get `headscale:latest` into your cluster image registry.
|
||||
|
||||
An easy way to do this with k3s:
|
||||
|
||||
- Reconfigure k3s to use docker instead of containerd (`k3s server --docker`)
|
||||
- `docker build -t headscale:latest ..` from here
|
||||
|
||||
|
@ -61,7 +62,7 @@ Use the wrapper script to remotely operate headscale to perform administrative
|
|||
tasks like creating namespaces, authkeys, etc.
|
||||
|
||||
```
|
||||
[c@nix-slate:~/Projects/headscale/k8s]$ ./headscale.bash
|
||||
[c@nix-slate:~/Projects/headscale/k8s]$ ./headscale.bash
|
||||
|
||||
headscale is an open source implementation of the Tailscale control server
|
||||
|
||||
|
|
|
@ -6,13 +6,13 @@ metadata:
|
|||
kubernetes.io/ingress.class: traefik
|
||||
spec:
|
||||
rules:
|
||||
- host: $(PUBLIC_HOSTNAME)
|
||||
http:
|
||||
paths:
|
||||
- backend:
|
||||
service:
|
||||
name: headscale
|
||||
port:
|
||||
number: 8080
|
||||
path: /
|
||||
pathType: Prefix
|
||||
- host: $(PUBLIC_HOSTNAME)
|
||||
http:
|
||||
paths:
|
||||
- backend:
|
||||
service:
|
||||
name: headscale
|
||||
port:
|
||||
number: 8080
|
||||
path: /
|
||||
pathType: Prefix
|
||||
|
|
|
@ -1,42 +1,42 @@
|
|||
namespace: headscale
|
||||
resources:
|
||||
- configmap.yaml
|
||||
- ingress.yaml
|
||||
- service.yaml
|
||||
- configmap.yaml
|
||||
- ingress.yaml
|
||||
- service.yaml
|
||||
generatorOptions:
|
||||
disableNameSuffixHash: true
|
||||
configMapGenerator:
|
||||
- name: headscale-site
|
||||
files:
|
||||
- derp.yaml=site/derp.yaml
|
||||
envs:
|
||||
- site/public.env
|
||||
- name: headscale-etc
|
||||
literals:
|
||||
- config.json={}
|
||||
- name: headscale-site
|
||||
files:
|
||||
- derp.yaml=site/derp.yaml
|
||||
envs:
|
||||
- site/public.env
|
||||
- name: headscale-etc
|
||||
literals:
|
||||
- config.json={}
|
||||
secretGenerator:
|
||||
- name: headscale
|
||||
files:
|
||||
- secrets/private-key
|
||||
- name: headscale
|
||||
files:
|
||||
- secrets/private-key
|
||||
vars:
|
||||
- name: PUBLIC_PROTO
|
||||
objRef:
|
||||
kind: ConfigMap
|
||||
name: headscale-site
|
||||
apiVersion: v1
|
||||
fieldRef:
|
||||
fieldPath: data.public-proto
|
||||
- name: PUBLIC_HOSTNAME
|
||||
objRef:
|
||||
kind: ConfigMap
|
||||
name: headscale-site
|
||||
apiVersion: v1
|
||||
fieldRef:
|
||||
fieldPath: data.public-hostname
|
||||
- name: CONTACT_EMAIL
|
||||
objRef:
|
||||
kind: ConfigMap
|
||||
name: headscale-site
|
||||
apiVersion: v1
|
||||
fieldRef:
|
||||
fieldPath: data.contact-email
|
||||
- name: PUBLIC_PROTO
|
||||
objRef:
|
||||
kind: ConfigMap
|
||||
name: headscale-site
|
||||
apiVersion: v1
|
||||
fieldRef:
|
||||
fieldPath: data.public-proto
|
||||
- name: PUBLIC_HOSTNAME
|
||||
objRef:
|
||||
kind: ConfigMap
|
||||
name: headscale-site
|
||||
apiVersion: v1
|
||||
fieldRef:
|
||||
fieldPath: data.public-hostname
|
||||
- name: CONTACT_EMAIL
|
||||
objRef:
|
||||
kind: ConfigMap
|
||||
name: headscale-site
|
||||
apiVersion: v1
|
||||
fieldRef:
|
||||
fieldPath: data.contact-email
|
||||
|
|
|
@ -8,6 +8,6 @@ spec:
|
|||
selector:
|
||||
app: headscale
|
||||
ports:
|
||||
- name: http
|
||||
targetPort: http
|
||||
port: 8080
|
||||
- name: http
|
||||
targetPort: http
|
||||
port: 8080
|
||||
|
|
|
@ -13,66 +13,66 @@ spec:
|
|||
app: headscale
|
||||
spec:
|
||||
containers:
|
||||
- name: headscale
|
||||
image: "headscale:latest"
|
||||
imagePullPolicy: IfNotPresent
|
||||
command: ["/go/bin/headscale", "serve"]
|
||||
env:
|
||||
- name: SERVER_URL
|
||||
value: $(PUBLIC_PROTO)://$(PUBLIC_HOSTNAME)
|
||||
- name: LISTEN_ADDR
|
||||
valueFrom:
|
||||
configMapKeyRef:
|
||||
name: headscale-config
|
||||
key: listen_addr
|
||||
- name: PRIVATE_KEY_PATH
|
||||
value: /vol/secret/private-key
|
||||
- name: DERP_MAP_PATH
|
||||
value: /vol/config/derp.yaml
|
||||
- name: EPHEMERAL_NODE_INACTIVITY_TIMEOUT
|
||||
valueFrom:
|
||||
configMapKeyRef:
|
||||
name: headscale-config
|
||||
key: ephemeral_node_inactivity_timeout
|
||||
- name: DB_TYPE
|
||||
value: postgres
|
||||
- name: DB_HOST
|
||||
value: postgres.headscale.svc.cluster.local
|
||||
- name: DB_PORT
|
||||
value: "5432"
|
||||
- name: DB_USER
|
||||
value: headscale
|
||||
- name: DB_PASS
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: postgresql
|
||||
key: password
|
||||
- name: DB_NAME
|
||||
value: headscale
|
||||
ports:
|
||||
- name: http
|
||||
protocol: TCP
|
||||
containerPort: 8080
|
||||
livenessProbe:
|
||||
tcpSocket:
|
||||
port: http
|
||||
initialDelaySeconds: 30
|
||||
timeoutSeconds: 5
|
||||
periodSeconds: 15
|
||||
volumeMounts:
|
||||
- name: config
|
||||
mountPath: /vol/config
|
||||
- name: secret
|
||||
mountPath: /vol/secret
|
||||
- name: etc
|
||||
mountPath: /etc/headscale
|
||||
- name: headscale
|
||||
image: "headscale:latest"
|
||||
imagePullPolicy: IfNotPresent
|
||||
command: ["/go/bin/headscale", "serve"]
|
||||
env:
|
||||
- name: SERVER_URL
|
||||
value: $(PUBLIC_PROTO)://$(PUBLIC_HOSTNAME)
|
||||
- name: LISTEN_ADDR
|
||||
valueFrom:
|
||||
configMapKeyRef:
|
||||
name: headscale-config
|
||||
key: listen_addr
|
||||
- name: PRIVATE_KEY_PATH
|
||||
value: /vol/secret/private-key
|
||||
- name: DERP_MAP_PATH
|
||||
value: /vol/config/derp.yaml
|
||||
- name: EPHEMERAL_NODE_INACTIVITY_TIMEOUT
|
||||
valueFrom:
|
||||
configMapKeyRef:
|
||||
name: headscale-config
|
||||
key: ephemeral_node_inactivity_timeout
|
||||
- name: DB_TYPE
|
||||
value: postgres
|
||||
- name: DB_HOST
|
||||
value: postgres.headscale.svc.cluster.local
|
||||
- name: DB_PORT
|
||||
value: "5432"
|
||||
- name: DB_USER
|
||||
value: headscale
|
||||
- name: DB_PASS
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: postgresql
|
||||
key: password
|
||||
- name: DB_NAME
|
||||
value: headscale
|
||||
ports:
|
||||
- name: http
|
||||
protocol: TCP
|
||||
containerPort: 8080
|
||||
livenessProbe:
|
||||
tcpSocket:
|
||||
port: http
|
||||
initialDelaySeconds: 30
|
||||
timeoutSeconds: 5
|
||||
periodSeconds: 15
|
||||
volumeMounts:
|
||||
- name: config
|
||||
mountPath: /vol/config
|
||||
- name: secret
|
||||
mountPath: /vol/secret
|
||||
- name: etc
|
||||
mountPath: /etc/headscale
|
||||
volumes:
|
||||
- name: config
|
||||
configMap:
|
||||
name: headscale-site
|
||||
- name: etc
|
||||
configMap:
|
||||
name: headscale-etc
|
||||
- name: secret
|
||||
secret:
|
||||
secretName: headscale
|
||||
- name: config
|
||||
configMap:
|
||||
name: headscale-site
|
||||
- name: etc
|
||||
configMap:
|
||||
name: headscale-etc
|
||||
- name: secret
|
||||
secret:
|
||||
secretName: headscale
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
namespace: headscale
|
||||
bases:
|
||||
- ../base
|
||||
- ../base
|
||||
resources:
|
||||
- deployment.yaml
|
||||
- postgres-service.yaml
|
||||
- postgres-statefulset.yaml
|
||||
- deployment.yaml
|
||||
- postgres-service.yaml
|
||||
- postgres-statefulset.yaml
|
||||
generatorOptions:
|
||||
disableNameSuffixHash: true
|
||||
secretGenerator:
|
||||
- name: postgresql
|
||||
files:
|
||||
- secrets/password
|
||||
- name: postgresql
|
||||
files:
|
||||
- secrets/password
|
||||
|
|
|
@ -8,6 +8,6 @@ spec:
|
|||
selector:
|
||||
app: postgres
|
||||
ports:
|
||||
- name: postgres
|
||||
targetPort: postgres
|
||||
port: 5432
|
||||
- name: postgres
|
||||
targetPort: postgres
|
||||
port: 5432
|
||||
|
|
|
@ -14,36 +14,36 @@ spec:
|
|||
app: postgres
|
||||
spec:
|
||||
containers:
|
||||
- name: postgres
|
||||
image: "postgres:13"
|
||||
imagePullPolicy: IfNotPresent
|
||||
env:
|
||||
- name: POSTGRES_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: postgresql
|
||||
key: password
|
||||
- name: POSTGRES_USER
|
||||
value: headscale
|
||||
ports:
|
||||
- name: postgres
|
||||
protocol: TCP
|
||||
containerPort: 5432
|
||||
livenessProbe:
|
||||
tcpSocket:
|
||||
port: 5432
|
||||
initialDelaySeconds: 30
|
||||
timeoutSeconds: 5
|
||||
periodSeconds: 15
|
||||
volumeMounts:
|
||||
- name: pgdata
|
||||
mountPath: /var/lib/postgresql/data
|
||||
image: "postgres:13"
|
||||
imagePullPolicy: IfNotPresent
|
||||
env:
|
||||
- name: POSTGRES_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: postgresql
|
||||
key: password
|
||||
- name: POSTGRES_USER
|
||||
value: headscale
|
||||
ports:
|
||||
- name: postgres
|
||||
protocol: TCP
|
||||
containerPort: 5432
|
||||
livenessProbe:
|
||||
tcpSocket:
|
||||
port: 5432
|
||||
initialDelaySeconds: 30
|
||||
timeoutSeconds: 5
|
||||
periodSeconds: 15
|
||||
volumeMounts:
|
||||
- name: pgdata
|
||||
mountPath: /var/lib/postgresql/data
|
||||
volumeClaimTemplates:
|
||||
- metadata:
|
||||
name: pgdata
|
||||
spec:
|
||||
storageClassName: local-path
|
||||
accessModes: ["ReadWriteOnce"]
|
||||
resources:
|
||||
requests:
|
||||
storage: 1Gi
|
||||
- metadata:
|
||||
name: pgdata
|
||||
spec:
|
||||
storageClassName: local-path
|
||||
accessModes: ["ReadWriteOnce"]
|
||||
resources:
|
||||
requests:
|
||||
storage: 1Gi
|
||||
|
|
|
@ -6,6 +6,6 @@ metadata:
|
|||
traefik.ingress.kubernetes.io/router.tls: "true"
|
||||
spec:
|
||||
tls:
|
||||
- hosts:
|
||||
- $(PUBLIC_HOSTNAME)
|
||||
secretName: production-cert
|
||||
- hosts:
|
||||
- $(PUBLIC_HOSTNAME)
|
||||
secretName: production-cert
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
namespace: headscale
|
||||
bases:
|
||||
- ../base
|
||||
- ../base
|
||||
resources:
|
||||
- production-issuer.yaml
|
||||
- production-issuer.yaml
|
||||
patches:
|
||||
- path: ingress-patch.yaml
|
||||
target:
|
||||
kind: Ingress
|
||||
- path: ingress-patch.yaml
|
||||
target:
|
||||
kind: Ingress
|
||||
|
|
|
@ -11,6 +11,6 @@ spec:
|
|||
# Secret resource used to store the account's private key.
|
||||
name: letsencrypt-production-acc-key
|
||||
solvers:
|
||||
- http01:
|
||||
ingress:
|
||||
class: traefik
|
||||
- http01:
|
||||
ingress:
|
||||
class: traefik
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
namespace: headscale
|
||||
bases:
|
||||
- ../base
|
||||
- ../base
|
||||
resources:
|
||||
- statefulset.yaml
|
||||
- statefulset.yaml
|
||||
|
|
|
@ -14,66 +14,66 @@ spec:
|
|||
app: headscale
|
||||
spec:
|
||||
containers:
|
||||
- name: headscale
|
||||
image: "headscale:latest"
|
||||
imagePullPolicy: IfNotPresent
|
||||
command: ["/go/bin/headscale", "serve"]
|
||||
env:
|
||||
- name: SERVER_URL
|
||||
value: $(PUBLIC_PROTO)://$(PUBLIC_HOSTNAME)
|
||||
- name: LISTEN_ADDR
|
||||
valueFrom:
|
||||
configMapKeyRef:
|
||||
name: headscale-config
|
||||
key: listen_addr
|
||||
- name: PRIVATE_KEY_PATH
|
||||
value: /vol/secret/private-key
|
||||
- name: DERP_MAP_PATH
|
||||
value: /vol/config/derp.yaml
|
||||
- name: EPHEMERAL_NODE_INACTIVITY_TIMEOUT
|
||||
valueFrom:
|
||||
configMapKeyRef:
|
||||
name: headscale-config
|
||||
key: ephemeral_node_inactivity_timeout
|
||||
- name: DB_TYPE
|
||||
value: sqlite3
|
||||
- name: DB_PATH
|
||||
value: /vol/data/db.sqlite
|
||||
ports:
|
||||
- name: http
|
||||
protocol: TCP
|
||||
containerPort: 8080
|
||||
livenessProbe:
|
||||
tcpSocket:
|
||||
port: http
|
||||
initialDelaySeconds: 30
|
||||
timeoutSeconds: 5
|
||||
periodSeconds: 15
|
||||
volumeMounts:
|
||||
- name: config
|
||||
mountPath: /vol/config
|
||||
- name: data
|
||||
mountPath: /vol/data
|
||||
- name: secret
|
||||
mountPath: /vol/secret
|
||||
- name: etc
|
||||
mountPath: /etc/headscale
|
||||
- name: headscale
|
||||
image: "headscale:latest"
|
||||
imagePullPolicy: IfNotPresent
|
||||
command: ["/go/bin/headscale", "serve"]
|
||||
env:
|
||||
- name: SERVER_URL
|
||||
value: $(PUBLIC_PROTO)://$(PUBLIC_HOSTNAME)
|
||||
- name: LISTEN_ADDR
|
||||
valueFrom:
|
||||
configMapKeyRef:
|
||||
name: headscale-config
|
||||
key: listen_addr
|
||||
- name: PRIVATE_KEY_PATH
|
||||
value: /vol/secret/private-key
|
||||
- name: DERP_MAP_PATH
|
||||
value: /vol/config/derp.yaml
|
||||
- name: EPHEMERAL_NODE_INACTIVITY_TIMEOUT
|
||||
valueFrom:
|
||||
configMapKeyRef:
|
||||
name: headscale-config
|
||||
key: ephemeral_node_inactivity_timeout
|
||||
- name: DB_TYPE
|
||||
value: sqlite3
|
||||
- name: DB_PATH
|
||||
value: /vol/data/db.sqlite
|
||||
ports:
|
||||
- name: http
|
||||
protocol: TCP
|
||||
containerPort: 8080
|
||||
livenessProbe:
|
||||
tcpSocket:
|
||||
port: http
|
||||
initialDelaySeconds: 30
|
||||
timeoutSeconds: 5
|
||||
periodSeconds: 15
|
||||
volumeMounts:
|
||||
- name: config
|
||||
mountPath: /vol/config
|
||||
- name: data
|
||||
mountPath: /vol/data
|
||||
- name: secret
|
||||
mountPath: /vol/secret
|
||||
- name: etc
|
||||
mountPath: /etc/headscale
|
||||
volumes:
|
||||
- name: config
|
||||
configMap:
|
||||
name: headscale-site
|
||||
- name: etc
|
||||
configMap:
|
||||
name: headscale-etc
|
||||
- name: secret
|
||||
secret:
|
||||
secretName: headscale
|
||||
- name: config
|
||||
configMap:
|
||||
name: headscale-site
|
||||
- name: etc
|
||||
configMap:
|
||||
name: headscale-etc
|
||||
- name: secret
|
||||
secret:
|
||||
secretName: headscale
|
||||
volumeClaimTemplates:
|
||||
- metadata:
|
||||
name: data
|
||||
spec:
|
||||
storageClassName: local-path
|
||||
accessModes: ["ReadWriteOnce"]
|
||||
resources:
|
||||
requests:
|
||||
storage: 1Gi
|
||||
- metadata:
|
||||
name: data
|
||||
spec:
|
||||
storageClassName: local-path
|
||||
accessModes: ["ReadWriteOnce"]
|
||||
resources:
|
||||
requests:
|
||||
storage: 1Gi
|
||||
|
|
|
@ -6,6 +6,6 @@ metadata:
|
|||
traefik.ingress.kubernetes.io/router.tls: "true"
|
||||
spec:
|
||||
tls:
|
||||
- hosts:
|
||||
- $(PUBLIC_HOSTNAME)
|
||||
secretName: staging-cert
|
||||
- hosts:
|
||||
- $(PUBLIC_HOSTNAME)
|
||||
secretName: staging-cert
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
namespace: headscale
|
||||
bases:
|
||||
- ../base
|
||||
- ../base
|
||||
resources:
|
||||
- staging-issuer.yaml
|
||||
- staging-issuer.yaml
|
||||
patches:
|
||||
- path: ingress-patch.yaml
|
||||
target:
|
||||
kind: Ingress
|
||||
- path: ingress-patch.yaml
|
||||
target:
|
||||
kind: Ingress
|
||||
|
|
|
@ -11,6 +11,6 @@ spec:
|
|||
# Secret resource used to store the account's private key.
|
||||
name: letsencrypt-staging-acc-key
|
||||
solvers:
|
||||
- http01:
|
||||
ingress:
|
||||
class: traefik
|
||||
- http01:
|
||||
ingress:
|
||||
class: traefik
|
||||
|
|
402
machine.go
402
machine.go
|
@ -10,10 +10,9 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/fatih/set"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/rs/zerolog/log"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"gorm.io/datatypes"
|
||||
"gorm.io/gorm"
|
||||
"inet.af/netaddr"
|
||||
|
@ -21,7 +20,13 @@ import (
|
|||
"tailscale.com/types/wgkey"
|
||||
)
|
||||
|
||||
// Machine is a Headscale client
|
||||
const (
|
||||
errMachineNotFound = Error("machine not found")
|
||||
errMachineAlreadyRegistered = Error("machine already registered")
|
||||
errMachineRouteIsNotAvailable = Error("route is not available on machine")
|
||||
)
|
||||
|
||||
// Machine is a Headscale client.
|
||||
type Machine struct {
|
||||
ID uint64 `gorm:"primary_key"`
|
||||
MachineKey string `gorm:"type:varchar(64);unique_index"`
|
||||
|
@ -56,53 +61,58 @@ type (
|
|||
MachinesP []*Machine
|
||||
)
|
||||
|
||||
// For the time being this method is rather naive
|
||||
func (m Machine) isAlreadyRegistered() bool {
|
||||
return m.Registered
|
||||
// For the time being this method is rather naive.
|
||||
func (machine Machine) isAlreadyRegistered() bool {
|
||||
return machine.Registered
|
||||
}
|
||||
|
||||
// isExpired returns whether the machine registration has expired
|
||||
func (m Machine) isExpired() bool {
|
||||
return time.Now().UTC().After(*m.Expiry)
|
||||
// isExpired returns whether the machine registration has expired.
|
||||
func (machine Machine) isExpired() bool {
|
||||
return time.Now().UTC().After(*machine.Expiry)
|
||||
}
|
||||
|
||||
// If the Machine is expired, updateMachineExpiry updates the Machine Expiry time to the maximum allowed duration,
|
||||
// or the default duration if no Expiry time was requested by the client. The expiry time here does not (yet) cause
|
||||
// a client to be disconnected, however they will have to re-auth the machine if they attempt to reconnect after the
|
||||
// expiry time.
|
||||
func (h *Headscale) updateMachineExpiry(m *Machine) {
|
||||
if m.isExpired() {
|
||||
func (h *Headscale) updateMachineExpiry(machine *Machine) {
|
||||
if machine.isExpired() {
|
||||
now := time.Now().UTC()
|
||||
maxExpiry := now.Add(h.cfg.MaxMachineRegistrationDuration) // calculate the maximum expiry
|
||||
defaultExpiry := now.Add(h.cfg.DefaultMachineRegistrationDuration) // calculate the default expiry
|
||||
maxExpiry := now.Add(
|
||||
h.cfg.MaxMachineRegistrationDuration,
|
||||
) // calculate the maximum expiry
|
||||
defaultExpiry := now.Add(
|
||||
h.cfg.DefaultMachineRegistrationDuration,
|
||||
) // calculate the default expiry
|
||||
|
||||
// clamp the expiry time of the machine registration to the maximum allowed, or use the default if none supplied
|
||||
if maxExpiry.Before(*m.RequestedExpiry) {
|
||||
if maxExpiry.Before(*machine.RequestedExpiry) {
|
||||
log.Debug().
|
||||
Msgf("Clamping registration expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineRegistrationDuration)
|
||||
m.Expiry = &maxExpiry
|
||||
} else if m.RequestedExpiry.IsZero() {
|
||||
machine.Expiry = &maxExpiry
|
||||
} else if machine.RequestedExpiry.IsZero() {
|
||||
log.Debug().Msgf("Using default machine registration expiry time: %v (%v)", defaultExpiry, h.cfg.DefaultMachineRegistrationDuration)
|
||||
m.Expiry = &defaultExpiry
|
||||
machine.Expiry = &defaultExpiry
|
||||
} else {
|
||||
log.Debug().Msgf("Using requested machine registration expiry time: %v", m.RequestedExpiry)
|
||||
m.Expiry = m.RequestedExpiry
|
||||
log.Debug().Msgf("Using requested machine registration expiry time: %v", machine.RequestedExpiry)
|
||||
machine.Expiry = machine.RequestedExpiry
|
||||
}
|
||||
|
||||
h.db.Save(&m)
|
||||
h.db.Save(&machine)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
|
||||
func (h *Headscale) getDirectPeers(machine *Machine) (Machines, error) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Finding direct peers")
|
||||
|
||||
machines := Machines{}
|
||||
if err := h.db.Preload("Namespace").Where("namespace_id = ? AND machine_key <> ? AND registered",
|
||||
m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil {
|
||||
machine.NamespaceID, machine.MachineKey).Find(&machines).Error; err != nil {
|
||||
log.Error().Err(err).Msg("Error accessing db")
|
||||
|
||||
return Machines{}, err
|
||||
}
|
||||
|
||||
|
@ -110,21 +120,22 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
|
|||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msgf("Found direct machines: %s", machines.String())
|
||||
|
||||
return machines, nil
|
||||
}
|
||||
|
||||
// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for
|
||||
func (h *Headscale) getShared(m *Machine) (Machines, error) {
|
||||
// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for.
|
||||
func (h *Headscale) getShared(machine *Machine) (Machines, error) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Finding shared peers")
|
||||
|
||||
sharedMachines := []SharedMachine{}
|
||||
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("namespace_id = ?",
|
||||
m.NamespaceID).Find(&sharedMachines).Error; err != nil {
|
||||
machine.NamespaceID).Find(&sharedMachines).Error; err != nil {
|
||||
return Machines{}, err
|
||||
}
|
||||
|
||||
|
@ -137,27 +148,30 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) {
|
|||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msgf("Found shared peers: %s", peers.String())
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// getSharedTo fetches the machines of the namespaces this machine is shared in
|
||||
func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
|
||||
// getSharedTo fetches the machines of the namespaces this machine is shared in.
|
||||
func (h *Headscale) getSharedTo(machine *Machine) (Machines, error) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Finding peers in namespaces this machine is shared with")
|
||||
|
||||
sharedMachines := []SharedMachine{}
|
||||
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("machine_id = ?",
|
||||
m.ID).Find(&sharedMachines).Error; err != nil {
|
||||
machine.ID).Find(&sharedMachines).Error; err != nil {
|
||||
return Machines{}, err
|
||||
}
|
||||
|
||||
peers := make(Machines, 0)
|
||||
for _, sharedMachine := range sharedMachines {
|
||||
namespaceMachines, err := h.ListMachinesInNamespace(sharedMachine.Namespace.Name)
|
||||
namespaceMachines, err := h.ListMachinesInNamespace(
|
||||
sharedMachine.Namespace.Name,
|
||||
)
|
||||
if err != nil {
|
||||
return Machines{}, err
|
||||
}
|
||||
|
@ -168,36 +182,40 @@ func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
|
|||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msgf("Found peers we are shared with: %s", peers.String())
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) getPeers(m *Machine) (Machines, error) {
|
||||
direct, err := h.getDirectPeers(m)
|
||||
func (h *Headscale) getPeers(machine *Machine) (Machines, error) {
|
||||
direct, err := h.getDirectPeers(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot fetch peers")
|
||||
|
||||
return Machines{}, err
|
||||
}
|
||||
|
||||
shared, err := h.getShared(m)
|
||||
shared, err := h.getShared(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot fetch peers")
|
||||
|
||||
return Machines{}, err
|
||||
}
|
||||
|
||||
sharedTo, err := h.getSharedTo(m)
|
||||
sharedTo, err := h.getSharedTo(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot fetch peers")
|
||||
|
||||
return Machines{}, err
|
||||
}
|
||||
|
||||
|
@ -208,7 +226,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
|
|||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msgf("Found total peers: %s", peers.String())
|
||||
|
||||
return peers, nil
|
||||
|
@ -219,10 +237,11 @@ func (h *Headscale) ListMachines() ([]Machine, error) {
|
|||
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Find(&machines).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return machines, nil
|
||||
}
|
||||
|
||||
// GetMachine finds a Machine by name and namespace and returns the Machine struct
|
||||
// GetMachine finds a Machine by name and namespace and returns the Machine struct.
|
||||
func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error) {
|
||||
machines, err := h.ListMachinesInNamespace(namespace)
|
||||
if err != nil {
|
||||
|
@ -234,73 +253,77 @@ func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error)
|
|||
return &m, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("machine not found")
|
||||
|
||||
return nil, errMachineNotFound
|
||||
}
|
||||
|
||||
// GetMachineByID finds a Machine by ID and returns the Machine struct
|
||||
// GetMachineByID finds a Machine by ID and returns the Machine struct.
|
||||
func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
|
||||
m := Machine{}
|
||||
if result := h.db.Preload("Namespace").Find(&Machine{ID: id}).First(&m); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct
|
||||
func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) {
|
||||
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
|
||||
func (h *Headscale) GetMachineByMachineKey(machineKey string) (*Machine, error) {
|
||||
m := Machine{}
|
||||
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey); result.Error != nil {
|
||||
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", machineKey); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
// UpdateMachine takes a Machine struct pointer (typically already loaded from database
|
||||
// and updates it with the latest data from the database.
|
||||
func (h *Headscale) UpdateMachine(m *Machine) error {
|
||||
if result := h.db.Find(m).First(&m); result.Error != nil {
|
||||
func (h *Headscale) UpdateMachine(machine *Machine) error {
|
||||
if result := h.db.Find(machine).First(&machine); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteMachine softs deletes a Machine from the database
|
||||
func (h *Headscale) DeleteMachine(m *Machine) error {
|
||||
err := h.RemoveSharedMachineFromAllNamespaces(m)
|
||||
if err != nil && err != errorMachineNotShared {
|
||||
// DeleteMachine softs deletes a Machine from the database.
|
||||
func (h *Headscale) DeleteMachine(machine *Machine) error {
|
||||
err := h.RemoveSharedMachineFromAllNamespaces(machine)
|
||||
if err != nil && errors.Is(err, errMachineNotShared) {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Registered = false
|
||||
namespaceID := m.NamespaceID
|
||||
h.db.Save(&m) // we mark it as unregistered, just in case
|
||||
if err := h.db.Delete(&m).Error; err != nil {
|
||||
machine.Registered = false
|
||||
namespaceID := machine.NamespaceID
|
||||
h.db.Save(&machine) // we mark it as unregistered, just in case
|
||||
if err := h.db.Delete(&machine).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return h.RequestMapUpdates(namespaceID)
|
||||
}
|
||||
|
||||
// HardDeleteMachine hard deletes a Machine from the database
|
||||
func (h *Headscale) HardDeleteMachine(m *Machine) error {
|
||||
err := h.RemoveSharedMachineFromAllNamespaces(m)
|
||||
if err != nil && err != errorMachineNotShared {
|
||||
// HardDeleteMachine hard deletes a Machine from the database.
|
||||
func (h *Headscale) HardDeleteMachine(machine *Machine) error {
|
||||
err := h.RemoveSharedMachineFromAllNamespaces(machine)
|
||||
if err != nil && errors.Is(err, errMachineNotShared) {
|
||||
return err
|
||||
}
|
||||
|
||||
namespaceID := m.NamespaceID
|
||||
if err := h.db.Unscoped().Delete(&m).Error; err != nil {
|
||||
namespaceID := machine.NamespaceID
|
||||
if err := h.db.Unscoped().Delete(&machine).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return h.RequestMapUpdates(namespaceID)
|
||||
}
|
||||
|
||||
// GetHostInfo returns a Hostinfo struct for the machine
|
||||
func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
|
||||
// GetHostInfo returns a Hostinfo struct for the machine.
|
||||
func (machine *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
|
||||
hostinfo := tailcfg.Hostinfo{}
|
||||
if len(m.HostInfo) != 0 {
|
||||
hi, err := m.HostInfo.MarshalJSON()
|
||||
if len(machine.HostInfo) != 0 {
|
||||
hi, err := machine.HostInfo.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -309,21 +332,21 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &hostinfo, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) isOutdated(m *Machine) bool {
|
||||
err := h.UpdateMachine(m)
|
||||
if err != nil {
|
||||
func (h *Headscale) isOutdated(machine *Machine) bool {
|
||||
if err := h.UpdateMachine(machine); err != nil {
|
||||
// It does not seem meaningful to propagate this error as the end result
|
||||
// will have to be that the machine has to be considered outdated.
|
||||
return true
|
||||
}
|
||||
|
||||
sharedMachines, _ := h.getShared(m)
|
||||
sharedMachines, _ := h.getShared(machine)
|
||||
|
||||
namespaceSet := set.New(set.ThreadSafe)
|
||||
namespaceSet.Add(m.Namespace.Name)
|
||||
namespaceSet.Add(machine.Namespace.Name)
|
||||
|
||||
// Check if any of our shared namespaces has updates that we have
|
||||
// not propagated.
|
||||
|
@ -333,27 +356,30 @@ func (h *Headscale) isOutdated(m *Machine) bool {
|
|||
|
||||
namespaces := make([]string, namespaceSet.Size())
|
||||
for index, namespace := range namespaceSet.List() {
|
||||
namespaces[index] = namespace.(string)
|
||||
if name, ok := namespace.(string); ok {
|
||||
namespaces[index] = name
|
||||
}
|
||||
}
|
||||
|
||||
lastChange := h.getLastStateChange(namespaces...)
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", m.Name).
|
||||
Time("last_successful_update", *m.LastSuccessfulUpdate).
|
||||
Str("machine", machine.Name).
|
||||
Time("last_successful_update", *machine.LastSuccessfulUpdate).
|
||||
Time("last_state_change", lastChange).
|
||||
Msgf("Checking if %s is missing updates", m.Name)
|
||||
return m.LastSuccessfulUpdate.Before(lastChange)
|
||||
Msgf("Checking if %s is missing updates", machine.Name)
|
||||
|
||||
return machine.LastSuccessfulUpdate.Before(lastChange)
|
||||
}
|
||||
|
||||
func (m Machine) String() string {
|
||||
return m.Name
|
||||
func (machine Machine) String() string {
|
||||
return machine.Name
|
||||
}
|
||||
|
||||
func (ms Machines) String() string {
|
||||
temp := make([]string, len(ms))
|
||||
func (machines Machines) String() string {
|
||||
temp := make([]string, len(machines))
|
||||
|
||||
for index, machine := range ms {
|
||||
for index, machine := range machines {
|
||||
temp[index] = machine.Name
|
||||
}
|
||||
|
||||
|
@ -361,24 +387,24 @@ func (ms Machines) String() string {
|
|||
}
|
||||
|
||||
// TODO(kradalby): Remove when we have generics...
|
||||
func (ms MachinesP) String() string {
|
||||
temp := make([]string, len(ms))
|
||||
func (machines MachinesP) String() string {
|
||||
temp := make([]string, len(machines))
|
||||
|
||||
for index, machine := range ms {
|
||||
for index, machine := range machines {
|
||||
temp[index] = machine.Name
|
||||
}
|
||||
|
||||
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
|
||||
}
|
||||
|
||||
func (ms Machines) toNodes(
|
||||
func (machines Machines) toNodes(
|
||||
baseDomain string,
|
||||
dnsConfig *tailcfg.DNSConfig,
|
||||
includeRoutes bool,
|
||||
) ([]*tailcfg.Node, error) {
|
||||
nodes := make([]*tailcfg.Node, len(ms))
|
||||
nodes := make([]*tailcfg.Node, len(machines))
|
||||
|
||||
for index, machine := range ms {
|
||||
for index, machine := range machines {
|
||||
node, err := machine.toNode(baseDomain, dnsConfig, includeRoutes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -391,20 +417,25 @@ func (ms Machines) toNodes(
|
|||
}
|
||||
|
||||
// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
|
||||
// as per the expected behaviour in the official SaaS
|
||||
func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, includeRoutes bool) (*tailcfg.Node, error) {
|
||||
nKey, err := wgkey.ParseHex(m.NodeKey)
|
||||
// as per the expected behaviour in the official SaaS.
|
||||
func (machine Machine) toNode(
|
||||
baseDomain string,
|
||||
dnsConfig *tailcfg.DNSConfig,
|
||||
includeRoutes bool,
|
||||
) (*tailcfg.Node, error) {
|
||||
nodeKey, err := wgkey.ParseHex(machine.NodeKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mKey, err := wgkey.ParseHex(m.MachineKey)
|
||||
|
||||
machineKey, err := wgkey.ParseHex(machine.MachineKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var discoKey tailcfg.DiscoKey
|
||||
if m.DiscoKey != "" {
|
||||
dKey, err := wgkey.ParseHex(m.DiscoKey)
|
||||
if machine.DiscoKey != "" {
|
||||
dKey, err := wgkey.ParseHex(machine.DiscoKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -414,23 +445,27 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
|
|||
}
|
||||
|
||||
addrs := []netaddr.IPPrefix{}
|
||||
ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", m.IPAddress))
|
||||
ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", machine.IPAddress))
|
||||
if err != nil {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("ip", m.IPAddress).
|
||||
Msgf("Failed to parse IP Prefix from IP: %s", m.IPAddress)
|
||||
Str("ip", machine.IPAddress).
|
||||
Msgf("Failed to parse IP Prefix from IP: %s", machine.IPAddress)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
addrs = append(addrs, ip) // missing the ipv6 ?
|
||||
|
||||
allowedIPs := []netaddr.IPPrefix{}
|
||||
allowedIPs = append(allowedIPs, ip) // we append the node own IP, as it is required by the clients
|
||||
allowedIPs = append(
|
||||
allowedIPs,
|
||||
ip,
|
||||
) // we append the node own IP, as it is required by the clients
|
||||
|
||||
if includeRoutes {
|
||||
routesStr := []string{}
|
||||
if len(m.EnabledRoutes) != 0 {
|
||||
allwIps, err := m.EnabledRoutes.MarshalJSON()
|
||||
if len(machine.EnabledRoutes) != 0 {
|
||||
allwIps, err := machine.EnabledRoutes.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -450,8 +485,8 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
|
|||
}
|
||||
|
||||
endpoints := []string{}
|
||||
if len(m.Endpoints) != 0 {
|
||||
be, err := m.Endpoints.MarshalJSON()
|
||||
if len(machine.Endpoints) != 0 {
|
||||
be, err := machine.Endpoints.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -462,8 +497,8 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
|
|||
}
|
||||
|
||||
hostinfo := tailcfg.Hostinfo{}
|
||||
if len(m.HostInfo) != 0 {
|
||||
hi, err := m.HostInfo.MarshalJSON()
|
||||
if len(machine.HostInfo) != 0 {
|
||||
hi, err := machine.HostInfo.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -481,29 +516,34 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
|
|||
}
|
||||
|
||||
var keyExpiry time.Time
|
||||
if m.Expiry != nil {
|
||||
keyExpiry = *m.Expiry
|
||||
if machine.Expiry != nil {
|
||||
keyExpiry = *machine.Expiry
|
||||
} else {
|
||||
keyExpiry = time.Time{}
|
||||
}
|
||||
|
||||
var hostname string
|
||||
if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS
|
||||
hostname = fmt.Sprintf("%s.%s.%s", m.Name, m.Namespace.Name, baseDomain)
|
||||
hostname = fmt.Sprintf(
|
||||
"%s.%s.%s",
|
||||
machine.Name,
|
||||
machine.Namespace.Name,
|
||||
baseDomain,
|
||||
)
|
||||
} else {
|
||||
hostname = m.Name
|
||||
hostname = machine.Name
|
||||
}
|
||||
|
||||
n := tailcfg.Node{
|
||||
ID: tailcfg.NodeID(m.ID), // this is the actual ID
|
||||
node := tailcfg.Node{
|
||||
ID: tailcfg.NodeID(machine.ID), // this is the actual ID
|
||||
StableID: tailcfg.StableNodeID(
|
||||
strconv.FormatUint(m.ID, 10),
|
||||
strconv.FormatUint(machine.ID, Base10),
|
||||
), // in headscale, unlike tailcontrol server, IDs are permanent
|
||||
Name: hostname,
|
||||
User: tailcfg.UserID(m.NamespaceID),
|
||||
Key: tailcfg.NodeKey(nKey),
|
||||
User: tailcfg.UserID(machine.NamespaceID),
|
||||
Key: tailcfg.NodeKey(nodeKey),
|
||||
KeyExpiry: keyExpiry,
|
||||
Machine: tailcfg.MachineKey(mKey),
|
||||
Machine: tailcfg.MachineKey(machineKey),
|
||||
DiscoKey: discoKey,
|
||||
Addresses: addrs,
|
||||
AllowedIPs: allowedIPs,
|
||||
|
@ -511,81 +551,90 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
|
|||
DERP: derp,
|
||||
|
||||
Hostinfo: hostinfo,
|
||||
Created: m.CreatedAt,
|
||||
LastSeen: m.LastSeen,
|
||||
Created: machine.CreatedAt,
|
||||
LastSeen: machine.LastSeen,
|
||||
|
||||
KeepAlive: true,
|
||||
MachineAuthorized: m.Registered,
|
||||
MachineAuthorized: machine.Registered,
|
||||
Capabilities: []string{tailcfg.CapabilityFileSharing},
|
||||
}
|
||||
return &n, nil
|
||||
|
||||
return &node, nil
|
||||
}
|
||||
|
||||
func (m *Machine) toProto() *v1.Machine {
|
||||
machine := &v1.Machine{
|
||||
Id: m.ID,
|
||||
MachineKey: m.MachineKey,
|
||||
func (machine *Machine) toProto() *v1.Machine {
|
||||
machineProto := &v1.Machine{
|
||||
Id: machine.ID,
|
||||
MachineKey: machine.MachineKey,
|
||||
|
||||
NodeKey: m.NodeKey,
|
||||
DiscoKey: m.DiscoKey,
|
||||
IpAddress: m.IPAddress,
|
||||
Name: m.Name,
|
||||
Namespace: m.Namespace.toProto(),
|
||||
NodeKey: machine.NodeKey,
|
||||
DiscoKey: machine.DiscoKey,
|
||||
IpAddress: machine.IPAddress,
|
||||
Name: machine.Name,
|
||||
Namespace: machine.Namespace.toProto(),
|
||||
|
||||
Registered: m.Registered,
|
||||
Registered: machine.Registered,
|
||||
|
||||
// TODO(kradalby): Implement register method enum converter
|
||||
// RegisterMethod: ,
|
||||
|
||||
CreatedAt: timestamppb.New(m.CreatedAt),
|
||||
CreatedAt: timestamppb.New(machine.CreatedAt),
|
||||
}
|
||||
|
||||
if m.AuthKey != nil {
|
||||
machine.PreAuthKey = m.AuthKey.toProto()
|
||||
if machine.AuthKey != nil {
|
||||
machineProto.PreAuthKey = machine.AuthKey.toProto()
|
||||
}
|
||||
|
||||
if m.LastSeen != nil {
|
||||
machine.LastSeen = timestamppb.New(*m.LastSeen)
|
||||
if machine.LastSeen != nil {
|
||||
machineProto.LastSeen = timestamppb.New(*machine.LastSeen)
|
||||
}
|
||||
|
||||
if m.LastSuccessfulUpdate != nil {
|
||||
machine.LastSuccessfulUpdate = timestamppb.New(*m.LastSuccessfulUpdate)
|
||||
if machine.LastSuccessfulUpdate != nil {
|
||||
machineProto.LastSuccessfulUpdate = timestamppb.New(
|
||||
*machine.LastSuccessfulUpdate,
|
||||
)
|
||||
}
|
||||
|
||||
if m.Expiry != nil {
|
||||
machine.Expiry = timestamppb.New(*m.Expiry)
|
||||
if machine.Expiry != nil {
|
||||
machineProto.Expiry = timestamppb.New(*machine.Expiry)
|
||||
}
|
||||
|
||||
return machine
|
||||
return machineProto
|
||||
}
|
||||
|
||||
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey
|
||||
func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, error) {
|
||||
ns, err := h.GetNamespace(namespace)
|
||||
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
|
||||
func (h *Headscale) RegisterMachine(
|
||||
key string,
|
||||
namespaceName string,
|
||||
) (*Machine, error) {
|
||||
namespace, err := h.GetNamespace(namespaceName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mKey, err := wgkey.ParseHex(key)
|
||||
machineKey, err := wgkey.ParseHex(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := Machine{}
|
||||
if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("Machine not found")
|
||||
machine := Machine{}
|
||||
if result := h.db.First(&machine, "machine_key = ?", machineKey.HexString()); errors.Is(
|
||||
result.Error,
|
||||
gorm.ErrRecordNotFound,
|
||||
) {
|
||||
return nil, errMachineNotFound
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Attempting to register machine")
|
||||
|
||||
if m.isAlreadyRegistered() {
|
||||
err := errors.New("Machine already registered")
|
||||
if machine.isAlreadyRegistered() {
|
||||
err := errMachineAlreadyRegistered
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Attempting to register machine")
|
||||
|
||||
return nil, err
|
||||
|
@ -596,42 +645,44 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
|
|||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Could not find IP for the new machine")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("ip", ip.String()).
|
||||
Msg("Found IP for host")
|
||||
|
||||
m.IPAddress = ip.String()
|
||||
m.NamespaceID = ns.ID
|
||||
m.Registered = true
|
||||
m.RegisterMethod = "cli"
|
||||
h.db.Save(&m)
|
||||
machine.IPAddress = ip.String()
|
||||
machine.NamespaceID = namespace.ID
|
||||
machine.Registered = true
|
||||
machine.RegisterMethod = "cli"
|
||||
h.db.Save(&machine)
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("ip", ip.String()).
|
||||
Msg("Machine registered with the database")
|
||||
|
||||
return &m, nil
|
||||
return &machine, nil
|
||||
}
|
||||
|
||||
func (m *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
|
||||
hostInfo, err := m.GetHostInfo()
|
||||
func (machine *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
|
||||
hostInfo, err := machine.GetHostInfo()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hostInfo.RoutableIPs, nil
|
||||
}
|
||||
|
||||
func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
|
||||
data, err := m.EnabledRoutes.MarshalJSON()
|
||||
func (machine *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
|
||||
data, err := machine.EnabledRoutes.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -654,13 +705,13 @@ func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
|
|||
return routes, nil
|
||||
}
|
||||
|
||||
func (m *Machine) IsRoutesEnabled(routeStr string) bool {
|
||||
func (machine *Machine) IsRoutesEnabled(routeStr string) bool {
|
||||
route, err := netaddr.ParseIPPrefix(routeStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
enabledRoutes, err := m.GetEnabledRoutes()
|
||||
enabledRoutes, err := machine.GetEnabledRoutes()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
@ -670,12 +721,13 @@ func (m *Machine) IsRoutesEnabled(routeStr string) bool {
|
|||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// EnableNodeRoute enables new routes based on a list of new routes. It will _replace_ the
|
||||
// previous list of routes.
|
||||
func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
|
||||
func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error {
|
||||
newRoutes := make([]netaddr.IPPrefix, len(routeStrs))
|
||||
for index, routeStr := range routeStrs {
|
||||
route, err := netaddr.ParseIPPrefix(routeStr)
|
||||
|
@ -686,14 +738,18 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
|
|||
newRoutes[index] = route
|
||||
}
|
||||
|
||||
availableRoutes, err := m.GetAdvertisedRoutes()
|
||||
availableRoutes, err := machine.GetAdvertisedRoutes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
if !containsIpPrefix(availableRoutes, newRoute) {
|
||||
return fmt.Errorf("route (%s) is not available on node %s", m.Name, newRoute)
|
||||
if !containsIPPrefix(availableRoutes, newRoute) {
|
||||
return fmt.Errorf(
|
||||
"route (%s) is not available on node %s: %w",
|
||||
machine.Name,
|
||||
newRoute, errMachineRouteIsNotAvailable,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -702,10 +758,10 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
m.EnabledRoutes = datatypes.JSON(routes)
|
||||
h.db.Save(&m)
|
||||
machine.EnabledRoutes = datatypes.JSON(routes)
|
||||
h.db.Save(&machine)
|
||||
|
||||
err = h.RequestMapUpdates(m.NamespaceID)
|
||||
err = h.RequestMapUpdates(machine.NamespaceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -713,13 +769,13 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (m *Machine) RoutesToProto() (*v1.Routes, error) {
|
||||
availableRoutes, err := m.GetAdvertisedRoutes()
|
||||
func (machine *Machine) RoutesToProto() (*v1.Routes, error) {
|
||||
availableRoutes, err := machine.GetAdvertisedRoutes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
enabledRoutes, err := m.GetEnabledRoutes()
|
||||
enabledRoutes, err := machine.GetEnabledRoutes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
113
machine_test.go
113
machine_test.go
|
@ -8,152 +8,159 @@ import (
|
|||
)
|
||||
|
||||
func (s *Suite) TestGetMachine(c *check.C) {
|
||||
n, err := h.CreateNamespace("test")
|
||||
namespace, err := app.CreateNamespace("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = h.GetMachine("test", "testmachine")
|
||||
_, err = app.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
m := &Machine{
|
||||
machine := &Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "testmachine",
|
||||
NamespaceID: n.ID,
|
||||
NamespaceID: namespace.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
h.db.Save(m)
|
||||
app.db.Save(machine)
|
||||
|
||||
m1, err := h.GetMachine("test", "testmachine")
|
||||
machineFromDB, err := app.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = m1.GetHostInfo()
|
||||
_, err = machineFromDB.GetHostInfo()
|
||||
c.Assert(err, check.IsNil)
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMachineByID(c *check.C) {
|
||||
n, err := h.CreateNamespace("test")
|
||||
namespace, err := app.CreateNamespace("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = h.GetMachineByID(0)
|
||||
_, err = app.GetMachineByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
m := Machine{
|
||||
machine := Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "testmachine",
|
||||
NamespaceID: n.ID,
|
||||
NamespaceID: namespace.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
h.db.Save(&m)
|
||||
app.db.Save(&machine)
|
||||
|
||||
m1, err := h.GetMachineByID(0)
|
||||
machineByID, err := app.GetMachineByID(0)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = m1.GetHostInfo()
|
||||
_, err = machineByID.GetHostInfo()
|
||||
c.Assert(err, check.IsNil)
|
||||
}
|
||||
|
||||
func (s *Suite) TestDeleteMachine(c *check.C) {
|
||||
n, err := h.CreateNamespace("test")
|
||||
namespace, err := app.CreateNamespace("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
m := Machine{
|
||||
machine := Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "testmachine",
|
||||
NamespaceID: n.ID,
|
||||
NamespaceID: namespace.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
AuthKeyID: uint(1),
|
||||
}
|
||||
h.db.Save(&m)
|
||||
err = h.DeleteMachine(&m)
|
||||
app.db.Save(&machine)
|
||||
|
||||
err = app.DeleteMachine(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
v, err := h.getValue("namespaces_pending_updates")
|
||||
|
||||
namespacesPendingUpdates, err := app.getValue("namespaces_pending_updates")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
names := []string{}
|
||||
err = json.Unmarshal([]byte(v), &names)
|
||||
err = json.Unmarshal([]byte(namespacesPendingUpdates), &names)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(names, check.DeepEquals, []string{n.Name})
|
||||
h.checkForNamespacesPendingUpdates()
|
||||
v, _ = h.getValue("namespaces_pending_updates")
|
||||
c.Assert(v, check.Equals, "")
|
||||
_, err = h.GetMachine(n.Name, "testmachine")
|
||||
c.Assert(names, check.DeepEquals, []string{namespace.Name})
|
||||
|
||||
app.checkForNamespacesPendingUpdates()
|
||||
|
||||
namespacesPendingUpdates, _ = app.getValue("namespaces_pending_updates")
|
||||
c.Assert(namespacesPendingUpdates, check.Equals, "")
|
||||
_, err = app.GetMachine(namespace.Name, "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
}
|
||||
|
||||
func (s *Suite) TestHardDeleteMachine(c *check.C) {
|
||||
n, err := h.CreateNamespace("test")
|
||||
namespace, err := app.CreateNamespace("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
m := Machine{
|
||||
machine := Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "testmachine3",
|
||||
NamespaceID: n.ID,
|
||||
NamespaceID: namespace.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
AuthKeyID: uint(1),
|
||||
}
|
||||
h.db.Save(&m)
|
||||
err = h.HardDeleteMachine(&m)
|
||||
app.db.Save(&machine)
|
||||
|
||||
err = app.HardDeleteMachine(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
_, err = h.GetMachine(n.Name, "testmachine3")
|
||||
|
||||
_, err = app.GetMachine(namespace.Name, "testmachine3")
|
||||
c.Assert(err, check.NotNil)
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetDirectPeers(c *check.C) {
|
||||
n, err := h.CreateNamespace("test")
|
||||
namespace, err := app.CreateNamespace("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = h.GetMachineByID(0)
|
||||
_, err = app.GetMachineByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
for i := 0; i <= 10; i++ {
|
||||
m := Machine{
|
||||
ID: uint64(i),
|
||||
MachineKey: "foo" + strconv.Itoa(i),
|
||||
NodeKey: "bar" + strconv.Itoa(i),
|
||||
DiscoKey: "faa" + strconv.Itoa(i),
|
||||
Name: "testmachine" + strconv.Itoa(i),
|
||||
NamespaceID: n.ID,
|
||||
for index := 0; index <= 10; index++ {
|
||||
machine := Machine{
|
||||
ID: uint64(index),
|
||||
MachineKey: "foo" + strconv.Itoa(index),
|
||||
NodeKey: "bar" + strconv.Itoa(index),
|
||||
DiscoKey: "faa" + strconv.Itoa(index),
|
||||
Name: "testmachine" + strconv.Itoa(index),
|
||||
NamespaceID: namespace.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
h.db.Save(&m)
|
||||
app.db.Save(&machine)
|
||||
}
|
||||
|
||||
m1, err := h.GetMachineByID(0)
|
||||
machine0ByID, err := app.GetMachineByID(0)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = m1.GetHostInfo()
|
||||
_, err = machine0ByID.GetHostInfo()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
peers, err := h.getDirectPeers(m1)
|
||||
peersOfMachine0, err := app.getDirectPeers(machine0ByID)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(peers), check.Equals, 9)
|
||||
c.Assert(peers[0].Name, check.Equals, "testmachine2")
|
||||
c.Assert(peers[5].Name, check.Equals, "testmachine7")
|
||||
c.Assert(peers[8].Name, check.Equals, "testmachine10")
|
||||
c.Assert(len(peersOfMachine0), check.Equals, 9)
|
||||
c.Assert(peersOfMachine0[0].Name, check.Equals, "testmachine2")
|
||||
c.Assert(peersOfMachine0[5].Name, check.Equals, "testmachine7")
|
||||
c.Assert(peersOfMachine0[8].Name, check.Equals, "testmachine10")
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ var (
|
|||
Name: "update_request_sent_to_node_total",
|
||||
Help: "The number of calls/messages issued on a specific nodes update channel",
|
||||
}, []string{"namespace", "machine", "status"})
|
||||
//TODO(kradalby): This is very debugging, we might want to remove it.
|
||||
// TODO(kradalby): This is very debugging, we might want to remove it.
|
||||
updateRequestsReceivedOnChannel = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "update_request_received_on_channel_total",
|
||||
|
|
147
namespaces.go
147
namespaces.go
|
@ -15,9 +15,9 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
errorNamespaceExists = Error("Namespace already exists")
|
||||
errorNamespaceNotFound = Error("Namespace not found")
|
||||
errorNamespaceNotEmptyOfNodes = Error("Namespace not empty: node(s) found")
|
||||
errNamespaceExists = Error("Namespace already exists")
|
||||
errNamespaceNotFound = Error("Namespace not found")
|
||||
errNamespaceNotEmptyOfNodes = Error("Namespace not empty: node(s) found")
|
||||
)
|
||||
|
||||
// Namespace is the way Headscale implements the concept of users in Tailscale
|
||||
|
@ -30,51 +30,53 @@ type Namespace struct {
|
|||
}
|
||||
|
||||
// CreateNamespace creates a new Namespace. Returns error if could not be created
|
||||
// or another namespace already exists
|
||||
// or another namespace already exists.
|
||||
func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
|
||||
n := Namespace{}
|
||||
if err := h.db.Where("name = ?", name).First(&n).Error; err == nil {
|
||||
return nil, errorNamespaceExists
|
||||
namespace := Namespace{}
|
||||
if err := h.db.Where("name = ?", name).First(&namespace).Error; err == nil {
|
||||
return nil, errNamespaceExists
|
||||
}
|
||||
n.Name = name
|
||||
if err := h.db.Create(&n).Error; err != nil {
|
||||
namespace.Name = name
|
||||
if err := h.db.Create(&namespace).Error; err != nil {
|
||||
log.Error().
|
||||
Str("func", "CreateNamespace").
|
||||
Err(err).
|
||||
Msg("Could not create row")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
return &n, nil
|
||||
|
||||
return &namespace, nil
|
||||
}
|
||||
|
||||
// DestroyNamespace destroys a Namespace. Returns error if the Namespace does
|
||||
// not exist or if there are machines associated with it.
|
||||
func (h *Headscale) DestroyNamespace(name string) error {
|
||||
n, err := h.GetNamespace(name)
|
||||
namespace, err := h.GetNamespace(name)
|
||||
if err != nil {
|
||||
return errorNamespaceNotFound
|
||||
return errNamespaceNotFound
|
||||
}
|
||||
|
||||
m, err := h.ListMachinesInNamespace(name)
|
||||
machines, err := h.ListMachinesInNamespace(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(m) > 0 {
|
||||
return errorNamespaceNotEmptyOfNodes
|
||||
if len(machines) > 0 {
|
||||
return errNamespaceNotEmptyOfNodes
|
||||
}
|
||||
|
||||
keys, err := h.ListPreAuthKeys(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, p := range keys {
|
||||
err = h.DestroyPreAuthKey(&p)
|
||||
for _, key := range keys {
|
||||
err = h.DestroyPreAuthKey(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if result := h.db.Unscoped().Delete(&n); result.Error != nil {
|
||||
if result := h.db.Unscoped().Delete(&namespace); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
|
@ -84,25 +86,25 @@ func (h *Headscale) DestroyNamespace(name string) error {
|
|||
// RenameNamespace renames a Namespace. Returns error if the Namespace does
|
||||
// not exist or if another Namespace exists with the new name.
|
||||
func (h *Headscale) RenameNamespace(oldName, newName string) error {
|
||||
n, err := h.GetNamespace(oldName)
|
||||
oldNamespace, err := h.GetNamespace(oldName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = h.GetNamespace(newName)
|
||||
if err == nil {
|
||||
return errorNamespaceExists
|
||||
return errNamespaceExists
|
||||
}
|
||||
if !errors.Is(err, errorNamespaceNotFound) {
|
||||
if !errors.Is(err, errNamespaceNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
n.Name = newName
|
||||
oldNamespace.Name = newName
|
||||
|
||||
if result := h.db.Save(&n); result.Error != nil {
|
||||
if result := h.db.Save(&oldNamespace); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
err = h.RequestMapUpdates(n.ID)
|
||||
err = h.RequestMapUpdates(oldNamespace.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -110,39 +112,45 @@ func (h *Headscale) RenameNamespace(oldName, newName string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// GetNamespace fetches a namespace by name
|
||||
// GetNamespace fetches a namespace by name.
|
||||
func (h *Headscale) GetNamespace(name string) (*Namespace, error) {
|
||||
n := Namespace{}
|
||||
if result := h.db.First(&n, "name = ?", name); errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, errorNamespaceNotFound
|
||||
namespace := Namespace{}
|
||||
if result := h.db.First(&namespace, "name = ?", name); errors.Is(
|
||||
result.Error,
|
||||
gorm.ErrRecordNotFound,
|
||||
) {
|
||||
return nil, errNamespaceNotFound
|
||||
}
|
||||
return &n, nil
|
||||
|
||||
return &namespace, nil
|
||||
}
|
||||
|
||||
// ListNamespaces gets all the existing namespaces
|
||||
// ListNamespaces gets all the existing namespaces.
|
||||
func (h *Headscale) ListNamespaces() ([]Namespace, error) {
|
||||
namespaces := []Namespace{}
|
||||
if err := h.db.Find(&namespaces).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return namespaces, nil
|
||||
}
|
||||
|
||||
// ListMachinesInNamespace gets all the nodes in a given namespace
|
||||
// ListMachinesInNamespace gets all the nodes in a given namespace.
|
||||
func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) {
|
||||
n, err := h.GetNamespace(name)
|
||||
namespace, err := h.GetNamespace(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
machines := []Machine{}
|
||||
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil {
|
||||
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: namespace.ID}).Find(&machines).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return machines, nil
|
||||
}
|
||||
|
||||
// ListSharedMachinesInNamespace returns all the machines that are shared to the specified namespace
|
||||
// ListSharedMachinesInNamespace returns all the machines that are shared to the specified namespace.
|
||||
func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error) {
|
||||
namespace, err := h.GetNamespace(name)
|
||||
if err != nil {
|
||||
|
@ -155,48 +163,61 @@ func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error
|
|||
|
||||
machines := []Machine{}
|
||||
for _, sharedMachine := range sharedMachines {
|
||||
machine, err := h.GetMachineByID(sharedMachine.MachineID) // otherwise not everything comes filled
|
||||
machine, err := h.GetMachineByID(
|
||||
sharedMachine.MachineID,
|
||||
) // otherwise not everything comes filled
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
machines = append(machines, *machine)
|
||||
}
|
||||
|
||||
return machines, nil
|
||||
}
|
||||
|
||||
// SetMachineNamespace assigns a Machine to a namespace
|
||||
func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error {
|
||||
n, err := h.GetNamespace(namespaceName)
|
||||
// SetMachineNamespace assigns a Machine to a namespace.
|
||||
func (h *Headscale) SetMachineNamespace(machine *Machine, namespaceName string) error {
|
||||
namespace, err := h.GetNamespace(namespaceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.NamespaceID = n.ID
|
||||
h.db.Save(&m)
|
||||
machine.NamespaceID = namespace.ID
|
||||
h.db.Save(&machine)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RequestMapUpdates signals the KV worker to update the maps for this namespace
|
||||
// TODO(kradalby): Remove the need for this.
|
||||
// RequestMapUpdates signals the KV worker to update the maps for this namespace.
|
||||
func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
|
||||
namespace := Namespace{}
|
||||
if err := h.db.First(&namespace, namespaceID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v, err := h.getValue("namespaces_pending_updates")
|
||||
if err != nil || v == "" {
|
||||
err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name))
|
||||
namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates")
|
||||
if err != nil || namespacesPendingUpdates == "" {
|
||||
err = h.setValue(
|
||||
"namespaces_pending_updates",
|
||||
fmt.Sprintf(`["%s"]`, namespace.Name),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
names := []string{}
|
||||
err = json.Unmarshal([]byte(v), &names)
|
||||
err = json.Unmarshal([]byte(namespacesPendingUpdates), &names)
|
||||
if err != nil {
|
||||
err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name))
|
||||
err = h.setValue(
|
||||
"namespaces_pending_updates",
|
||||
fmt.Sprintf(`["%s"]`, namespace.Name),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -207,22 +228,24 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
|
|||
Str("func", "RequestMapUpdates").
|
||||
Err(err).
|
||||
Msg("Could not marshal namespaces_pending_updates")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return h.setValue("namespaces_pending_updates", string(data))
|
||||
}
|
||||
|
||||
func (h *Headscale) checkForNamespacesPendingUpdates() {
|
||||
v, err := h.getValue("namespaces_pending_updates")
|
||||
namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if v == "" {
|
||||
if namespacesPendingUpdates == "" {
|
||||
return
|
||||
}
|
||||
|
||||
namespaces := []string{}
|
||||
err = json.Unmarshal([]byte(v), &namespaces)
|
||||
err = json.Unmarshal([]byte(namespacesPendingUpdates), &namespaces)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -233,24 +256,25 @@ func (h *Headscale) checkForNamespacesPendingUpdates() {
|
|||
Msg("Sending updates to nodes in namespacespace")
|
||||
h.setLastStateChangeToNow(namespace)
|
||||
}
|
||||
newV, err := h.getValue("namespaces_pending_updates")
|
||||
newPendingUpdateValue, err := h.getValue("namespaces_pending_updates")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if v == newV { // only clear when no changes, so we notified everybody
|
||||
if namespacesPendingUpdates == newPendingUpdateValue { // only clear when no changes, so we notified everybody
|
||||
err = h.setValue("namespaces_pending_updates", "")
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("func", "checkForNamespacesPendingUpdates").
|
||||
Err(err).
|
||||
Msg("Could not save to KV")
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Namespace) toUser() *tailcfg.User {
|
||||
u := tailcfg.User{
|
||||
user := tailcfg.User{
|
||||
ID: tailcfg.UserID(n.ID),
|
||||
LoginName: n.Name,
|
||||
DisplayName: n.Name,
|
||||
|
@ -259,25 +283,27 @@ func (n *Namespace) toUser() *tailcfg.User {
|
|||
Logins: []tailcfg.LoginID{},
|
||||
Created: time.Time{},
|
||||
}
|
||||
return &u
|
||||
|
||||
return &user
|
||||
}
|
||||
|
||||
func (n *Namespace) toLogin() *tailcfg.Login {
|
||||
l := tailcfg.Login{
|
||||
login := tailcfg.Login{
|
||||
ID: tailcfg.LoginID(n.ID),
|
||||
LoginName: n.Name,
|
||||
DisplayName: n.Name,
|
||||
ProfilePicURL: "",
|
||||
Domain: "headscale.net",
|
||||
}
|
||||
return &l
|
||||
|
||||
return &login
|
||||
}
|
||||
|
||||
func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile {
|
||||
func getMapResponseUserProfiles(machine Machine, peers Machines) []tailcfg.UserProfile {
|
||||
namespaceMap := make(map[string]Namespace)
|
||||
namespaceMap[m.Namespace.Name] = m.Namespace
|
||||
for _, p := range peers {
|
||||
namespaceMap[p.Namespace.Name] = p.Namespace // not worth checking if already is there
|
||||
namespaceMap[machine.Namespace.Name] = machine.Namespace
|
||||
for _, peer := range peers {
|
||||
namespaceMap[peer.Namespace.Name] = peer.Namespace // not worth checking if already is there
|
||||
}
|
||||
|
||||
profiles := []tailcfg.UserProfile{}
|
||||
|
@ -289,12 +315,13 @@ func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile
|
|||
DisplayName: namespace.Name,
|
||||
})
|
||||
}
|
||||
|
||||
return profiles
|
||||
}
|
||||
|
||||
func (n *Namespace) toProto() *v1.Namespace {
|
||||
return &v1.Namespace{
|
||||
Id: strconv.FormatUint(uint64(n.ID), 10),
|
||||
Id: strconv.FormatUint(uint64(n.ID), Base10),
|
||||
Name: n.Name,
|
||||
CreatedAt: timestamppb.New(n.CreatedAt),
|
||||
}
|
||||
|
|
|
@ -7,207 +7,232 @@ import (
|
|||
)
|
||||
|
||||
func (s *Suite) TestCreateAndDestroyNamespace(c *check.C) {
|
||||
n, err := h.CreateNamespace("test")
|
||||
namespace, err := app.CreateNamespace("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(n.Name, check.Equals, "test")
|
||||
c.Assert(namespace.Name, check.Equals, "test")
|
||||
|
||||
ns, err := h.ListNamespaces()
|
||||
namespaces, err := app.ListNamespaces()
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(ns), check.Equals, 1)
|
||||
c.Assert(len(namespaces), check.Equals, 1)
|
||||
|
||||
err = h.DestroyNamespace("test")
|
||||
err = app.DestroyNamespace("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = h.GetNamespace("test")
|
||||
_, err = app.GetNamespace("test")
|
||||
c.Assert(err, check.NotNil)
|
||||
}
|
||||
|
||||
func (s *Suite) TestDestroyNamespaceErrors(c *check.C) {
|
||||
err := h.DestroyNamespace("test")
|
||||
c.Assert(err, check.Equals, errorNamespaceNotFound)
|
||||
err := app.DestroyNamespace("test")
|
||||
c.Assert(err, check.Equals, errNamespaceNotFound)
|
||||
|
||||
n, err := h.CreateNamespace("test")
|
||||
namespace, err := app.CreateNamespace("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = h.DestroyNamespace("test")
|
||||
err = app.DestroyNamespace("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
result := h.db.Preload("Namespace").First(&pak, "key = ?", pak.Key)
|
||||
result := app.db.Preload("Namespace").First(&pak, "key = ?", pak.Key)
|
||||
// destroying a namespace also deletes all associated preauthkeys
|
||||
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
|
||||
|
||||
n, err = h.CreateNamespace("test")
|
||||
namespace, err = app.CreateNamespace("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err = h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||
pak, err = app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
m := Machine{
|
||||
machine := Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "testmachine",
|
||||
NamespaceID: n.ID,
|
||||
NamespaceID: namespace.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
h.db.Save(&m)
|
||||
app.db.Save(&machine)
|
||||
|
||||
err = h.DestroyNamespace("test")
|
||||
c.Assert(err, check.Equals, errorNamespaceNotEmptyOfNodes)
|
||||
err = app.DestroyNamespace("test")
|
||||
c.Assert(err, check.Equals, errNamespaceNotEmptyOfNodes)
|
||||
}
|
||||
|
||||
func (s *Suite) TestRenameNamespace(c *check.C) {
|
||||
n, err := h.CreateNamespace("test")
|
||||
namespaceTest, err := app.CreateNamespace("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(n.Name, check.Equals, "test")
|
||||
c.Assert(namespaceTest.Name, check.Equals, "test")
|
||||
|
||||
ns, err := h.ListNamespaces()
|
||||
namespaces, err := app.ListNamespaces()
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(ns), check.Equals, 1)
|
||||
c.Assert(len(namespaces), check.Equals, 1)
|
||||
|
||||
err = h.RenameNamespace("test", "test_renamed")
|
||||
err = app.RenameNamespace("test", "test_renamed")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = h.GetNamespace("test")
|
||||
c.Assert(err, check.Equals, errorNamespaceNotFound)
|
||||
_, err = app.GetNamespace("test")
|
||||
c.Assert(err, check.Equals, errNamespaceNotFound)
|
||||
|
||||
_, err = h.GetNamespace("test_renamed")
|
||||
_, err = app.GetNamespace("test_renamed")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = h.RenameNamespace("test_does_not_exit", "test")
|
||||
c.Assert(err, check.Equals, errorNamespaceNotFound)
|
||||
err = app.RenameNamespace("test_does_not_exit", "test")
|
||||
c.Assert(err, check.Equals, errNamespaceNotFound)
|
||||
|
||||
n2, err := h.CreateNamespace("test2")
|
||||
namespaceTest2, err := app.CreateNamespace("test2")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(n2.Name, check.Equals, "test2")
|
||||
c.Assert(namespaceTest2.Name, check.Equals, "test2")
|
||||
|
||||
err = h.RenameNamespace("test2", "test_renamed")
|
||||
c.Assert(err, check.Equals, errorNamespaceExists)
|
||||
err = app.RenameNamespace("test2", "test_renamed")
|
||||
c.Assert(err, check.Equals, errNamespaceExists)
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
||||
n1, err := h.CreateNamespace("shared1")
|
||||
namespaceShared1, err := app.CreateNamespace("shared1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
n2, err := h.CreateNamespace("shared2")
|
||||
namespaceShared2, err := app.CreateNamespace("shared2")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
n3, err := h.CreateNamespace("shared3")
|
||||
namespaceShared3, err := app.CreateNamespace("shared3")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak1n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
||||
preAuthKeyShared1, err := app.CreatePreAuthKey(
|
||||
namespaceShared1.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak2n2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
|
||||
preAuthKeyShared2, err := app.CreatePreAuthKey(
|
||||
namespaceShared2.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak3n3, err := h.CreatePreAuthKey(n3.Name, false, false, nil)
|
||||
preAuthKeyShared3, err := app.CreatePreAuthKey(
|
||||
namespaceShared3.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak4n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
||||
preAuthKey2Shared1, err := app.CreatePreAuthKey(
|
||||
namespaceShared1.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
|
||||
_, err = app.GetMachine(namespaceShared1.Name, "test_get_shared_nodes_1")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
m1 := &Machine{
|
||||
machineInShared1 := &Machine{
|
||||
ID: 1,
|
||||
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
Name: "test_get_shared_nodes_1",
|
||||
NamespaceID: n1.ID,
|
||||
Namespace: *n1,
|
||||
NamespaceID: namespaceShared1.ID,
|
||||
Namespace: *namespaceShared1,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
IPAddress: "100.64.0.1",
|
||||
AuthKeyID: uint(pak1n1.ID),
|
||||
AuthKeyID: uint(preAuthKeyShared1.ID),
|
||||
}
|
||||
h.db.Save(m1)
|
||||
app.db.Save(machineInShared1)
|
||||
|
||||
_, err = h.GetMachine(n1.Name, m1.Name)
|
||||
_, err = app.GetMachine(namespaceShared1.Name, machineInShared1.Name)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
m2 := &Machine{
|
||||
machineInShared2 := &Machine{
|
||||
ID: 2,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Name: "test_get_shared_nodes_2",
|
||||
NamespaceID: n2.ID,
|
||||
Namespace: *n2,
|
||||
NamespaceID: namespaceShared2.ID,
|
||||
Namespace: *namespaceShared2,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
IPAddress: "100.64.0.2",
|
||||
AuthKeyID: uint(pak2n2.ID),
|
||||
AuthKeyID: uint(preAuthKeyShared2.ID),
|
||||
}
|
||||
h.db.Save(m2)
|
||||
app.db.Save(machineInShared2)
|
||||
|
||||
_, err = h.GetMachine(n2.Name, m2.Name)
|
||||
_, err = app.GetMachine(namespaceShared2.Name, machineInShared2.Name)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
m3 := &Machine{
|
||||
machineInShared3 := &Machine{
|
||||
ID: 3,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Name: "test_get_shared_nodes_3",
|
||||
NamespaceID: n3.ID,
|
||||
Namespace: *n3,
|
||||
NamespaceID: namespaceShared3.ID,
|
||||
Namespace: *namespaceShared3,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
IPAddress: "100.64.0.3",
|
||||
AuthKeyID: uint(pak3n3.ID),
|
||||
AuthKeyID: uint(preAuthKeyShared3.ID),
|
||||
}
|
||||
h.db.Save(m3)
|
||||
app.db.Save(machineInShared3)
|
||||
|
||||
_, err = h.GetMachine(n3.Name, m3.Name)
|
||||
_, err = app.GetMachine(namespaceShared3.Name, machineInShared3.Name)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
m4 := &Machine{
|
||||
machine2InShared1 := &Machine{
|
||||
ID: 4,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Name: "test_get_shared_nodes_4",
|
||||
NamespaceID: n1.ID,
|
||||
Namespace: *n1,
|
||||
NamespaceID: namespaceShared1.ID,
|
||||
Namespace: *namespaceShared1,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
IPAddress: "100.64.0.4",
|
||||
AuthKeyID: uint(pak4n1.ID),
|
||||
AuthKeyID: uint(preAuthKey2Shared1.ID),
|
||||
}
|
||||
h.db.Save(m4)
|
||||
app.db.Save(machine2InShared1)
|
||||
|
||||
err = h.AddSharedMachineToNamespace(m2, n1)
|
||||
err = app.AddSharedMachineToNamespace(machineInShared2, namespaceShared1)
|
||||
c.Assert(err, check.IsNil)
|
||||
m1peers, err := h.getPeers(m1)
|
||||
peersOfMachine1InShared1, err := app.getPeers(machineInShared1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
userProfiles := getMapResponseUserProfiles(*m1, m1peers)
|
||||
userProfiles := getMapResponseUserProfiles(
|
||||
*machineInShared1,
|
||||
peersOfMachine1InShared1,
|
||||
)
|
||||
|
||||
log.Trace().Msgf("userProfiles %#v", userProfiles)
|
||||
c.Assert(len(userProfiles), check.Equals, 2)
|
||||
|
||||
found := false
|
||||
for _, up := range userProfiles {
|
||||
if up.DisplayName == n1.Name {
|
||||
for _, userProfiles := range userProfiles {
|
||||
if userProfiles.DisplayName == namespaceShared1.Name {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
c.Assert(found, check.Equals, true)
|
||||
|
||||
found = false
|
||||
for _, up := range userProfiles {
|
||||
if up.DisplayName == n2.Name {
|
||||
for _, userProfile := range userProfiles {
|
||||
if userProfile.DisplayName == namespaceShared2.Name {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
150
oidc.go
150
oidc.go
|
@ -17,6 +17,12 @@ import (
|
|||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const (
|
||||
oidcStateCacheExpiration = time.Minute * 5
|
||||
oidcStateCacheCleanupInterval = time.Minute * 10
|
||||
randomByteSize = 16
|
||||
)
|
||||
|
||||
type IDTokenClaims struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Groups []string `json:"groups,omitempty"`
|
||||
|
@ -32,6 +38,7 @@ func (h *Headscale) initOIDC() error {
|
|||
|
||||
if err != nil {
|
||||
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -39,14 +46,20 @@ func (h *Headscale) initOIDC() error {
|
|||
ClientID: h.cfg.OIDC.ClientID,
|
||||
ClientSecret: h.cfg.OIDC.ClientSecret,
|
||||
Endpoint: h.oidcProvider.Endpoint(),
|
||||
RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
RedirectURL: fmt.Sprintf(
|
||||
"%s/oidc/callback",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
||||
),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
}
|
||||
|
||||
// init the state cache if it hasn't been already
|
||||
if h.oidcStateCache == nil {
|
||||
h.oidcStateCache = cache.New(time.Minute*5, time.Minute*10)
|
||||
h.oidcStateCache = cache.New(
|
||||
oidcStateCacheExpiration,
|
||||
oidcStateCacheCleanupInterval,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -54,50 +67,53 @@ func (h *Headscale) initOIDC() error {
|
|||
|
||||
// RegisterOIDC redirects to the OIDC provider for authentication
|
||||
// Puts machine key in cache so the callback can retrieve it using the oidc state param
|
||||
// Listens in /oidc/register/:mKey
|
||||
func (h *Headscale) RegisterOIDC(c *gin.Context) {
|
||||
mKeyStr := c.Param("mkey")
|
||||
// Listens in /oidc/register/:mKey.
|
||||
func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
|
||||
mKeyStr := ctx.Param("mkey")
|
||||
if mKeyStr == "" {
|
||||
c.String(http.StatusBadRequest, "Wrong params")
|
||||
ctx.String(http.StatusBadRequest, "Wrong params")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
b := make([]byte, 16)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
randomBlob := make([]byte, randomByteSize)
|
||||
if _, err := rand.Read(randomBlob); err != nil {
|
||||
log.Error().Msg("could not read 16 bytes from rand")
|
||||
c.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
|
||||
ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
stateStr := hex.EncodeToString(b)[:32]
|
||||
stateStr := hex.EncodeToString(randomBlob)[:32]
|
||||
|
||||
// place the machine key into the state cache, so it can be retrieved later
|
||||
h.oidcStateCache.Set(stateStr, mKeyStr, time.Minute*5)
|
||||
h.oidcStateCache.Set(stateStr, mKeyStr, oidcStateCacheExpiration)
|
||||
|
||||
authUrl := h.oauth2Config.AuthCodeURL(stateStr)
|
||||
log.Debug().Msgf("Redirecting to %s for authentication", authUrl)
|
||||
authURL := h.oauth2Config.AuthCodeURL(stateStr)
|
||||
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
|
||||
|
||||
c.Redirect(http.StatusFound, authUrl)
|
||||
ctx.Redirect(http.StatusFound, authURL)
|
||||
}
|
||||
|
||||
// OIDCCallback handles the callback from the OIDC endpoint
|
||||
// Retrieves the mkey from the state cache and adds the machine to the users email namespace
|
||||
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
|
||||
// TODO: Add groups information from OIDC tokens into machine HostInfo
|
||||
// Listens in /oidc/callback
|
||||
func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
// Listens in /oidc/callback.
|
||||
func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
||||
code := ctx.Query("code")
|
||||
state := ctx.Query("state")
|
||||
|
||||
if code == "" || state == "" {
|
||||
c.String(http.StatusBadRequest, "Wrong params")
|
||||
ctx.String(http.StatusBadRequest, "Wrong params")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, "Could not exchange code for token")
|
||||
ctx.String(http.StatusBadRequest, "Could not exchange code for token")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -105,7 +121,8 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
|
||||
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
|
||||
if !rawIDTokenOK {
|
||||
c.String(http.StatusBadRequest, "Could not extract ID Token")
|
||||
ctx.String(http.StatusBadRequest, "Could not extract ID Token")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -113,21 +130,26 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
|
||||
idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
|
||||
ctx.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: we can use userinfo at some point to grab additional information about the user (groups membership, etc)
|
||||
//userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token))
|
||||
//if err != nil {
|
||||
// c.String(http.StatusBadRequest, fmt.Sprintf("Failed to retrieve userinfo: %s", err))
|
||||
// return
|
||||
//}
|
||||
// userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token))
|
||||
// if err != nil {
|
||||
// c.String(http.StatusBadRequest, fmt.Sprintf("Failed to retrieve userinfo: %s", err))
|
||||
// return
|
||||
// }
|
||||
|
||||
// Extract custom claims
|
||||
var claims IDTokenClaims
|
||||
if err = idToken.Claims(&claims); err != nil {
|
||||
c.String(http.StatusBadRequest, fmt.Sprintf("Failed to decode id token claims: %s", err))
|
||||
ctx.String(
|
||||
http.StatusBadRequest,
|
||||
fmt.Sprintf("Failed to decode id token claims: %s", err),
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -135,62 +157,80 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
mKeyIf, mKeyFound := h.oidcStateCache.Get(state)
|
||||
|
||||
if !mKeyFound {
|
||||
log.Error().Msg("requested machine state key expired before authorisation completed")
|
||||
c.String(http.StatusBadRequest, "state has expired")
|
||||
log.Error().
|
||||
Msg("requested machine state key expired before authorisation completed")
|
||||
ctx.String(http.StatusBadRequest, "state has expired")
|
||||
|
||||
return
|
||||
}
|
||||
mKeyStr, mKeyOK := mKeyIf.(string)
|
||||
|
||||
if !mKeyOK {
|
||||
log.Error().Msg("could not get machine key from cache")
|
||||
c.String(http.StatusInternalServerError, "could not get machine key from cache")
|
||||
ctx.String(
|
||||
http.StatusInternalServerError,
|
||||
"could not get machine key from cache",
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// retrieve machine information
|
||||
m, err := h.GetMachineByMachineKey(mKeyStr)
|
||||
machine, err := h.GetMachineByMachineKey(mKeyStr)
|
||||
if err != nil {
|
||||
log.Error().Msg("machine key not found in database")
|
||||
c.String(http.StatusInternalServerError, "could not get machine info from database")
|
||||
ctx.String(
|
||||
http.StatusInternalServerError,
|
||||
"could not get machine info from database",
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok {
|
||||
if namespaceName, ok := h.getNamespaceFromEmail(claims.Email); ok {
|
||||
// register the machine if it's new
|
||||
if !m.Registered {
|
||||
|
||||
if !machine.Registered {
|
||||
log.Debug().Msg("Registering new machine after successful callback")
|
||||
|
||||
ns, err := h.GetNamespace(nsName)
|
||||
namespace, err := h.GetNamespace(namespaceName)
|
||||
if err != nil {
|
||||
ns, err = h.CreateNamespace(nsName)
|
||||
namespace, err = h.CreateNamespace(namespaceName)
|
||||
|
||||
if err != nil {
|
||||
log.Error().Msgf("could not create new namespace '%s'", claims.Email)
|
||||
c.String(http.StatusInternalServerError, "could not create new namespace")
|
||||
log.Error().
|
||||
Msgf("could not create new namespace '%s'", claims.Email)
|
||||
ctx.String(
|
||||
http.StatusInternalServerError,
|
||||
"could not create new namespace",
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ip, err := h.getAvailableIP()
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, "could not get an IP from the pool")
|
||||
ctx.String(
|
||||
http.StatusInternalServerError,
|
||||
"could not get an IP from the pool",
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
m.IPAddress = ip.String()
|
||||
m.NamespaceID = ns.ID
|
||||
m.Registered = true
|
||||
m.RegisterMethod = "oidc"
|
||||
m.LastSuccessfulUpdate = &now
|
||||
h.db.Save(&m)
|
||||
machine.IPAddress = ip.String()
|
||||
machine.NamespaceID = namespace.ID
|
||||
machine.Registered = true
|
||||
machine.RegisterMethod = "oidc"
|
||||
machine.LastSuccessfulUpdate = &now
|
||||
h.db.Save(&machine)
|
||||
}
|
||||
|
||||
h.updateMachineExpiry(m)
|
||||
h.updateMachineExpiry(machine)
|
||||
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
||||
ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
||||
<html>
|
||||
<body>
|
||||
<h1>headscale</h1>
|
||||
|
@ -201,15 +241,17 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
</html>
|
||||
|
||||
`, claims.Email)))
|
||||
|
||||
}
|
||||
|
||||
log.Error().
|
||||
Str("email", claims.Email).
|
||||
Str("username", claims.Username).
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Email could not be mapped to a namespace")
|
||||
c.String(http.StatusBadRequest, "email from claim could not be mapped to a namespace")
|
||||
ctx.String(
|
||||
http.StatusBadRequest,
|
||||
"email from claim could not be mapped to a namespace",
|
||||
)
|
||||
}
|
||||
|
||||
// getNamespaceFromEmail passes the users email through a list of "matchers"
|
||||
|
|
50
oidc_test.go
50
oidc_test.go
|
@ -145,29 +145,37 @@ func TestHeadscale_getNamespaceFromEmail(t *testing.T) {
|
|||
},
|
||||
}
|
||||
//nolint
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := &Headscale{
|
||||
cfg: tt.fields.cfg,
|
||||
db: tt.fields.db,
|
||||
dbString: tt.fields.dbString,
|
||||
dbType: tt.fields.dbType,
|
||||
dbDebug: tt.fields.dbDebug,
|
||||
publicKey: tt.fields.publicKey,
|
||||
privateKey: tt.fields.privateKey,
|
||||
aclPolicy: tt.fields.aclPolicy,
|
||||
aclRules: tt.fields.aclRules,
|
||||
lastStateChange: tt.fields.lastStateChange,
|
||||
oidcProvider: tt.fields.oidcProvider,
|
||||
oauth2Config: tt.fields.oauth2Config,
|
||||
oidcStateCache: tt.fields.oidcStateCache,
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
app := &Headscale{
|
||||
cfg: test.fields.cfg,
|
||||
db: test.fields.db,
|
||||
dbString: test.fields.dbString,
|
||||
dbType: test.fields.dbType,
|
||||
dbDebug: test.fields.dbDebug,
|
||||
publicKey: test.fields.publicKey,
|
||||
privateKey: test.fields.privateKey,
|
||||
aclPolicy: test.fields.aclPolicy,
|
||||
aclRules: test.fields.aclRules,
|
||||
lastStateChange: test.fields.lastStateChange,
|
||||
oidcProvider: test.fields.oidcProvider,
|
||||
oauth2Config: test.fields.oauth2Config,
|
||||
oidcStateCache: test.fields.oidcStateCache,
|
||||
}
|
||||
got, got1 := h.getNamespaceFromEmail(tt.args.email)
|
||||
if got != tt.want {
|
||||
t.Errorf("Headscale.getNamespaceFromEmail() got = %v, want %v", got, tt.want)
|
||||
got, got1 := app.getNamespaceFromEmail(test.args.email)
|
||||
if got != test.want {
|
||||
t.Errorf(
|
||||
"Headscale.getNamespaceFromEmail() got = %v, want %v",
|
||||
got,
|
||||
test.want,
|
||||
)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("Headscale.getNamespaceFromEmail() got1 = %v, want %v", got1, tt.want1)
|
||||
if got1 != test.want1 {
|
||||
t.Errorf(
|
||||
"Headscale.getNamespaceFromEmail() got1 = %v, want %v",
|
||||
got1,
|
||||
test.want1,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
265
poll.go
265
poll.go
|
@ -15,6 +15,11 @@ import (
|
|||
"tailscale.com/types/wgkey"
|
||||
)
|
||||
|
||||
const (
|
||||
keepAliveInterval = 60 * time.Second
|
||||
updateCheckInterval = 10 * time.Second
|
||||
)
|
||||
|
||||
// PollNetMapHandler takes care of /machine/:id/map
|
||||
//
|
||||
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
|
||||
|
@ -24,20 +29,21 @@ import (
|
|||
// only after their first request (marked with the ReadOnly field).
|
||||
//
|
||||
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
|
||||
func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
||||
func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("id", c.Param("id")).
|
||||
Str("id", ctx.Param("id")).
|
||||
Msg("PollNetMapHandler called")
|
||||
body, _ := io.ReadAll(c.Request.Body)
|
||||
mKeyStr := c.Param("id")
|
||||
body, _ := io.ReadAll(ctx.Request.Body)
|
||||
mKeyStr := ctx.Param("id")
|
||||
mKey, err := wgkey.ParseHex(mKeyStr)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMap").
|
||||
Err(err).
|
||||
Msg("Cannot parse client key")
|
||||
c.String(http.StatusBadRequest, "")
|
||||
ctx.String(http.StatusBadRequest, "")
|
||||
|
||||
return
|
||||
}
|
||||
req := tailcfg.MapRequest{}
|
||||
|
@ -47,34 +53,36 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
|||
Str("handler", "PollNetMap").
|
||||
Err(err).
|
||||
Msg("Cannot decode message")
|
||||
c.String(http.StatusBadRequest, "")
|
||||
ctx.String(http.StatusBadRequest, "")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
m, err := h.GetMachineByMachineKey(mKey.HexString())
|
||||
machine, err := h.GetMachineByMachineKey(mKey.HexString())
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Warn().
|
||||
Str("handler", "PollNetMap").
|
||||
Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
|
||||
c.String(http.StatusUnauthorized, "")
|
||||
ctx.String(http.StatusUnauthorized, "")
|
||||
|
||||
return
|
||||
}
|
||||
log.Error().
|
||||
Str("handler", "PollNetMap").
|
||||
Msgf("Failed to fetch machine from the database with Machine key: %s", mKey.HexString())
|
||||
c.String(http.StatusInternalServerError, "")
|
||||
ctx.String(http.StatusInternalServerError, "")
|
||||
}
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("id", c.Param("id")).
|
||||
Str("machine", m.Name).
|
||||
Str("id", ctx.Param("id")).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Found machine in database")
|
||||
|
||||
hostinfo, _ := json.Marshal(req.Hostinfo)
|
||||
m.Name = req.Hostinfo.Hostname
|
||||
m.HostInfo = datatypes.JSON(hostinfo)
|
||||
m.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
|
||||
machine.Name = req.Hostinfo.Hostname
|
||||
machine.HostInfo = datatypes.JSON(hostinfo)
|
||||
machine.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
|
||||
now := time.Now().UTC()
|
||||
|
||||
// From Tailscale client:
|
||||
|
@ -87,20 +95,21 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
|||
// before their first real endpoint update.
|
||||
if !req.ReadOnly {
|
||||
endpoints, _ := json.Marshal(req.Endpoints)
|
||||
m.Endpoints = datatypes.JSON(endpoints)
|
||||
m.LastSeen = &now
|
||||
machine.Endpoints = datatypes.JSON(endpoints)
|
||||
machine.LastSeen = &now
|
||||
}
|
||||
h.db.Save(&m)
|
||||
h.db.Save(&machine)
|
||||
|
||||
data, err := h.getMapResponse(mKey, req, m)
|
||||
data, err := h.getMapResponse(mKey, req, machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("id", c.Param("id")).
|
||||
Str("machine", m.Name).
|
||||
Str("id", ctx.Param("id")).
|
||||
Str("machine", machine.Name).
|
||||
Err(err).
|
||||
Msg("Failed to get Map response")
|
||||
c.String(http.StatusInternalServerError, ":(")
|
||||
ctx.String(http.StatusInternalServerError, ":(")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -111,8 +120,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
|||
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
|
||||
log.Debug().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("id", c.Param("id")).
|
||||
Str("machine", m.Name).
|
||||
Str("id", ctx.Param("id")).
|
||||
Str("machine", machine.Name).
|
||||
Bool("readOnly", req.ReadOnly).
|
||||
Bool("omitPeers", req.OmitPeers).
|
||||
Bool("stream", req.Stream).
|
||||
|
@ -121,15 +130,16 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
|||
if req.ReadOnly {
|
||||
log.Info().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Client is starting up. Probably interested in a DERP map")
|
||||
c.Data(200, "application/json; charset=utf-8", data)
|
||||
ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// There has been an update to _any_ of the nodes that the other nodes would
|
||||
// need to know about
|
||||
h.setLastStateChangeToNow(m.Namespace.Name)
|
||||
h.setLastStateChangeToNow(machine.Namespace.Name)
|
||||
|
||||
// The request is not ReadOnly, so we need to set up channels for updating
|
||||
// peers via longpoll
|
||||
|
@ -137,8 +147,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
|||
// Only create update channel if it has not been created
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("id", c.Param("id")).
|
||||
Str("machine", m.Name).
|
||||
Str("id", ctx.Param("id")).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Loading or creating update channel")
|
||||
updateChan := make(chan struct{})
|
||||
|
||||
|
@ -152,46 +162,59 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
|||
if req.OmitPeers && !req.Stream {
|
||||
log.Info().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Client sent endpoint update and is ok with a response without peer list")
|
||||
c.Data(200, "application/json; charset=utf-8", data)
|
||||
ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)
|
||||
|
||||
// It sounds like we should update the nodes when we have received a endpoint update
|
||||
// even tho the comments in the tailscale code dont explicitly say so.
|
||||
updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "endpoint-update").Inc()
|
||||
updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "endpoint-update").
|
||||
Inc()
|
||||
go func() { updateChan <- struct{}{} }()
|
||||
|
||||
return
|
||||
} else if req.OmitPeers && req.Stream {
|
||||
log.Warn().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Ignoring request, don't know how to handle it")
|
||||
c.String(http.StatusBadRequest, "")
|
||||
ctx.String(http.StatusBadRequest, "")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Client is ready to access the tailnet")
|
||||
log.Info().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Sending initial map")
|
||||
go func() { pollDataChan <- data }()
|
||||
|
||||
log.Info().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Notifying peers")
|
||||
updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "full-update").Inc()
|
||||
updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "full-update").
|
||||
Inc()
|
||||
go func() { updateChan <- struct{}{} }()
|
||||
|
||||
h.PollNetMapStream(c, m, req, mKey, pollDataChan, keepAliveChan, updateChan, cancelKeepAlive)
|
||||
h.PollNetMapStream(
|
||||
ctx,
|
||||
machine,
|
||||
req,
|
||||
mKey,
|
||||
pollDataChan,
|
||||
keepAliveChan,
|
||||
updateChan,
|
||||
cancelKeepAlive,
|
||||
)
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("id", c.Param("id")).
|
||||
Str("machine", m.Name).
|
||||
Str("id", ctx.Param("id")).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Finished stream, closing PollNetMap session")
|
||||
}
|
||||
|
||||
|
@ -199,165 +222,181 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
|||
// stream logic, ensuring we communicate updates and data
|
||||
// to the connected clients.
|
||||
func (h *Headscale) PollNetMapStream(
|
||||
c *gin.Context,
|
||||
m *Machine,
|
||||
req tailcfg.MapRequest,
|
||||
mKey wgkey.Key,
|
||||
ctx *gin.Context,
|
||||
machine *Machine,
|
||||
mapRequest tailcfg.MapRequest,
|
||||
machineKey wgkey.Key,
|
||||
pollDataChan chan []byte,
|
||||
keepAliveChan chan []byte,
|
||||
updateChan chan struct{},
|
||||
cancelKeepAlive chan struct{},
|
||||
) {
|
||||
go h.scheduledPollWorker(cancelKeepAlive, updateChan, keepAliveChan, mKey, req, m)
|
||||
go h.scheduledPollWorker(
|
||||
cancelKeepAlive,
|
||||
updateChan,
|
||||
keepAliveChan,
|
||||
machineKey,
|
||||
mapRequest,
|
||||
machine,
|
||||
)
|
||||
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
ctx.Stream(func(writer io.Writer) bool {
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Waiting for data to stream...")
|
||||
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
|
||||
|
||||
select {
|
||||
case data := <-pollDataChan:
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("channel", "pollData").
|
||||
Int("bytes", len(data)).
|
||||
Msg("Sending data received via pollData channel")
|
||||
_, err := w.Write(data)
|
||||
_, err := writer.Write(data)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("channel", "pollData").
|
||||
Err(err).
|
||||
Msg("Cannot write data")
|
||||
|
||||
return false
|
||||
}
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("channel", "pollData").
|
||||
Int("bytes", len(data)).
|
||||
Msg("Data from pollData channel written successfully")
|
||||
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
|
||||
// when an outdated machine object is kept alive, e.g. db is update from
|
||||
// command line, but then overwritten.
|
||||
err = h.UpdateMachine(m)
|
||||
err = h.UpdateMachine(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("channel", "pollData").
|
||||
Err(err).
|
||||
Msg("Cannot update machine from database")
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
m.LastSeen = &now
|
||||
machine.LastSeen = &now
|
||||
|
||||
lastStateUpdate.WithLabelValues(m.Namespace.Name, m.Name).Set(float64(now.Unix()))
|
||||
m.LastSuccessfulUpdate = &now
|
||||
lastStateUpdate.WithLabelValues(machine.Namespace.Name, machine.Name).
|
||||
Set(float64(now.Unix()))
|
||||
machine.LastSuccessfulUpdate = &now
|
||||
|
||||
h.db.Save(&m)
|
||||
h.db.Save(&machine)
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("channel", "pollData").
|
||||
Int("bytes", len(data)).
|
||||
Msg("Machine entry in database updated successfully after sending pollData")
|
||||
|
||||
return true
|
||||
|
||||
case data := <-keepAliveChan:
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("channel", "keepAlive").
|
||||
Int("bytes", len(data)).
|
||||
Msg("Sending keep alive message")
|
||||
_, err := w.Write(data)
|
||||
_, err := writer.Write(data)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("channel", "keepAlive").
|
||||
Err(err).
|
||||
Msg("Cannot write keep alive message")
|
||||
|
||||
return false
|
||||
}
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("channel", "keepAlive").
|
||||
Int("bytes", len(data)).
|
||||
Msg("Keep alive sent successfully")
|
||||
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
|
||||
// when an outdated machine object is kept alive, e.g. db is update from
|
||||
// command line, but then overwritten.
|
||||
err = h.UpdateMachine(m)
|
||||
err = h.UpdateMachine(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("channel", "keepAlive").
|
||||
Err(err).
|
||||
Msg("Cannot update machine from database")
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
m.LastSeen = &now
|
||||
h.db.Save(&m)
|
||||
machine.LastSeen = &now
|
||||
h.db.Save(&machine)
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("channel", "keepAlive").
|
||||
Int("bytes", len(data)).
|
||||
Msg("Machine updated successfully after sending keep alive")
|
||||
|
||||
return true
|
||||
|
||||
case <-updateChan:
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("channel", "update").
|
||||
Msg("Received a request for update")
|
||||
updateRequestsReceivedOnChannel.WithLabelValues(m.Name, m.Namespace.Name).Inc()
|
||||
if h.isOutdated(m) {
|
||||
updateRequestsReceivedOnChannel.WithLabelValues(machine.Name, machine.Namespace.Name).
|
||||
Inc()
|
||||
if h.isOutdated(machine) {
|
||||
log.Debug().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Time("last_successful_update", *m.LastSuccessfulUpdate).
|
||||
Time("last_state_change", h.getLastStateChange(m.Namespace.Name)).
|
||||
Msgf("There has been updates since the last successful update to %s", m.Name)
|
||||
data, err := h.getMapResponse(mKey, req, m)
|
||||
Str("machine", machine.Name).
|
||||
Time("last_successful_update", *machine.LastSuccessfulUpdate).
|
||||
Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)).
|
||||
Msgf("There has been updates since the last successful update to %s", machine.Name)
|
||||
data, err := h.getMapResponse(machineKey, mapRequest, machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("channel", "update").
|
||||
Err(err).
|
||||
Msg("Could not get the map update")
|
||||
}
|
||||
_, err = w.Write(data)
|
||||
_, err = writer.Write(data)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("channel", "update").
|
||||
Err(err).
|
||||
Msg("Could not write the map response")
|
||||
updateRequestsSentToNode.WithLabelValues(m.Name, m.Namespace.Name, "failed").Inc()
|
||||
updateRequestsSentToNode.WithLabelValues(machine.Name, machine.Namespace.Name, "failed").
|
||||
Inc()
|
||||
|
||||
return false
|
||||
}
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("channel", "update").
|
||||
Msg("Updated Map has been sent")
|
||||
updateRequestsSentToNode.WithLabelValues(m.Name, m.Namespace.Name, "success").Inc()
|
||||
updateRequestsSentToNode.WithLabelValues(machine.Name, machine.Namespace.Name, "success").
|
||||
Inc()
|
||||
|
||||
// Keep track of the last successful update,
|
||||
// we sometimes end in a state were the update
|
||||
|
@ -366,77 +405,79 @@ func (h *Headscale) PollNetMapStream(
|
|||
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
|
||||
// when an outdated machine object is kept alive, e.g. db is update from
|
||||
// command line, but then overwritten.
|
||||
err = h.UpdateMachine(m)
|
||||
err = h.UpdateMachine(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("channel", "update").
|
||||
Err(err).
|
||||
Msg("Cannot update machine from database")
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
|
||||
lastStateUpdate.WithLabelValues(m.Namespace.Name, m.Name).Set(float64(now.Unix()))
|
||||
m.LastSuccessfulUpdate = &now
|
||||
lastStateUpdate.WithLabelValues(machine.Namespace.Name, machine.Name).
|
||||
Set(float64(now.Unix()))
|
||||
machine.LastSuccessfulUpdate = &now
|
||||
|
||||
h.db.Save(&m)
|
||||
h.db.Save(&machine)
|
||||
} else {
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Time("last_successful_update", *m.LastSuccessfulUpdate).
|
||||
Time("last_state_change", h.getLastStateChange(m.Namespace.Name)).
|
||||
Msgf("%s is up to date", m.Name)
|
||||
Str("machine", machine.Name).
|
||||
Time("last_successful_update", *machine.LastSuccessfulUpdate).
|
||||
Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)).
|
||||
Msgf("%s is up to date", machine.Name)
|
||||
}
|
||||
|
||||
return true
|
||||
|
||||
case <-c.Request.Context().Done():
|
||||
case <-ctx.Request.Context().Done():
|
||||
log.Info().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("The client has closed the connection")
|
||||
// TODO: Abstract away all the database calls, this can cause race conditions
|
||||
// when an outdated machine object is kept alive, e.g. db is update from
|
||||
// command line, but then overwritten.
|
||||
err := h.UpdateMachine(m)
|
||||
err := h.UpdateMachine(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("channel", "Done").
|
||||
Err(err).
|
||||
Msg("Cannot update machine from database")
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
m.LastSeen = &now
|
||||
h.db.Save(&m)
|
||||
machine.LastSeen = &now
|
||||
h.db.Save(&machine)
|
||||
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("channel", "Done").
|
||||
Msg("Cancelling keepAlive channel")
|
||||
cancelKeepAlive <- struct{}{}
|
||||
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("channel", "Done").
|
||||
Msg("Closing update channel")
|
||||
//h.closeUpdateChannel(m)
|
||||
// h.closeUpdateChannel(m)
|
||||
close(updateChan)
|
||||
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("channel", "Done").
|
||||
Msg("Closing pollData channel")
|
||||
close(pollDataChan)
|
||||
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("channel", "Done").
|
||||
Msg("Closing keepAliveChan channel")
|
||||
close(keepAliveChan)
|
||||
|
@ -450,12 +491,12 @@ func (h *Headscale) scheduledPollWorker(
|
|||
cancelChan <-chan struct{},
|
||||
updateChan chan<- struct{},
|
||||
keepAliveChan chan<- []byte,
|
||||
mKey wgkey.Key,
|
||||
req tailcfg.MapRequest,
|
||||
m *Machine,
|
||||
machineKey wgkey.Key,
|
||||
mapRequest tailcfg.MapRequest,
|
||||
machine *Machine,
|
||||
) {
|
||||
keepAliveTicker := time.NewTicker(60 * time.Second)
|
||||
updateCheckerTicker := time.NewTicker(10 * time.Second)
|
||||
keepAliveTicker := time.NewTicker(keepAliveInterval)
|
||||
updateCheckerTicker := time.NewTicker(updateCheckInterval)
|
||||
|
||||
for {
|
||||
select {
|
||||
|
@ -463,27 +504,29 @@ func (h *Headscale) scheduledPollWorker(
|
|||
return
|
||||
|
||||
case <-keepAliveTicker.C:
|
||||
data, err := h.getMapKeepAliveResponse(mKey, req, m)
|
||||
data, err := h.getMapKeepAliveResponse(machineKey, mapRequest)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("func", "keepAlive").
|
||||
Err(err).
|
||||
Msg("Error generating the keep alive msg")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("func", "keepAlive").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Sending keepalive")
|
||||
keepAliveChan <- data
|
||||
|
||||
case <-updateCheckerTicker.C:
|
||||
log.Debug().
|
||||
Str("func", "scheduledPollWorker").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Sending update request")
|
||||
updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "scheduled-update").Inc()
|
||||
updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "scheduled-update").
|
||||
Inc()
|
||||
updateChan <- struct{}{}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,19 +7,19 @@ import (
|
|||
"strconv"
|
||||
"time"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"gorm.io/gorm"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
)
|
||||
|
||||
const (
|
||||
errorAuthKeyNotFound = Error("AuthKey not found")
|
||||
errorAuthKeyExpired = Error("AuthKey expired")
|
||||
errPreAuthKeyNotFound = Error("AuthKey not found")
|
||||
errPreAuthKeyExpired = Error("AuthKey expired")
|
||||
errSingleUseAuthKeyHasBeenUsed = Error("AuthKey has already been used")
|
||||
errNamespaceMismatch = Error("namespace mismatch")
|
||||
)
|
||||
|
||||
// PreAuthKey describes a pre-authorization key usable in a particular namespace
|
||||
// PreAuthKey describes a pre-authorization key usable in a particular namespace.
|
||||
type PreAuthKey struct {
|
||||
ID uint64 `gorm:"primary_key"`
|
||||
Key string
|
||||
|
@ -33,14 +33,14 @@ type PreAuthKey struct {
|
|||
Expiration *time.Time
|
||||
}
|
||||
|
||||
// CreatePreAuthKey creates a new PreAuthKey in a namespace, and returns it
|
||||
// CreatePreAuthKey creates a new PreAuthKey in a namespace, and returns it.
|
||||
func (h *Headscale) CreatePreAuthKey(
|
||||
namespaceName string,
|
||||
reusable bool,
|
||||
ephemeral bool,
|
||||
expiration *time.Time,
|
||||
) (*PreAuthKey, error) {
|
||||
n, err := h.GetNamespace(namespaceName)
|
||||
namespace, err := h.GetNamespace(namespaceName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -51,35 +51,36 @@ func (h *Headscale) CreatePreAuthKey(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
k := PreAuthKey{
|
||||
key := PreAuthKey{
|
||||
Key: kstr,
|
||||
NamespaceID: n.ID,
|
||||
Namespace: *n,
|
||||
NamespaceID: namespace.ID,
|
||||
Namespace: *namespace,
|
||||
Reusable: reusable,
|
||||
Ephemeral: ephemeral,
|
||||
CreatedAt: &now,
|
||||
Expiration: expiration,
|
||||
}
|
||||
h.db.Save(&k)
|
||||
h.db.Save(&key)
|
||||
|
||||
return &k, nil
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
// ListPreAuthKeys returns the list of PreAuthKeys for a namespace
|
||||
// ListPreAuthKeys returns the list of PreAuthKeys for a namespace.
|
||||
func (h *Headscale) ListPreAuthKeys(namespaceName string) ([]PreAuthKey, error) {
|
||||
n, err := h.GetNamespace(namespaceName)
|
||||
namespace, err := h.GetNamespace(namespaceName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys := []PreAuthKey{}
|
||||
if err := h.db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: n.ID}).Find(&keys).Error; err != nil {
|
||||
if err := h.db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: namespace.ID}).Find(&keys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// GetPreAuthKey returns a PreAuthKey for a given key
|
||||
// GetPreAuthKey returns a PreAuthKey for a given key.
|
||||
func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, error) {
|
||||
pak, err := h.checkKeyValidity(key)
|
||||
if err != nil {
|
||||
|
@ -87,7 +88,7 @@ func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, er
|
|||
}
|
||||
|
||||
if pak.Namespace.Name != namespace {
|
||||
return nil, errors.New("Namespace mismatch")
|
||||
return nil, errNamespaceMismatch
|
||||
}
|
||||
|
||||
return pak, nil
|
||||
|
@ -95,32 +96,36 @@ func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, er
|
|||
|
||||
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
|
||||
// does not exist.
|
||||
func (h *Headscale) DestroyPreAuthKey(pak *PreAuthKey) error {
|
||||
if result := h.db.Unscoped().Delete(&pak); result.Error != nil {
|
||||
func (h *Headscale) DestroyPreAuthKey(pak PreAuthKey) error {
|
||||
if result := h.db.Unscoped().Delete(pak); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired
|
||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||
func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error {
|
||||
if err := h.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node
|
||||
// If returns no error and a PreAuthKey, it can be used
|
||||
// If returns no error and a PreAuthKey, it can be used.
|
||||
func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {
|
||||
pak := PreAuthKey{}
|
||||
if result := h.db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, errorAuthKeyNotFound
|
||||
if result := h.db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is(
|
||||
result.Error,
|
||||
gorm.ErrRecordNotFound,
|
||||
) {
|
||||
return nil, errPreAuthKeyNotFound
|
||||
}
|
||||
|
||||
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
|
||||
return nil, errorAuthKeyExpired
|
||||
return nil, errPreAuthKeyExpired
|
||||
}
|
||||
|
||||
if pak.Reusable || pak.Ephemeral { // we don't need to check if has been used before
|
||||
|
@ -145,13 +150,14 @@ func (h *Headscale) generateKey() (string, error) {
|
|||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
func (key *PreAuthKey) toProto() *v1.PreAuthKey {
|
||||
protoKey := v1.PreAuthKey{
|
||||
Namespace: key.Namespace.Name,
|
||||
Id: strconv.FormatUint(key.ID, 10),
|
||||
Id: strconv.FormatUint(key.ID, Base10),
|
||||
Key: key.Key,
|
||||
Ephemeral: key.Ephemeral,
|
||||
Reusable: key.Reusable,
|
||||
|
|
|
@ -7,189 +7,189 @@ import (
|
|||
)
|
||||
|
||||
func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
||||
_, err := h.CreatePreAuthKey("bogus", true, false, nil)
|
||||
_, err := app.CreatePreAuthKey("bogus", true, false, nil)
|
||||
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
n, err := h.CreateNamespace("test")
|
||||
namespace, err := app.CreateNamespace("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
k, err := h.CreatePreAuthKey(n.Name, true, false, nil)
|
||||
key, err := app.CreatePreAuthKey(namespace.Name, true, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
// Did we get a valid key?
|
||||
c.Assert(k.Key, check.NotNil)
|
||||
c.Assert(len(k.Key), check.Equals, 48)
|
||||
c.Assert(key.Key, check.NotNil)
|
||||
c.Assert(len(key.Key), check.Equals, 48)
|
||||
|
||||
// Make sure the Namespace association is populated
|
||||
c.Assert(k.Namespace.Name, check.Equals, n.Name)
|
||||
c.Assert(key.Namespace.Name, check.Equals, namespace.Name)
|
||||
|
||||
_, err = h.ListPreAuthKeys("bogus")
|
||||
_, err = app.ListPreAuthKeys("bogus")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
keys, err := h.ListPreAuthKeys(n.Name)
|
||||
keys, err := app.ListPreAuthKeys(namespace.Name)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(keys), check.Equals, 1)
|
||||
|
||||
// Make sure the Namespace association is populated
|
||||
c.Assert((keys)[0].Namespace.Name, check.Equals, n.Name)
|
||||
c.Assert((keys)[0].Namespace.Name, check.Equals, namespace.Name)
|
||||
}
|
||||
|
||||
func (*Suite) TestExpiredPreAuthKey(c *check.C) {
|
||||
n, err := h.CreateNamespace("test2")
|
||||
namespace, err := app.CreateNamespace("test2")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
now := time.Now()
|
||||
pak, err := h.CreatePreAuthKey(n.Name, true, false, &now)
|
||||
pak, err := app.CreatePreAuthKey(namespace.Name, true, false, &now)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
p, err := h.checkKeyValidity(pak.Key)
|
||||
c.Assert(err, check.Equals, errorAuthKeyExpired)
|
||||
c.Assert(p, check.IsNil)
|
||||
key, err := app.checkKeyValidity(pak.Key)
|
||||
c.Assert(err, check.Equals, errPreAuthKeyExpired)
|
||||
c.Assert(key, check.IsNil)
|
||||
}
|
||||
|
||||
func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) {
|
||||
p, err := h.checkKeyValidity("potatoKey")
|
||||
c.Assert(err, check.Equals, errorAuthKeyNotFound)
|
||||
c.Assert(p, check.IsNil)
|
||||
key, err := app.checkKeyValidity("potatoKey")
|
||||
c.Assert(err, check.Equals, errPreAuthKeyNotFound)
|
||||
c.Assert(key, check.IsNil)
|
||||
}
|
||||
|
||||
func (*Suite) TestValidateKeyOk(c *check.C) {
|
||||
n, err := h.CreateNamespace("test3")
|
||||
namespace, err := app.CreateNamespace("test3")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := h.CreatePreAuthKey(n.Name, true, false, nil)
|
||||
pak, err := app.CreatePreAuthKey(namespace.Name, true, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
p, err := h.checkKeyValidity(pak.Key)
|
||||
key, err := app.checkKeyValidity(pak.Key)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(p.ID, check.Equals, pak.ID)
|
||||
c.Assert(key.ID, check.Equals, pak.ID)
|
||||
}
|
||||
|
||||
func (*Suite) TestAlreadyUsedKey(c *check.C) {
|
||||
n, err := h.CreateNamespace("test4")
|
||||
namespace, err := app.CreateNamespace("test4")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
m := Machine{
|
||||
machine := Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "testest",
|
||||
NamespaceID: n.ID,
|
||||
NamespaceID: namespace.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
h.db.Save(&m)
|
||||
app.db.Save(&machine)
|
||||
|
||||
p, err := h.checkKeyValidity(pak.Key)
|
||||
key, err := app.checkKeyValidity(pak.Key)
|
||||
c.Assert(err, check.Equals, errSingleUseAuthKeyHasBeenUsed)
|
||||
c.Assert(p, check.IsNil)
|
||||
c.Assert(key, check.IsNil)
|
||||
}
|
||||
|
||||
func (*Suite) TestReusableBeingUsedKey(c *check.C) {
|
||||
n, err := h.CreateNamespace("test5")
|
||||
namespace, err := app.CreateNamespace("test5")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := h.CreatePreAuthKey(n.Name, true, false, nil)
|
||||
pak, err := app.CreatePreAuthKey(namespace.Name, true, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
m := Machine{
|
||||
machine := Machine{
|
||||
ID: 1,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "testest",
|
||||
NamespaceID: n.ID,
|
||||
NamespaceID: namespace.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
h.db.Save(&m)
|
||||
app.db.Save(&machine)
|
||||
|
||||
p, err := h.checkKeyValidity(pak.Key)
|
||||
key, err := app.checkKeyValidity(pak.Key)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(p.ID, check.Equals, pak.ID)
|
||||
c.Assert(key.ID, check.Equals, pak.ID)
|
||||
}
|
||||
|
||||
func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
|
||||
n, err := h.CreateNamespace("test6")
|
||||
namespace, err := app.CreateNamespace("test6")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
p, err := h.checkKeyValidity(pak.Key)
|
||||
key, err := app.checkKeyValidity(pak.Key)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(p.ID, check.Equals, pak.ID)
|
||||
c.Assert(key.ID, check.Equals, pak.ID)
|
||||
}
|
||||
|
||||
func (*Suite) TestEphemeralKey(c *check.C) {
|
||||
n, err := h.CreateNamespace("test7")
|
||||
namespace, err := app.CreateNamespace("test7")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := h.CreatePreAuthKey(n.Name, false, true, nil)
|
||||
pak, err := app.CreatePreAuthKey(namespace.Name, false, true, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
now := time.Now()
|
||||
m := Machine{
|
||||
machine := Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "testest",
|
||||
NamespaceID: n.ID,
|
||||
NamespaceID: namespace.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
LastSeen: &now,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
h.db.Save(&m)
|
||||
app.db.Save(&machine)
|
||||
|
||||
_, err = h.checkKeyValidity(pak.Key)
|
||||
_, err = app.checkKeyValidity(pak.Key)
|
||||
// Ephemeral keys are by definition reusable
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = h.GetMachine("test7", "testest")
|
||||
_, err = app.GetMachine("test7", "testest")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
h.expireEphemeralNodesWorker()
|
||||
app.expireEphemeralNodesWorker()
|
||||
|
||||
// The machine record should have been deleted
|
||||
_, err = h.GetMachine("test7", "testest")
|
||||
_, err = app.GetMachine("test7", "testest")
|
||||
c.Assert(err, check.NotNil)
|
||||
}
|
||||
|
||||
func (*Suite) TestExpirePreauthKey(c *check.C) {
|
||||
n, err := h.CreateNamespace("test3")
|
||||
namespace, err := app.CreateNamespace("test3")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := h.CreatePreAuthKey(n.Name, true, false, nil)
|
||||
pak, err := app.CreatePreAuthKey(namespace.Name, true, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(pak.Expiration, check.IsNil)
|
||||
|
||||
err = h.ExpirePreAuthKey(pak)
|
||||
err = app.ExpirePreAuthKey(pak)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(pak.Expiration, check.NotNil)
|
||||
|
||||
p, err := h.checkKeyValidity(pak.Key)
|
||||
c.Assert(err, check.Equals, errorAuthKeyExpired)
|
||||
c.Assert(p, check.IsNil)
|
||||
key, err := app.checkKeyValidity(pak.Key)
|
||||
c.Assert(err, check.Equals, errPreAuthKeyExpired)
|
||||
c.Assert(key, check.IsNil)
|
||||
}
|
||||
|
||||
func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
|
||||
n, err := h.CreateNamespace("test6")
|
||||
namespace, err := app.CreateNamespace("test6")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
pak.Used = true
|
||||
h.db.Save(&pak)
|
||||
app.db.Save(&pak)
|
||||
|
||||
_, err = h.checkKeyValidity(pak.Key)
|
||||
_, err = app.checkKeyValidity(pak.Key)
|
||||
c.Assert(err, check.Equals, errSingleUseAuthKeyHasBeenUsed)
|
||||
}
|
||||
|
|
|
@ -12,115 +12,115 @@ import "headscale/v1/routes.proto";
|
|||
|
||||
service HeadscaleService {
|
||||
// --- Namespace start ---
|
||||
rpc GetNamespace(GetNamespaceRequest) returns(GetNamespaceResponse) {
|
||||
option(google.api.http) = {
|
||||
get : "/api/v1/namespace/{name}"
|
||||
rpc GetNamespace(GetNamespaceRequest) returns (GetNamespaceResponse) {
|
||||
option (google.api.http) = {
|
||||
get: "/api/v1/namespace/{name}"
|
||||
};
|
||||
}
|
||||
|
||||
rpc CreateNamespace(CreateNamespaceRequest) returns(CreateNamespaceResponse) {
|
||||
option(google.api.http) = {
|
||||
post : "/api/v1/namespace"
|
||||
body : "*"
|
||||
rpc CreateNamespace(CreateNamespaceRequest) returns (CreateNamespaceResponse) {
|
||||
option (google.api.http) = {
|
||||
post: "/api/v1/namespace"
|
||||
body: "*"
|
||||
};
|
||||
}
|
||||
|
||||
rpc RenameNamespace(RenameNamespaceRequest) returns(RenameNamespaceResponse) {
|
||||
option(google.api.http) = {
|
||||
post : "/api/v1/namespace/{old_name}/rename/{new_name}"
|
||||
rpc RenameNamespace(RenameNamespaceRequest) returns (RenameNamespaceResponse) {
|
||||
option (google.api.http) = {
|
||||
post: "/api/v1/namespace/{old_name}/rename/{new_name}"
|
||||
};
|
||||
}
|
||||
|
||||
rpc DeleteNamespace(DeleteNamespaceRequest) returns(DeleteNamespaceResponse) {
|
||||
option(google.api.http) = {
|
||||
delete : "/api/v1/namespace/{name}"
|
||||
rpc DeleteNamespace(DeleteNamespaceRequest) returns (DeleteNamespaceResponse) {
|
||||
option (google.api.http) = {
|
||||
delete: "/api/v1/namespace/{name}"
|
||||
};
|
||||
}
|
||||
|
||||
rpc ListNamespaces(ListNamespacesRequest) returns(ListNamespacesResponse) {
|
||||
option(google.api.http) = {
|
||||
get : "/api/v1/namespace"
|
||||
rpc ListNamespaces(ListNamespacesRequest) returns (ListNamespacesResponse) {
|
||||
option (google.api.http) = {
|
||||
get: "/api/v1/namespace"
|
||||
};
|
||||
}
|
||||
// --- Namespace end ---
|
||||
|
||||
// --- PreAuthKeys start ---
|
||||
rpc CreatePreAuthKey(CreatePreAuthKeyRequest) returns(CreatePreAuthKeyResponse) {
|
||||
option(google.api.http) = {
|
||||
post : "/api/v1/preauthkey"
|
||||
body : "*"
|
||||
rpc CreatePreAuthKey(CreatePreAuthKeyRequest) returns (CreatePreAuthKeyResponse) {
|
||||
option (google.api.http) = {
|
||||
post: "/api/v1/preauthkey"
|
||||
body: "*"
|
||||
};
|
||||
}
|
||||
|
||||
rpc ExpirePreAuthKey(ExpirePreAuthKeyRequest) returns(ExpirePreAuthKeyResponse) {
|
||||
option(google.api.http) = {
|
||||
post : "/api/v1/preauthkey/expire"
|
||||
body : "*"
|
||||
rpc ExpirePreAuthKey(ExpirePreAuthKeyRequest) returns (ExpirePreAuthKeyResponse) {
|
||||
option (google.api.http) = {
|
||||
post: "/api/v1/preauthkey/expire"
|
||||
body: "*"
|
||||
};
|
||||
}
|
||||
|
||||
rpc ListPreAuthKeys(ListPreAuthKeysRequest) returns(ListPreAuthKeysResponse) {
|
||||
option(google.api.http) = {
|
||||
get : "/api/v1/preauthkey"
|
||||
rpc ListPreAuthKeys(ListPreAuthKeysRequest) returns (ListPreAuthKeysResponse) {
|
||||
option (google.api.http) = {
|
||||
get: "/api/v1/preauthkey"
|
||||
};
|
||||
}
|
||||
// --- PreAuthKeys end ---
|
||||
|
||||
// --- Machine start ---
|
||||
rpc DebugCreateMachine(DebugCreateMachineRequest) returns(DebugCreateMachineResponse) {
|
||||
option(google.api.http) = {
|
||||
post : "/api/v1/debug/machine"
|
||||
body : "*"
|
||||
rpc DebugCreateMachine(DebugCreateMachineRequest) returns (DebugCreateMachineResponse) {
|
||||
option (google.api.http) = {
|
||||
post: "/api/v1/debug/machine"
|
||||
body: "*"
|
||||
};
|
||||
}
|
||||
|
||||
rpc GetMachine(GetMachineRequest) returns(GetMachineResponse) {
|
||||
option(google.api.http) = {
|
||||
get : "/api/v1/machine/{machine_id}"
|
||||
rpc GetMachine(GetMachineRequest) returns (GetMachineResponse) {
|
||||
option (google.api.http) = {
|
||||
get: "/api/v1/machine/{machine_id}"
|
||||
};
|
||||
}
|
||||
|
||||
rpc RegisterMachine(RegisterMachineRequest) returns(RegisterMachineResponse) {
|
||||
option(google.api.http) = {
|
||||
post : "/api/v1/machine/register"
|
||||
rpc RegisterMachine(RegisterMachineRequest) returns (RegisterMachineResponse) {
|
||||
option (google.api.http) = {
|
||||
post: "/api/v1/machine/register"
|
||||
};
|
||||
}
|
||||
|
||||
rpc DeleteMachine(DeleteMachineRequest) returns(DeleteMachineResponse) {
|
||||
option(google.api.http) = {
|
||||
delete : "/api/v1/machine/{machine_id}"
|
||||
rpc DeleteMachine(DeleteMachineRequest) returns (DeleteMachineResponse) {
|
||||
option (google.api.http) = {
|
||||
delete: "/api/v1/machine/{machine_id}"
|
||||
};
|
||||
}
|
||||
|
||||
rpc ListMachines(ListMachinesRequest) returns(ListMachinesResponse) {
|
||||
option(google.api.http) = {
|
||||
get : "/api/v1/machine"
|
||||
rpc ListMachines(ListMachinesRequest) returns (ListMachinesResponse) {
|
||||
option (google.api.http) = {
|
||||
get: "/api/v1/machine"
|
||||
};
|
||||
}
|
||||
|
||||
rpc ShareMachine(ShareMachineRequest) returns(ShareMachineResponse) {
|
||||
option(google.api.http) = {
|
||||
post : "/api/v1/machine/{machine_id}/share/{namespace}"
|
||||
rpc ShareMachine(ShareMachineRequest) returns (ShareMachineResponse) {
|
||||
option (google.api.http) = {
|
||||
post: "/api/v1/machine/{machine_id}/share/{namespace}"
|
||||
};
|
||||
}
|
||||
|
||||
rpc UnshareMachine(UnshareMachineRequest) returns(UnshareMachineResponse) {
|
||||
option(google.api.http) = {
|
||||
post : "/api/v1/machine/{machine_id}/unshare/{namespace}"
|
||||
rpc UnshareMachine(UnshareMachineRequest) returns (UnshareMachineResponse) {
|
||||
option (google.api.http) = {
|
||||
post: "/api/v1/machine/{machine_id}/unshare/{namespace}"
|
||||
};
|
||||
}
|
||||
// --- Machine end ---
|
||||
|
||||
// --- Route start ---
|
||||
rpc GetMachineRoute(GetMachineRouteRequest) returns(GetMachineRouteResponse) {
|
||||
option(google.api.http) = {
|
||||
get : "/api/v1/machine/{machine_id}/routes"
|
||||
rpc GetMachineRoute(GetMachineRouteRequest) returns (GetMachineRouteResponse) {
|
||||
option (google.api.http) = {
|
||||
get: "/api/v1/machine/{machine_id}/routes"
|
||||
};
|
||||
}
|
||||
|
||||
rpc EnableMachineRoutes(EnableMachineRoutesRequest) returns(EnableMachineRoutesResponse) {
|
||||
option(google.api.http) = {
|
||||
post : "/api/v1/machine/{machine_id}/routes"
|
||||
rpc EnableMachineRoutes(EnableMachineRoutesRequest) returns (EnableMachineRoutesResponse) {
|
||||
option (google.api.http) = {
|
||||
post: "/api/v1/machine/{machine_id}/routes"
|
||||
};
|
||||
}
|
||||
// --- Route end ---
|
||||
|
|
55
routes.go
55
routes.go
|
@ -2,38 +2,48 @@ package headscale
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"gorm.io/datatypes"
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
const (
|
||||
errRouteIsNotAvailable = Error("route is not available")
|
||||
)
|
||||
|
||||
// Deprecated: use machine function instead
|
||||
// GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by
|
||||
// namespace and node name)
|
||||
func (h *Headscale) GetAdvertisedNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) {
|
||||
m, err := h.GetMachine(namespace, nodeName)
|
||||
// namespace and node name).
|
||||
func (h *Headscale) GetAdvertisedNodeRoutes(
|
||||
namespace string,
|
||||
nodeName string,
|
||||
) (*[]netaddr.IPPrefix, error) {
|
||||
machine, err := h.GetMachine(namespace, nodeName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hostInfo, err := m.GetHostInfo()
|
||||
hostInfo, err := machine.GetHostInfo()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &hostInfo.RoutableIPs, nil
|
||||
}
|
||||
|
||||
// Deprecated: use machine function instead
|
||||
// GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by
|
||||
// namespace and node name)
|
||||
func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]netaddr.IPPrefix, error) {
|
||||
m, err := h.GetMachine(namespace, nodeName)
|
||||
// namespace and node name).
|
||||
func (h *Headscale) GetEnabledNodeRoutes(
|
||||
namespace string,
|
||||
nodeName string,
|
||||
) ([]netaddr.IPPrefix, error) {
|
||||
machine, err := h.GetMachine(namespace, nodeName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, err := m.EnabledRoutes.MarshalJSON()
|
||||
data, err := machine.EnabledRoutes.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -57,8 +67,12 @@ func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]n
|
|||
}
|
||||
|
||||
// Deprecated: use machine function instead
|
||||
// IsNodeRouteEnabled checks if a certain route has been enabled
|
||||
func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeStr string) bool {
|
||||
// IsNodeRouteEnabled checks if a certain route has been enabled.
|
||||
func (h *Headscale) IsNodeRouteEnabled(
|
||||
namespace string,
|
||||
nodeName string,
|
||||
routeStr string,
|
||||
) bool {
|
||||
route, err := netaddr.ParseIPPrefix(routeStr)
|
||||
if err != nil {
|
||||
return false
|
||||
|
@ -74,14 +88,19 @@ func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeS
|
|||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Deprecated: use EnableRoute in machine.go
|
||||
// EnableNodeRoute enables a subnet route advertised by a node (identified by
|
||||
// namespace and node name)
|
||||
func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) error {
|
||||
m, err := h.GetMachine(namespace, nodeName)
|
||||
// namespace and node name).
|
||||
func (h *Headscale) EnableNodeRoute(
|
||||
namespace string,
|
||||
nodeName string,
|
||||
routeStr string,
|
||||
) error {
|
||||
machine, err := h.GetMachine(namespace, nodeName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -113,7 +132,7 @@ func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr
|
|||
}
|
||||
|
||||
if !available {
|
||||
return fmt.Errorf("route (%s) is not available on node %s", nodeName, routeStr)
|
||||
return errRouteIsNotAvailable
|
||||
}
|
||||
|
||||
routes, err := json.Marshal(enabledRoutes)
|
||||
|
@ -121,10 +140,10 @@ func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr
|
|||
return err
|
||||
}
|
||||
|
||||
m.EnabledRoutes = datatypes.JSON(routes)
|
||||
h.db.Save(&m)
|
||||
machine.EnabledRoutes = datatypes.JSON(routes)
|
||||
h.db.Save(&machine)
|
||||
|
||||
err = h.RequestMapUpdates(m.NamespaceID)
|
||||
err = h.RequestMapUpdates(machine.NamespaceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -10,57 +10,60 @@ import (
|
|||
)
|
||||
|
||||
func (s *Suite) TestGetRoutes(c *check.C) {
|
||||
n, err := h.CreateNamespace("test")
|
||||
namespace, err := app.CreateNamespace("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = h.GetMachine("test", "test_get_route_machine")
|
||||
_, err = app.GetMachine("test", "test_get_route_machine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
route, err := netaddr.ParseIPPrefix("10.0.0.0/24")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
hi := tailcfg.Hostinfo{
|
||||
hostInfo := tailcfg.Hostinfo{
|
||||
RoutableIPs: []netaddr.IPPrefix{route},
|
||||
}
|
||||
hostinfo, err := json.Marshal(hi)
|
||||
hostinfo, err := json.Marshal(hostInfo)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
m := Machine{
|
||||
machine := Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "test_get_route_machine",
|
||||
NamespaceID: n.ID,
|
||||
NamespaceID: namespace.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: datatypes.JSON(hostinfo),
|
||||
}
|
||||
h.db.Save(&m)
|
||||
app.db.Save(&machine)
|
||||
|
||||
r, err := h.GetAdvertisedNodeRoutes("test", "test_get_route_machine")
|
||||
advertisedRoutes, err := app.GetAdvertisedNodeRoutes(
|
||||
"test",
|
||||
"test_get_route_machine",
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(*r), check.Equals, 1)
|
||||
c.Assert(len(*advertisedRoutes), check.Equals, 1)
|
||||
|
||||
err = h.EnableNodeRoute("test", "test_get_route_machine", "192.168.0.0/24")
|
||||
err = app.EnableNodeRoute("test", "test_get_route_machine", "192.168.0.0/24")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
err = h.EnableNodeRoute("test", "test_get_route_machine", "10.0.0.0/24")
|
||||
err = app.EnableNodeRoute("test", "test_get_route_machine", "10.0.0.0/24")
|
||||
c.Assert(err, check.IsNil)
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||
n, err := h.CreateNamespace("test")
|
||||
namespace, err := app.CreateNamespace("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = h.GetMachine("test", "test_enable_route_machine")
|
||||
_, err = app.GetMachine("test", "test_enable_route_machine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
route, err := netaddr.ParseIPPrefix(
|
||||
|
@ -73,56 +76,68 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
|||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
hi := tailcfg.Hostinfo{
|
||||
hostInfo := tailcfg.Hostinfo{
|
||||
RoutableIPs: []netaddr.IPPrefix{route, route2},
|
||||
}
|
||||
hostinfo, err := json.Marshal(hi)
|
||||
hostinfo, err := json.Marshal(hostInfo)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
m := Machine{
|
||||
machine := Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "test_enable_route_machine",
|
||||
NamespaceID: n.ID,
|
||||
NamespaceID: namespace.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: datatypes.JSON(hostinfo),
|
||||
}
|
||||
h.db.Save(&m)
|
||||
app.db.Save(&machine)
|
||||
|
||||
availableRoutes, err := h.GetAdvertisedNodeRoutes("test", "test_enable_route_machine")
|
||||
availableRoutes, err := app.GetAdvertisedNodeRoutes(
|
||||
"test",
|
||||
"test_enable_route_machine",
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(*availableRoutes), check.Equals, 2)
|
||||
|
||||
enabledRoutes, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
|
||||
noEnabledRoutes, err := app.GetEnabledNodeRoutes(
|
||||
"test",
|
||||
"test_enable_route_machine",
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes), check.Equals, 0)
|
||||
c.Assert(len(noEnabledRoutes), check.Equals, 0)
|
||||
|
||||
err = h.EnableNodeRoute("test", "test_enable_route_machine", "192.168.0.0/24")
|
||||
err = app.EnableNodeRoute("test", "test_enable_route_machine", "192.168.0.0/24")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
err = h.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
|
||||
err = app.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes1, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
|
||||
enabledRoutes, err := app.GetEnabledNodeRoutes("test", "test_enable_route_machine")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
||||
c.Assert(len(enabledRoutes), check.Equals, 1)
|
||||
|
||||
// Adding it twice will just let it pass through
|
||||
err = h.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
|
||||
err = app.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes2, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
|
||||
enableRoutesAfterDoubleApply, err := app.GetEnabledNodeRoutes(
|
||||
"test",
|
||||
"test_enable_route_machine",
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes2), check.Equals, 1)
|
||||
c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1)
|
||||
|
||||
err = h.EnableNodeRoute("test", "test_enable_route_machine", "150.0.10.0/25")
|
||||
err = app.EnableNodeRoute("test", "test_enable_route_machine", "150.0.10.0/25")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes3, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
|
||||
enabledRoutesWithAdditionalRoute, err := app.GetEnabledNodeRoutes(
|
||||
"test",
|
||||
"test_enable_route_machine",
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes3), check.Equals, 2)
|
||||
c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2)
|
||||
}
|
||||
|
|
58
sharing.go
58
sharing.go
|
@ -2,11 +2,13 @@ package headscale
|
|||
|
||||
import "gorm.io/gorm"
|
||||
|
||||
const errorSameNamespace = Error("Destination namespace same as origin")
|
||||
const errorMachineAlreadyShared = Error("Node already shared to this namespace")
|
||||
const errorMachineNotShared = Error("Machine not shared to this namespace")
|
||||
const (
|
||||
errSameNamespace = Error("Destination namespace same as origin")
|
||||
errMachineAlreadyShared = Error("Node already shared to this namespace")
|
||||
errMachineNotShared = Error("Machine not shared to this namespace")
|
||||
)
|
||||
|
||||
// SharedMachine is a join table to support sharing nodes between namespaces
|
||||
// SharedMachine is a join table to support sharing nodes between namespaces.
|
||||
type SharedMachine struct {
|
||||
gorm.Model
|
||||
MachineID uint64
|
||||
|
@ -15,49 +17,57 @@ type SharedMachine struct {
|
|||
Namespace Namespace
|
||||
}
|
||||
|
||||
// AddSharedMachineToNamespace adds a machine as a shared node to a namespace
|
||||
func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error {
|
||||
if m.NamespaceID == ns.ID {
|
||||
return errorSameNamespace
|
||||
// AddSharedMachineToNamespace adds a machine as a shared node to a namespace.
|
||||
func (h *Headscale) AddSharedMachineToNamespace(
|
||||
machine *Machine,
|
||||
namespace *Namespace,
|
||||
) error {
|
||||
if machine.NamespaceID == namespace.ID {
|
||||
return errSameNamespace
|
||||
}
|
||||
|
||||
sharedMachines := []SharedMachine{}
|
||||
if err := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).Find(&sharedMachines).Error; err != nil {
|
||||
if err := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID).Find(&sharedMachines).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if len(sharedMachines) > 0 {
|
||||
return errorMachineAlreadyShared
|
||||
return errMachineAlreadyShared
|
||||
}
|
||||
|
||||
sharedMachine := SharedMachine{
|
||||
MachineID: m.ID,
|
||||
Machine: *m,
|
||||
NamespaceID: ns.ID,
|
||||
Namespace: *ns,
|
||||
MachineID: machine.ID,
|
||||
Machine: *machine,
|
||||
NamespaceID: namespace.ID,
|
||||
Namespace: *namespace,
|
||||
}
|
||||
h.db.Save(&sharedMachine)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveSharedMachineFromNamespace removes a shared machine from a namespace
|
||||
func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace) error {
|
||||
if m.NamespaceID == ns.ID {
|
||||
// RemoveSharedMachineFromNamespace removes a shared machine from a namespace.
|
||||
func (h *Headscale) RemoveSharedMachineFromNamespace(
|
||||
machine *Machine,
|
||||
namespace *Namespace,
|
||||
) error {
|
||||
if machine.NamespaceID == namespace.ID {
|
||||
// Can't unshare from primary namespace
|
||||
return errorMachineNotShared
|
||||
return errMachineNotShared
|
||||
}
|
||||
|
||||
sharedMachine := SharedMachine{}
|
||||
result := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).Unscoped().Delete(&sharedMachine)
|
||||
result := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID).
|
||||
Unscoped().
|
||||
Delete(&sharedMachine)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return errorMachineNotShared
|
||||
return errMachineNotShared
|
||||
}
|
||||
|
||||
err := h.RequestMapUpdates(ns.ID)
|
||||
err := h.RequestMapUpdates(namespace.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -65,10 +75,10 @@ func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace)
|
|||
return nil
|
||||
}
|
||||
|
||||
// RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces
|
||||
func (h *Headscale) RemoveSharedMachineFromAllNamespaces(m *Machine) error {
|
||||
// RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces.
|
||||
func (h *Headscale) RemoveSharedMachineFromAllNamespaces(machine *Machine) error {
|
||||
sharedMachine := SharedMachine{}
|
||||
if result := h.db.Where("machine_id = ?", m.ID).Unscoped().Delete(&sharedMachine); result.Error != nil {
|
||||
if result := h.db.Where("machine_id = ?", machine.ID).Unscoped().Delete(&sharedMachine); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
|
|
227
sharing_test.go
227
sharing_test.go
|
@ -4,45 +4,48 @@ import (
|
|||
"gopkg.in/check.v1"
|
||||
)
|
||||
|
||||
func CreateNodeNamespace(c *check.C, namespace, node, key, IP string) (*Namespace, *Machine) {
|
||||
n1, err := h.CreateNamespace(namespace)
|
||||
func CreateNodeNamespace(
|
||||
c *check.C,
|
||||
namespaceName, node, key, ip string,
|
||||
) (*Namespace, *Machine) {
|
||||
namespace, err := app.CreateNamespace(namespaceName)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
||||
pak1, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = h.GetMachine(n1.Name, node)
|
||||
_, err = app.GetMachine(namespace.Name, node)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
m1 := &Machine{
|
||||
machine := &Machine{
|
||||
ID: 0,
|
||||
MachineKey: key,
|
||||
NodeKey: key,
|
||||
DiscoKey: key,
|
||||
Name: node,
|
||||
NamespaceID: n1.ID,
|
||||
NamespaceID: namespace.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
IPAddress: IP,
|
||||
IPAddress: ip,
|
||||
AuthKeyID: uint(pak1.ID),
|
||||
}
|
||||
h.db.Save(m1)
|
||||
app.db.Save(machine)
|
||||
|
||||
_, err = h.GetMachine(n1.Name, m1.Name)
|
||||
_, err = app.GetMachine(namespace.Name, machine.Name)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
return n1, m1
|
||||
return namespace, machine
|
||||
}
|
||||
|
||||
func (s *Suite) TestBasicSharedNodesInNamespace(c *check.C) {
|
||||
n1, m1 := CreateNodeNamespace(
|
||||
namespace1, machine1 := CreateNodeNamespace(
|
||||
c,
|
||||
"shared1",
|
||||
"test_get_shared_nodes_1",
|
||||
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
"100.64.0.1",
|
||||
)
|
||||
_, m2 := CreateNodeNamespace(
|
||||
_, machine2 := CreateNodeNamespace(
|
||||
c,
|
||||
"shared2",
|
||||
"test_get_shared_nodes_2",
|
||||
|
@ -50,21 +53,21 @@ func (s *Suite) TestBasicSharedNodesInNamespace(c *check.C) {
|
|||
"100.64.0.2",
|
||||
)
|
||||
|
||||
p1s, err := h.getPeers(m1)
|
||||
peersOfMachine1BeforeShared, err := app.getPeers(machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(p1s), check.Equals, 0)
|
||||
c.Assert(len(peersOfMachine1BeforeShared), check.Equals, 0)
|
||||
|
||||
err = h.AddSharedMachineToNamespace(m2, n1)
|
||||
err = app.AddSharedMachineToNamespace(machine2, namespace1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
p1sAfter, err := h.getPeers(m1)
|
||||
peersOfMachine1AfterShared, err := app.getPeers(machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(p1sAfter), check.Equals, 1)
|
||||
c.Assert(p1sAfter[0].ID, check.Equals, m2.ID)
|
||||
c.Assert(len(peersOfMachine1AfterShared), check.Equals, 1)
|
||||
c.Assert(peersOfMachine1AfterShared[0].ID, check.Equals, machine2.ID)
|
||||
}
|
||||
|
||||
func (s *Suite) TestSameNamespace(c *check.C) {
|
||||
n1, m1 := CreateNodeNamespace(
|
||||
namespace1, machine1 := CreateNodeNamespace(
|
||||
c,
|
||||
"shared1",
|
||||
"test_get_shared_nodes_1",
|
||||
|
@ -72,23 +75,23 @@ func (s *Suite) TestSameNamespace(c *check.C) {
|
|||
"100.64.0.1",
|
||||
)
|
||||
|
||||
p1s, err := h.getPeers(m1)
|
||||
peersOfMachine1BeforeShare, err := app.getPeers(machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(p1s), check.Equals, 0)
|
||||
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
|
||||
|
||||
err = h.AddSharedMachineToNamespace(m1, n1)
|
||||
c.Assert(err, check.Equals, errorSameNamespace)
|
||||
err = app.AddSharedMachineToNamespace(machine1, namespace1)
|
||||
c.Assert(err, check.Equals, errSameNamespace)
|
||||
}
|
||||
|
||||
func (s *Suite) TestUnshare(c *check.C) {
|
||||
n1, m1 := CreateNodeNamespace(
|
||||
namespace1, machine1 := CreateNodeNamespace(
|
||||
c,
|
||||
"shared1",
|
||||
"test_unshare_1",
|
||||
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
"100.64.0.1",
|
||||
)
|
||||
_, m2 := CreateNodeNamespace(
|
||||
_, machine2 := CreateNodeNamespace(
|
||||
c,
|
||||
"shared2",
|
||||
"test_unshare_2",
|
||||
|
@ -96,40 +99,40 @@ func (s *Suite) TestUnshare(c *check.C) {
|
|||
"100.64.0.2",
|
||||
)
|
||||
|
||||
p1s, err := h.getPeers(m1)
|
||||
peersOfMachine1BeforeShare, err := app.getPeers(machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(p1s), check.Equals, 0)
|
||||
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
|
||||
|
||||
err = h.AddSharedMachineToNamespace(m2, n1)
|
||||
err = app.AddSharedMachineToNamespace(machine2, namespace1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
p1s, err = h.getShared(m1)
|
||||
peersOfMachine1BeforeShare, err = app.getShared(machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(p1s), check.Equals, 1)
|
||||
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 1)
|
||||
|
||||
err = h.RemoveSharedMachineFromNamespace(m2, n1)
|
||||
err = app.RemoveSharedMachineFromNamespace(machine2, namespace1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
p1s, err = h.getShared(m1)
|
||||
peersOfMachine1BeforeShare, err = app.getShared(machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(p1s), check.Equals, 0)
|
||||
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
|
||||
|
||||
err = h.RemoveSharedMachineFromNamespace(m2, n1)
|
||||
c.Assert(err, check.Equals, errorMachineNotShared)
|
||||
err = app.RemoveSharedMachineFromNamespace(machine2, namespace1)
|
||||
c.Assert(err, check.Equals, errMachineNotShared)
|
||||
|
||||
err = h.RemoveSharedMachineFromNamespace(m1, n1)
|
||||
c.Assert(err, check.Equals, errorMachineNotShared)
|
||||
err = app.RemoveSharedMachineFromNamespace(machine1, namespace1)
|
||||
c.Assert(err, check.Equals, errMachineNotShared)
|
||||
}
|
||||
|
||||
func (s *Suite) TestAlreadyShared(c *check.C) {
|
||||
n1, m1 := CreateNodeNamespace(
|
||||
namespace1, machine1 := CreateNodeNamespace(
|
||||
c,
|
||||
"shared1",
|
||||
"test_get_shared_nodes_1",
|
||||
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
"100.64.0.1",
|
||||
)
|
||||
_, m2 := CreateNodeNamespace(
|
||||
_, machine2 := CreateNodeNamespace(
|
||||
c,
|
||||
"shared2",
|
||||
"test_get_shared_nodes_2",
|
||||
|
@ -137,25 +140,25 @@ func (s *Suite) TestAlreadyShared(c *check.C) {
|
|||
"100.64.0.2",
|
||||
)
|
||||
|
||||
p1s, err := h.getPeers(m1)
|
||||
peersOfMachine1BeforeShare, err := app.getPeers(machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(p1s), check.Equals, 0)
|
||||
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
|
||||
|
||||
err = h.AddSharedMachineToNamespace(m2, n1)
|
||||
err = app.AddSharedMachineToNamespace(machine2, namespace1)
|
||||
c.Assert(err, check.IsNil)
|
||||
err = h.AddSharedMachineToNamespace(m2, n1)
|
||||
c.Assert(err, check.Equals, errorMachineAlreadyShared)
|
||||
err = app.AddSharedMachineToNamespace(machine2, namespace1)
|
||||
c.Assert(err, check.Equals, errMachineAlreadyShared)
|
||||
}
|
||||
|
||||
func (s *Suite) TestDoNotIncludeRoutesOnShared(c *check.C) {
|
||||
n1, m1 := CreateNodeNamespace(
|
||||
namespace1, machine1 := CreateNodeNamespace(
|
||||
c,
|
||||
"shared1",
|
||||
"test_get_shared_nodes_1",
|
||||
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
"100.64.0.1",
|
||||
)
|
||||
_, m2 := CreateNodeNamespace(
|
||||
_, machine2 := CreateNodeNamespace(
|
||||
c,
|
||||
"shared2",
|
||||
"test_get_shared_nodes_2",
|
||||
|
@ -163,35 +166,35 @@ func (s *Suite) TestDoNotIncludeRoutesOnShared(c *check.C) {
|
|||
"100.64.0.2",
|
||||
)
|
||||
|
||||
p1s, err := h.getPeers(m1)
|
||||
peersOfMachine1BeforeShare, err := app.getPeers(machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(p1s), check.Equals, 0)
|
||||
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
|
||||
|
||||
err = h.AddSharedMachineToNamespace(m2, n1)
|
||||
err = app.AddSharedMachineToNamespace(machine2, namespace1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
p1sAfter, err := h.getPeers(m1)
|
||||
peersOfMachine1AfterShare, err := app.getPeers(machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(p1sAfter), check.Equals, 1)
|
||||
c.Assert(p1sAfter[0].Name, check.Equals, "test_get_shared_nodes_2")
|
||||
c.Assert(len(peersOfMachine1AfterShare), check.Equals, 1)
|
||||
c.Assert(peersOfMachine1AfterShare[0].Name, check.Equals, "test_get_shared_nodes_2")
|
||||
}
|
||||
|
||||
func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
|
||||
n1, m1 := CreateNodeNamespace(
|
||||
namespace1, machine1 := CreateNodeNamespace(
|
||||
c,
|
||||
"shared1",
|
||||
"test_get_shared_nodes_1",
|
||||
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
"100.64.0.1",
|
||||
)
|
||||
_, m2 := CreateNodeNamespace(
|
||||
_, machine2 := CreateNodeNamespace(
|
||||
c,
|
||||
"shared2",
|
||||
"test_get_shared_nodes_2",
|
||||
"dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
"100.64.0.2",
|
||||
)
|
||||
_, m3 := CreateNodeNamespace(
|
||||
_, machine3 := CreateNodeNamespace(
|
||||
c,
|
||||
"shared3",
|
||||
"test_get_shared_nodes_3",
|
||||
|
@ -199,76 +202,80 @@ func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
|
|||
"100.64.0.3",
|
||||
)
|
||||
|
||||
pak4, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
||||
pak4, err := app.CreatePreAuthKey(namespace1.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
m4 := &Machine{
|
||||
machine4 := &Machine{
|
||||
ID: 4,
|
||||
MachineKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
|
||||
NodeKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
|
||||
DiscoKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
|
||||
Name: "test_get_shared_nodes_4",
|
||||
NamespaceID: n1.ID,
|
||||
NamespaceID: namespace1.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
IPAddress: "100.64.0.4",
|
||||
AuthKeyID: uint(pak4.ID),
|
||||
}
|
||||
h.db.Save(m4)
|
||||
app.db.Save(machine4)
|
||||
|
||||
_, err = h.GetMachine(n1.Name, m4.Name)
|
||||
_, err = app.GetMachine(namespace1.Name, machine4.Name)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
p1s, err := h.getPeers(m1)
|
||||
peersOfMachine1BeforeShare, err := app.getPeers(machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(p1s), check.Equals, 1) // node1 can see node4
|
||||
c.Assert(p1s[0].Name, check.Equals, m4.Name)
|
||||
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 1) // node1 can see node4
|
||||
c.Assert(peersOfMachine1BeforeShare[0].Name, check.Equals, machine4.Name)
|
||||
|
||||
err = h.AddSharedMachineToNamespace(m2, n1)
|
||||
err = app.AddSharedMachineToNamespace(machine2, namespace1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
p1sAfter, err := h.getPeers(m1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(p1sAfter), check.Equals, 2) // node1 can see node2 (shared) and node4 (same namespace)
|
||||
c.Assert(p1sAfter[0].Name, check.Equals, m2.Name)
|
||||
c.Assert(p1sAfter[1].Name, check.Equals, m4.Name)
|
||||
|
||||
node1shared, err := h.getShared(m1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(node1shared), check.Equals, 1) // node1 can see node2 as shared
|
||||
c.Assert(node1shared[0].Name, check.Equals, m2.Name)
|
||||
|
||||
pAlone, err := h.getPeers(m3)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(pAlone), check.Equals, 0) // node3 is alone
|
||||
|
||||
pSharedTo, err := h.getPeers(m2)
|
||||
peersOfMachine1AfterShare, err := app.getPeers(machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(
|
||||
len(pSharedTo),
|
||||
len(peersOfMachine1AfterShare),
|
||||
check.Equals,
|
||||
2,
|
||||
) // node1 can see node2 (shared) and node4 (same namespace)
|
||||
c.Assert(peersOfMachine1AfterShare[0].Name, check.Equals, machine2.Name)
|
||||
c.Assert(peersOfMachine1AfterShare[1].Name, check.Equals, machine4.Name)
|
||||
|
||||
sharedOfMachine1, err := app.getShared(machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(sharedOfMachine1), check.Equals, 1) // node1 can see node2 as shared
|
||||
c.Assert(sharedOfMachine1[0].Name, check.Equals, machine2.Name)
|
||||
|
||||
peersOfMachine3, err := app.getPeers(machine3)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(peersOfMachine3), check.Equals, 0) // node3 is alone
|
||||
|
||||
peersOfMachine2, err := app.getPeers(machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(
|
||||
len(peersOfMachine2),
|
||||
check.Equals,
|
||||
2,
|
||||
) // node2 should see node1 (sharedTo) and node4 (sharedTo), as is shared in namespace1
|
||||
c.Assert(pSharedTo[0].Name, check.Equals, m1.Name)
|
||||
c.Assert(pSharedTo[1].Name, check.Equals, m4.Name)
|
||||
c.Assert(peersOfMachine2[0].Name, check.Equals, machine1.Name)
|
||||
c.Assert(peersOfMachine2[1].Name, check.Equals, machine4.Name)
|
||||
}
|
||||
|
||||
func (s *Suite) TestDeleteSharedMachine(c *check.C) {
|
||||
n1, m1 := CreateNodeNamespace(
|
||||
namespace1, machine1 := CreateNodeNamespace(
|
||||
c,
|
||||
"shared1",
|
||||
"test_get_shared_nodes_1",
|
||||
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
"100.64.0.1",
|
||||
)
|
||||
_, m2 := CreateNodeNamespace(
|
||||
_, machine2 := CreateNodeNamespace(
|
||||
c,
|
||||
"shared2",
|
||||
"test_get_shared_nodes_2",
|
||||
"dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
"100.64.0.2",
|
||||
)
|
||||
_, m3 := CreateNodeNamespace(
|
||||
_, machine3 := CreateNodeNamespace(
|
||||
c,
|
||||
"shared3",
|
||||
"test_get_shared_nodes_3",
|
||||
|
@ -276,56 +283,58 @@ func (s *Suite) TestDeleteSharedMachine(c *check.C) {
|
|||
"100.64.0.3",
|
||||
)
|
||||
|
||||
pak4n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
||||
pak4n1, err := app.CreatePreAuthKey(namespace1.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
m4 := &Machine{
|
||||
machine4 := &Machine{
|
||||
ID: 4,
|
||||
MachineKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
|
||||
NodeKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
|
||||
DiscoKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
|
||||
Name: "test_get_shared_nodes_4",
|
||||
NamespaceID: n1.ID,
|
||||
NamespaceID: namespace1.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
IPAddress: "100.64.0.4",
|
||||
AuthKeyID: uint(pak4n1.ID),
|
||||
}
|
||||
h.db.Save(m4)
|
||||
app.db.Save(machine4)
|
||||
|
||||
_, err = h.GetMachine(n1.Name, m4.Name)
|
||||
_, err = app.GetMachine(namespace1.Name, machine4.Name)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
p1s, err := h.getPeers(m1)
|
||||
peersOfMachine1BeforeShare, err := app.getPeers(machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(p1s), check.Equals, 1) // nodes 1 and 4
|
||||
c.Assert(p1s[0].Name, check.Equals, m4.Name)
|
||||
c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 1) // nodes 1 and 4
|
||||
c.Assert(peersOfMachine1BeforeShare[0].Name, check.Equals, machine4.Name)
|
||||
|
||||
err = h.AddSharedMachineToNamespace(m2, n1)
|
||||
err = app.AddSharedMachineToNamespace(machine2, namespace1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
p1sAfter, err := h.getPeers(m1)
|
||||
peersOfMachine1AfterShare, err := app.getPeers(machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(p1sAfter), check.Equals, 2) // nodes 1, 2, 4
|
||||
c.Assert(p1sAfter[0].Name, check.Equals, m2.Name)
|
||||
c.Assert(p1sAfter[1].Name, check.Equals, m4.Name)
|
||||
c.Assert(len(peersOfMachine1AfterShare), check.Equals, 2) // nodes 1, 2, 4
|
||||
c.Assert(peersOfMachine1AfterShare[0].Name, check.Equals, machine2.Name)
|
||||
c.Assert(peersOfMachine1AfterShare[1].Name, check.Equals, machine4.Name)
|
||||
|
||||
node1shared, err := h.getShared(m1)
|
||||
sharedOfMachine1, err := app.getShared(machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(node1shared), check.Equals, 1) // nodes 1, 2, 4
|
||||
c.Assert(node1shared[0].Name, check.Equals, m2.Name)
|
||||
c.Assert(len(sharedOfMachine1), check.Equals, 1) // nodes 1, 2, 4
|
||||
c.Assert(sharedOfMachine1[0].Name, check.Equals, machine2.Name)
|
||||
|
||||
pAlone, err := h.getPeers(m3)
|
||||
peersOfMachine3, err := app.getPeers(machine3)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(pAlone), check.Equals, 0) // node 3 is alone
|
||||
c.Assert(len(peersOfMachine3), check.Equals, 0) // node 3 is alone
|
||||
|
||||
sharedMachines, err := h.ListSharedMachinesInNamespace(n1.Name)
|
||||
sharedMachinesInNamespace1, err := app.ListSharedMachinesInNamespace(
|
||||
namespace1.Name,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(sharedMachines), check.Equals, 1)
|
||||
c.Assert(len(sharedMachinesInNamespace1), check.Equals, 1)
|
||||
|
||||
err = h.DeleteMachine(m2)
|
||||
err = app.DeleteMachine(machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
sharedMachines, err = h.ListSharedMachinesInNamespace(n1.Name)
|
||||
sharedMachinesInNamespace1, err = app.ListSharedMachinesInNamespace(namespace1.Name)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(sharedMachines), check.Equals, 0)
|
||||
c.Assert(len(sharedMachinesInNamespace1), check.Equals, 0)
|
||||
}
|
||||
|
|
22
swagger.go
22
swagger.go
|
@ -6,16 +6,15 @@ import (
|
|||
"net/http"
|
||||
"text/template"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
//go:embed gen/openapiv2/headscale/v1/headscale.swagger.json
|
||||
var apiV1JSON []byte
|
||||
|
||||
func SwaggerUI(c *gin.Context) {
|
||||
t := template.Must(template.New("swagger").Parse(`
|
||||
func SwaggerUI(ctx *gin.Context) {
|
||||
swaggerTemplate := template.Must(template.New("swagger").Parse(`
|
||||
<html>
|
||||
<head>
|
||||
<link rel="stylesheet" type="text/css" href="https://unpkg.com/swagger-ui-dist@3/swagger-ui.css">
|
||||
|
@ -48,18 +47,23 @@ func SwaggerUI(c *gin.Context) {
|
|||
</html>`))
|
||||
|
||||
var payload bytes.Buffer
|
||||
if err := t.Execute(&payload, struct{}{}); err != nil {
|
||||
if err := swaggerTemplate.Execute(&payload, struct{}{}); err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Could not render Swagger")
|
||||
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Swagger"))
|
||||
ctx.Data(
|
||||
http.StatusInternalServerError,
|
||||
"text/html; charset=utf-8",
|
||||
[]byte("Could not render Swagger"),
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
|
||||
ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
|
||||
}
|
||||
|
||||
func SwaggerAPIv1(c *gin.Context) {
|
||||
c.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON)
|
||||
func SwaggerAPIv1(ctx *gin.Context) {
|
||||
ctx.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON)
|
||||
}
|
||||
|
|
57
utils.go
57
utils.go
|
@ -20,31 +20,48 @@ import (
|
|||
"tailscale.com/types/wgkey"
|
||||
)
|
||||
|
||||
const (
|
||||
errCannotDecryptReponse = Error("cannot decrypt response")
|
||||
errResponseMissingNonce = Error("response missing nonce")
|
||||
errCouldNotAllocateIP = Error("could not find any suitable IP")
|
||||
)
|
||||
|
||||
// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors
|
||||
type Error string
|
||||
|
||||
func (e Error) Error() string { return string(e) }
|
||||
|
||||
func decode(msg []byte, v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) error {
|
||||
func decode(
|
||||
msg []byte,
|
||||
v interface{},
|
||||
pubKey *wgkey.Key,
|
||||
privKey *wgkey.Private,
|
||||
) error {
|
||||
return decodeMsg(msg, v, pubKey, privKey)
|
||||
}
|
||||
|
||||
func decodeMsg(msg []byte, v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) error {
|
||||
func decodeMsg(
|
||||
msg []byte,
|
||||
output interface{},
|
||||
pubKey *wgkey.Key,
|
||||
privKey *wgkey.Private,
|
||||
) error {
|
||||
decrypted, err := decryptMsg(msg, pubKey, privKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// fmt.Println(string(decrypted))
|
||||
if err := json.Unmarshal(decrypted, v); err != nil {
|
||||
return fmt.Errorf("response: %v", err)
|
||||
if err := json.Unmarshal(decrypted, output); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decryptMsg(msg []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) {
|
||||
var nonce [24]byte
|
||||
if len(msg) < len(nonce)+1 {
|
||||
return nil, fmt.Errorf("response missing nonce, len=%d", len(msg))
|
||||
return nil, errResponseMissingNonce
|
||||
}
|
||||
copy(nonce[:], msg)
|
||||
msg = msg[len(nonce):]
|
||||
|
@ -52,8 +69,9 @@ func decryptMsg(msg []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte,
|
|||
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
|
||||
decrypted, ok := box.Open(nil, msg, &nonce, pub, pri)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("cannot decrypt response")
|
||||
return nil, errCannotDecryptReponse
|
||||
}
|
||||
|
||||
return decrypted, nil
|
||||
}
|
||||
|
||||
|
@ -66,13 +84,18 @@ func encode(v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, e
|
|||
return encodeMsg(b, pubKey, privKey)
|
||||
}
|
||||
|
||||
func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) {
|
||||
func encodeMsg(
|
||||
payload []byte,
|
||||
pubKey *wgkey.Key,
|
||||
privKey *wgkey.Private,
|
||||
) ([]byte, error) {
|
||||
var nonce [24]byte
|
||||
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
|
||||
msg := box.Seal(nonce[:], b, &nonce, pub, pri)
|
||||
msg := box.Seal(nonce[:], payload, &nonce, pub, pri)
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
|
@ -89,7 +112,7 @@ func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
|
|||
|
||||
for {
|
||||
if !ipPrefix.Contains(ip) {
|
||||
return nil, fmt.Errorf("could not find any suitable IP in %s", ipPrefix)
|
||||
return nil, errCouldNotAllocateIP
|
||||
}
|
||||
|
||||
// Some OS (including Linux) does not like when IPs ends with 0 or 255, which
|
||||
|
@ -98,13 +121,14 @@ func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
|
|||
ipRaw := ip.As4()
|
||||
if ipRaw[3] == 0 || ipRaw[3] == 255 {
|
||||
ip = ip.Next()
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if ip.IsZero() &&
|
||||
ip.IsLoopback() {
|
||||
|
||||
ip = ip.Next()
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -125,7 +149,7 @@ func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) {
|
|||
if addr != "" {
|
||||
ip, err := netaddr.ParseIP(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse ip from database, %w", err)
|
||||
return nil, fmt.Errorf("failed to parse ip from database: %w", err)
|
||||
}
|
||||
|
||||
ips[index] = ip
|
||||
|
@ -156,11 +180,16 @@ func tailNodesToString(nodes []*tailcfg.Node) string {
|
|||
}
|
||||
|
||||
func tailMapResponseToString(resp tailcfg.MapResponse) string {
|
||||
return fmt.Sprintf("{ Node: %s, Peers: %s }", resp.Node.Name, tailNodesToString(resp.Peers))
|
||||
return fmt.Sprintf(
|
||||
"{ Node: %s, Peers: %s }",
|
||||
resp.Node.Name,
|
||||
tailNodesToString(resp.Peers),
|
||||
)
|
||||
}
|
||||
|
||||
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
|
||||
return d.DialContext(ctx, "unix", addr)
|
||||
}
|
||||
|
||||
|
@ -174,7 +203,7 @@ func ipPrefixToString(prefixes []netaddr.IPPrefix) []string {
|
|||
return result
|
||||
}
|
||||
|
||||
func stringToIpPrefix(prefixes []string) ([]netaddr.IPPrefix, error) {
|
||||
func stringToIPPrefix(prefixes []string) ([]netaddr.IPPrefix, error) {
|
||||
result := make([]netaddr.IPPrefix, len(prefixes))
|
||||
|
||||
for index, prefixStr := range prefixes {
|
||||
|
@ -189,7 +218,7 @@ func stringToIpPrefix(prefixes []string) ([]netaddr.IPPrefix, error) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func containsIpPrefix(prefixes []netaddr.IPPrefix, prefix netaddr.IPPrefix) bool {
|
||||
func containsIPPrefix(prefixes []netaddr.IPPrefix, prefix netaddr.IPPrefix) bool {
|
||||
for _, p := range prefixes {
|
||||
if prefix == p {
|
||||
return true
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
)
|
||||
|
||||
func (s *Suite) TestGetAvailableIp(c *check.C) {
|
||||
ip, err := h.getAvailableIP()
|
||||
ip, err := app.getAvailableIP()
|
||||
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
|
@ -16,33 +16,33 @@ func (s *Suite) TestGetAvailableIp(c *check.C) {
|
|||
}
|
||||
|
||||
func (s *Suite) TestGetUsedIps(c *check.C) {
|
||||
ip, err := h.getAvailableIP()
|
||||
ip, err := app.getAvailableIP()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
n, err := h.CreateNamespace("test_ip")
|
||||
namespace, err := app.CreateNamespace("test_ip")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = h.GetMachine("test", "testmachine")
|
||||
_, err = app.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
m := Machine{
|
||||
machine := Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "testmachine",
|
||||
NamespaceID: n.ID,
|
||||
NamespaceID: namespace.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
AuthKeyID: uint(pak.ID),
|
||||
IPAddress: ip.String(),
|
||||
}
|
||||
h.db.Save(&m)
|
||||
app.db.Save(&machine)
|
||||
|
||||
ips, err := h.getUsedIPs()
|
||||
ips, err := app.getUsedIPs()
|
||||
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
|
@ -50,42 +50,42 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
|
|||
|
||||
c.Assert(ips[0], check.Equals, expected)
|
||||
|
||||
m1, err := h.GetMachineByID(0)
|
||||
machine1, err := app.GetMachineByID(0)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(m1.IPAddress, check.Equals, expected.String())
|
||||
c.Assert(machine1.IPAddress, check.Equals, expected.String())
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||
n, err := h.CreateNamespace("test-ip-multi")
|
||||
namespace, err := app.CreateNamespace("test-ip-multi")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
for i := 1; i <= 350; i++ {
|
||||
ip, err := h.getAvailableIP()
|
||||
for index := 1; index <= 350; index++ {
|
||||
ip, err := app.getAvailableIP()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = h.GetMachine("test", "testmachine")
|
||||
_, err = app.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
m := Machine{
|
||||
ID: uint64(i),
|
||||
machine := Machine{
|
||||
ID: uint64(index),
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "testmachine",
|
||||
NamespaceID: n.ID,
|
||||
NamespaceID: namespace.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
AuthKeyID: uint(pak.ID),
|
||||
IPAddress: ip.String(),
|
||||
}
|
||||
h.db.Save(&m)
|
||||
app.db.Save(&machine)
|
||||
}
|
||||
|
||||
ips, err := h.getUsedIPs()
|
||||
ips, err := app.getUsedIPs()
|
||||
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
|
@ -96,59 +96,67 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
|
|||
c.Assert(ips[300], check.Equals, netaddr.MustParseIP("10.27.1.47"))
|
||||
|
||||
// Check that we can read back the IPs
|
||||
m1, err := h.GetMachineByID(1)
|
||||
machine1, err := app.GetMachineByID(1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(m1.IPAddress, check.Equals, netaddr.MustParseIP("10.27.0.1").String())
|
||||
c.Assert(
|
||||
machine1.IPAddress,
|
||||
check.Equals,
|
||||
netaddr.MustParseIP("10.27.0.1").String(),
|
||||
)
|
||||
|
||||
m50, err := h.GetMachineByID(50)
|
||||
machine50, err := app.GetMachineByID(50)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(m50.IPAddress, check.Equals, netaddr.MustParseIP("10.27.0.50").String())
|
||||
c.Assert(
|
||||
machine50.IPAddress,
|
||||
check.Equals,
|
||||
netaddr.MustParseIP("10.27.0.50").String(),
|
||||
)
|
||||
|
||||
expectedNextIP := netaddr.MustParseIP("10.27.1.97")
|
||||
nextIP, err := h.getAvailableIP()
|
||||
nextIP, err := app.getAvailableIP()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(nextIP.String(), check.Equals, expectedNextIP.String())
|
||||
|
||||
// If we call get Available again, we should receive
|
||||
// the same IP, as it has not been reserved.
|
||||
nextIP2, err := h.getAvailableIP()
|
||||
nextIP2, err := app.getAvailableIP()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(nextIP2.String(), check.Equals, expectedNextIP.String())
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
|
||||
ip, err := h.getAvailableIP()
|
||||
ip, err := app.getAvailableIP()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
expected := netaddr.MustParseIP("10.27.0.1")
|
||||
|
||||
c.Assert(ip.String(), check.Equals, expected.String())
|
||||
|
||||
n, err := h.CreateNamespace("test_ip")
|
||||
namespace, err := app.CreateNamespace("test_ip")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = h.GetMachine("test", "testmachine")
|
||||
_, err = app.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
m := Machine{
|
||||
machine := Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "testmachine",
|
||||
NamespaceID: n.ID,
|
||||
NamespaceID: namespace.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
h.db.Save(&m)
|
||||
app.db.Save(&machine)
|
||||
|
||||
ip2, err := h.getAvailableIP()
|
||||
ip2, err := app.getAvailableIP()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(ip2.String(), check.Equals, expected.String())
|
||||
|
|
Loading…
Reference in a new issue