Initial work eliminating one/two letter variables
This commit is contained in:
parent
53ed749f45
commit
471c0b4993
19 changed files with 568 additions and 532 deletions
|
@ -28,6 +28,9 @@ linters:
|
||||||
# In progress
|
# In progress
|
||||||
- gocritic
|
- gocritic
|
||||||
|
|
||||||
|
# TODO: approve: ok, db, id
|
||||||
|
- varnamelen
|
||||||
|
|
||||||
# We should strive to enable these:
|
# We should strive to enable these:
|
||||||
- testpackage
|
- testpackage
|
||||||
- stylecheck
|
- stylecheck
|
||||||
|
@ -39,7 +42,6 @@ linters:
|
||||||
- gosec
|
- gosec
|
||||||
- forbidigo
|
- forbidigo
|
||||||
- dupl
|
- dupl
|
||||||
- varnamelen
|
|
||||||
- makezero
|
- makezero
|
||||||
- paralleltest
|
- paralleltest
|
||||||
|
|
||||||
|
|
74
acls.go
74
acls.go
|
@ -41,18 +41,18 @@ func (h *Headscale) LoadACLPolicy(path string) error {
|
||||||
defer policyFile.Close()
|
defer policyFile.Close()
|
||||||
|
|
||||||
var policy ACLPolicy
|
var policy ACLPolicy
|
||||||
b, err := io.ReadAll(policyFile)
|
policyBytes, err := io.ReadAll(policyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ast, err := hujson.Parse(b)
|
ast, err := hujson.Parse(policyBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ast.Standardize()
|
ast.Standardize()
|
||||||
b = ast.Pack()
|
policyBytes = ast.Pack()
|
||||||
err = json.Unmarshal(b, &policy)
|
err = json.Unmarshal(policyBytes, &policy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -73,32 +73,32 @@ func (h *Headscale) LoadACLPolicy(path string) error {
|
||||||
func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
|
func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
|
||||||
rules := []tailcfg.FilterRule{}
|
rules := []tailcfg.FilterRule{}
|
||||||
|
|
||||||
for i, a := range h.aclPolicy.ACLs {
|
for index, acl := range h.aclPolicy.ACLs {
|
||||||
if a.Action != "accept" {
|
if acl.Action != "accept" {
|
||||||
return nil, errorInvalidAction
|
return nil, errorInvalidAction
|
||||||
}
|
}
|
||||||
|
|
||||||
r := tailcfg.FilterRule{}
|
filterRule := tailcfg.FilterRule{}
|
||||||
|
|
||||||
srcIPs := []string{}
|
srcIPs := []string{}
|
||||||
for j, u := range a.Users {
|
for innerIndex, user := range acl.Users {
|
||||||
srcs, err := h.generateACLPolicySrcIP(u)
|
srcs, err := h.generateACLPolicySrcIP(user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Msgf("Error parsing ACL %d, User %d", i, j)
|
Msgf("Error parsing ACL %d, User %d", index, innerIndex)
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
srcIPs = append(srcIPs, srcs...)
|
srcIPs = append(srcIPs, srcs...)
|
||||||
}
|
}
|
||||||
r.SrcIPs = srcIPs
|
filterRule.SrcIPs = srcIPs
|
||||||
|
|
||||||
destPorts := []tailcfg.NetPortRange{}
|
destPorts := []tailcfg.NetPortRange{}
|
||||||
for j, d := range a.Ports {
|
for innerIndex, ports := range acl.Ports {
|
||||||
dests, err := h.generateACLPolicyDestPorts(d)
|
dests, err := h.generateACLPolicyDestPorts(ports)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Msgf("Error parsing ACL %d, Port %d", i, j)
|
Msgf("Error parsing ACL %d, Port %d", index, innerIndex)
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -162,17 +162,17 @@ func (h *Headscale) generateACLPolicyDestPorts(
|
||||||
return dests, nil
|
return dests, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) expandAlias(s string) ([]string, error) {
|
func (h *Headscale) expandAlias(alias string) ([]string, error) {
|
||||||
if s == "*" {
|
if alias == "*" {
|
||||||
return []string{"*"}, nil
|
return []string{"*"}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(s, "group:") {
|
if strings.HasPrefix(alias, "group:") {
|
||||||
if _, ok := h.aclPolicy.Groups[s]; !ok {
|
if _, ok := h.aclPolicy.Groups[alias]; !ok {
|
||||||
return nil, errorInvalidGroup
|
return nil, errorInvalidGroup
|
||||||
}
|
}
|
||||||
ips := []string{}
|
ips := []string{}
|
||||||
for _, n := range h.aclPolicy.Groups[s] {
|
for _, n := range h.aclPolicy.Groups[alias] {
|
||||||
nodes, err := h.ListMachinesInNamespace(n)
|
nodes, err := h.ListMachinesInNamespace(n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errorInvalidNamespace
|
return nil, errorInvalidNamespace
|
||||||
|
@ -185,8 +185,8 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
|
||||||
return ips, nil
|
return ips, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(s, "tag:") {
|
if strings.HasPrefix(alias, "tag:") {
|
||||||
if _, ok := h.aclPolicy.TagOwners[s]; !ok {
|
if _, ok := h.aclPolicy.TagOwners[alias]; !ok {
|
||||||
return nil, errorInvalidTag
|
return nil, errorInvalidTag
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -197,10 +197,10 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
ips := []string{}
|
ips := []string{}
|
||||||
for _, m := range machines {
|
for _, machine := range machines {
|
||||||
hostinfo := tailcfg.Hostinfo{}
|
hostinfo := tailcfg.Hostinfo{}
|
||||||
if len(m.HostInfo) != 0 {
|
if len(machine.HostInfo) != 0 {
|
||||||
hi, err := m.HostInfo.MarshalJSON()
|
hi, err := machine.HostInfo.MarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -211,8 +211,8 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
|
||||||
|
|
||||||
// FIXME: Check TagOwners allows this
|
// FIXME: Check TagOwners allows this
|
||||||
for _, t := range hostinfo.RequestTags {
|
for _, t := range hostinfo.RequestTags {
|
||||||
if s[4:] == t {
|
if alias[4:] == t {
|
||||||
ips = append(ips, m.IPAddress)
|
ips = append(ips, machine.IPAddress)
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -223,7 +223,7 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
|
||||||
return ips, nil
|
return ips, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := h.GetNamespace(s)
|
n, err := h.GetNamespace(alias)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
nodes, err := h.ListMachinesInNamespace(n.Name)
|
nodes, err := h.ListMachinesInNamespace(n.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -237,16 +237,16 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
|
||||||
return ips, nil
|
return ips, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if h, ok := h.aclPolicy.Hosts[s]; ok {
|
if h, ok := h.aclPolicy.Hosts[alias]; ok {
|
||||||
return []string{h.String()}, nil
|
return []string{h.String()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ip, err := netaddr.ParseIP(s)
|
ip, err := netaddr.ParseIP(alias)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return []string{ip.String()}, nil
|
return []string{ip.String()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
cidr, err := netaddr.ParseIPPrefix(s)
|
cidr, err := netaddr.ParseIPPrefix(alias)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return []string{cidr.String()}, nil
|
return []string{cidr.String()}, nil
|
||||||
}
|
}
|
||||||
|
@ -254,25 +254,25 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
|
||||||
return nil, errorInvalidUserSection
|
return nil, errorInvalidUserSection
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) {
|
func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
|
||||||
if s == "*" {
|
if portsStr == "*" {
|
||||||
return &[]tailcfg.PortRange{
|
return &[]tailcfg.PortRange{
|
||||||
{First: PORT_RANGE_BEGIN, Last: PORT_RANGE_END},
|
{First: PORT_RANGE_BEGIN, Last: PORT_RANGE_END},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ports := []tailcfg.PortRange{}
|
ports := []tailcfg.PortRange{}
|
||||||
for _, p := range strings.Split(s, ",") {
|
for _, portStr := range strings.Split(portsStr, ",") {
|
||||||
rang := strings.Split(p, "-")
|
rang := strings.Split(portStr, "-")
|
||||||
switch len(rang) {
|
switch len(rang) {
|
||||||
case 1:
|
case 1:
|
||||||
pi, err := strconv.ParseUint(rang[0], BASE_10, BIT_SIZE_16)
|
port, err := strconv.ParseUint(rang[0], BASE_10, BIT_SIZE_16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
ports = append(ports, tailcfg.PortRange{
|
ports = append(ports, tailcfg.PortRange{
|
||||||
First: uint16(pi),
|
First: uint16(port),
|
||||||
Last: uint16(pi),
|
Last: uint16(port),
|
||||||
})
|
})
|
||||||
|
|
||||||
case EXPECTED_TOKEN_ITEMS:
|
case EXPECTED_TOKEN_ITEMS:
|
||||||
|
|
|
@ -41,37 +41,37 @@ type ACLTest struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalJSON allows to parse the Hosts directly into netaddr objects.
|
// UnmarshalJSON allows to parse the Hosts directly into netaddr objects.
|
||||||
func (h *Hosts) UnmarshalJSON(data []byte) error {
|
func (hosts *Hosts) UnmarshalJSON(data []byte) error {
|
||||||
hosts := Hosts{}
|
newHosts := Hosts{}
|
||||||
hs := make(map[string]string)
|
hostIpPrefixMap := make(map[string]string)
|
||||||
ast, err := hujson.Parse(data)
|
ast, err := hujson.Parse(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ast.Standardize()
|
ast.Standardize()
|
||||||
data = ast.Pack()
|
data = ast.Pack()
|
||||||
err = json.Unmarshal(data, &hs)
|
err = json.Unmarshal(data, &hostIpPrefixMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for k, v := range hs {
|
for host, prefixStr := range hostIpPrefixMap {
|
||||||
if !strings.Contains(v, "/") {
|
if !strings.Contains(prefixStr, "/") {
|
||||||
v += "/32"
|
prefixStr += "/32"
|
||||||
}
|
}
|
||||||
prefix, err := netaddr.ParseIPPrefix(v)
|
prefix, err := netaddr.ParseIPPrefix(prefixStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
hosts[k] = prefix
|
newHosts[host] = prefix
|
||||||
}
|
}
|
||||||
*h = hosts
|
*hosts = newHosts
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsZero is perhaps a bit naive here.
|
// IsZero is perhaps a bit naive here.
|
||||||
func (p ACLPolicy) IsZero() bool {
|
func (policy ACLPolicy) IsZero() bool {
|
||||||
if len(p.Groups) == 0 && len(p.Hosts) == 0 && len(p.ACLs) == 0 {
|
if len(policy.Groups) == 0 && len(policy.Hosts) == 0 && len(policy.ACLs) == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
229
api.go
229
api.go
|
@ -22,21 +22,25 @@ const RESERVED_RESPONSE_HEADER_SIZE = 4
|
||||||
|
|
||||||
// KeyHandler provides the Headscale pub key
|
// KeyHandler provides the Headscale pub key
|
||||||
// Listens in /key.
|
// Listens in /key.
|
||||||
func (h *Headscale) KeyHandler(c *gin.Context) {
|
func (h *Headscale) KeyHandler(ctx *gin.Context) {
|
||||||
c.Data(http.StatusOK, "text/plain; charset=utf-8", []byte(h.publicKey.HexString()))
|
ctx.Data(
|
||||||
|
http.StatusOK,
|
||||||
|
"text/plain; charset=utf-8",
|
||||||
|
[]byte(h.publicKey.HexString()),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterWebAPI shows a simple message in the browser to point to the CLI
|
// RegisterWebAPI shows a simple message in the browser to point to the CLI
|
||||||
// Listens in /register.
|
// Listens in /register.
|
||||||
func (h *Headscale) RegisterWebAPI(c *gin.Context) {
|
func (h *Headscale) RegisterWebAPI(ctx *gin.Context) {
|
||||||
mKeyStr := c.Query("key")
|
machineKeyStr := ctx.Query("key")
|
||||||
if mKeyStr == "" {
|
if machineKeyStr == "" {
|
||||||
c.String(http.StatusBadRequest, "Wrong params")
|
ctx.String(http.StatusBadRequest, "Wrong params")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
||||||
<html>
|
<html>
|
||||||
<body>
|
<body>
|
||||||
<h1>headscale</h1>
|
<h1>headscale</h1>
|
||||||
|
@ -53,45 +57,45 @@ func (h *Headscale) RegisterWebAPI(c *gin.Context) {
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
|
|
||||||
`, mKeyStr)))
|
`, machineKeyStr)))
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegistrationHandler handles the actual registration process of a machine
|
// RegistrationHandler handles the actual registration process of a machine
|
||||||
// Endpoint /machine/:id.
|
// Endpoint /machine/:id.
|
||||||
func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
|
||||||
body, _ := io.ReadAll(c.Request.Body)
|
body, _ := io.ReadAll(ctx.Request.Body)
|
||||||
mKeyStr := c.Param("id")
|
machineKeyStr := ctx.Param("id")
|
||||||
mKey, err := wgkey.ParseHex(mKeyStr)
|
machineKey, err := wgkey.ParseHex(machineKeyStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot parse machine key")
|
Msg("Cannot parse machine key")
|
||||||
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
|
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
|
||||||
c.String(http.StatusInternalServerError, "Sad!")
|
ctx.String(http.StatusInternalServerError, "Sad!")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
req := tailcfg.RegisterRequest{}
|
req := tailcfg.RegisterRequest{}
|
||||||
err = decode(body, &req, &mKey, h.privateKey)
|
err = decode(body, &req, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot decode message")
|
Msg("Cannot decode message")
|
||||||
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
|
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
|
||||||
c.String(http.StatusInternalServerError, "Very sad!")
|
ctx.String(http.StatusInternalServerError, "Very sad!")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
m, err := h.GetMachineByMachineKey(mKey.HexString())
|
machine, err := h.GetMachineByMachineKey(machineKey.HexString())
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
|
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
|
||||||
newMachine := Machine{
|
newMachine := Machine{
|
||||||
Expiry: &time.Time{},
|
Expiry: &time.Time{},
|
||||||
MachineKey: mKey.HexString(),
|
MachineKey: machineKey.HexString(),
|
||||||
Name: req.Hostinfo.Hostname,
|
Name: req.Hostinfo.Hostname,
|
||||||
}
|
}
|
||||||
if err := h.db.Create(&newMachine).Error; err != nil {
|
if err := h.db.Create(&newMachine).Error; err != nil {
|
||||||
|
@ -99,16 +103,16 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not create row")
|
Msg("Could not create row")
|
||||||
machineRegistrations.WithLabelValues("unknown", "web", "error", m.Namespace.Name).
|
machineRegistrations.WithLabelValues("unknown", "web", "error", machine.Namespace.Name).
|
||||||
Inc()
|
Inc()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
m = &newMachine
|
machine = &newMachine
|
||||||
}
|
}
|
||||||
|
|
||||||
if !m.Registered && req.Auth.AuthKey != "" {
|
if !machine.Registered && req.Auth.AuthKey != "" {
|
||||||
h.handleAuthKey(c, h.db, mKey, req, *m)
|
h.handleAuthKey(ctx, h.db, machineKey, req, *machine)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -116,63 +120,63 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||||
resp := tailcfg.RegisterResponse{}
|
resp := tailcfg.RegisterResponse{}
|
||||||
|
|
||||||
// We have the updated key!
|
// We have the updated key!
|
||||||
if m.NodeKey == wgkey.Key(req.NodeKey).HexString() {
|
if machine.NodeKey == wgkey.Key(req.NodeKey).HexString() {
|
||||||
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
|
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
|
||||||
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
|
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
|
||||||
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
|
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Client requested logout")
|
Msg("Client requested logout")
|
||||||
|
|
||||||
m.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired
|
machine.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
|
|
||||||
resp.AuthURL = ""
|
resp.AuthURL = ""
|
||||||
resp.MachineAuthorized = false
|
resp.MachineAuthorized = false
|
||||||
resp.User = *m.Namespace.toUser()
|
resp.User = *machine.Namespace.toUser()
|
||||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
c.String(http.StatusInternalServerError, "")
|
ctx.String(http.StatusInternalServerError, "")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.Registered && m.Expiry.UTC().After(now) {
|
if machine.Registered && machine.Expiry.UTC().After(now) {
|
||||||
// The machine registration is valid, respond with redirect to /map
|
// The machine registration is valid, respond with redirect to /map
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Client is registered and we have the current NodeKey. All clear to /map")
|
Msg("Client is registered and we have the current NodeKey. All clear to /map")
|
||||||
|
|
||||||
resp.AuthURL = ""
|
resp.AuthURL = ""
|
||||||
resp.MachineAuthorized = true
|
resp.MachineAuthorized = true
|
||||||
resp.User = *m.Namespace.toUser()
|
resp.User = *machine.Namespace.toUser()
|
||||||
resp.Login = *m.Namespace.toLogin()
|
resp.Login = *machine.Namespace.toLogin()
|
||||||
|
|
||||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
machineRegistrations.WithLabelValues("update", "web", "error", m.Namespace.Name).
|
machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name).
|
||||||
Inc()
|
Inc()
|
||||||
c.String(http.StatusInternalServerError, "")
|
ctx.String(http.StatusInternalServerError, "")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
machineRegistrations.WithLabelValues("update", "web", "success", m.Namespace.Name).
|
machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name).
|
||||||
Inc()
|
Inc()
|
||||||
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -180,15 +184,15 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||||
// The client has registered before, but has expired
|
// The client has registered before, but has expired
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Machine registration has expired. Sending a authurl to register")
|
Msg("Machine registration has expired. Sending a authurl to register")
|
||||||
|
|
||||||
if h.cfg.OIDC.Issuer != "" {
|
if h.cfg.OIDC.Issuer != "" {
|
||||||
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
|
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
|
||||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
|
||||||
} else {
|
} else {
|
||||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
|
||||||
}
|
}
|
||||||
|
|
||||||
// When a client connects, it may request a specific expiry time in its
|
// When a client connects, it may request a specific expiry time in its
|
||||||
|
@ -197,51 +201,52 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||||
// into two steps (which cant pass arbitrary data between them easily) and needs to be
|
// into two steps (which cant pass arbitrary data between them easily) and needs to be
|
||||||
// retrieved again after the user has authenticated. After the authentication flow
|
// retrieved again after the user has authenticated. After the authentication flow
|
||||||
// completes, RequestedExpiry is copied into Expiry.
|
// completes, RequestedExpiry is copied into Expiry.
|
||||||
m.RequestedExpiry = &req.Expiry
|
machine.RequestedExpiry = &req.Expiry
|
||||||
|
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
|
|
||||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
machineRegistrations.WithLabelValues("new", "web", "error", m.Namespace.Name).
|
machineRegistrations.WithLabelValues("new", "web", "error", machine.Namespace.Name).
|
||||||
Inc()
|
Inc()
|
||||||
c.String(http.StatusInternalServerError, "")
|
ctx.String(http.StatusInternalServerError, "")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
machineRegistrations.WithLabelValues("new", "web", "success", m.Namespace.Name).
|
machineRegistrations.WithLabelValues("new", "web", "success", machine.Namespace.Name).
|
||||||
Inc()
|
Inc()
|
||||||
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
|
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
|
||||||
if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && m.Expiry.UTC().After(now) {
|
if machine.NodeKey == wgkey.Key(req.OldNodeKey).HexString() &&
|
||||||
|
machine.Expiry.UTC().After(now) {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("We have the OldNodeKey in the database. This is a key refresh")
|
Msg("We have the OldNodeKey in the database. This is a key refresh")
|
||||||
m.NodeKey = wgkey.Key(req.NodeKey).HexString()
|
machine.NodeKey = wgkey.Key(req.NodeKey).HexString()
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
|
|
||||||
resp.AuthURL = ""
|
resp.AuthURL = ""
|
||||||
resp.User = *m.Namespace.toUser()
|
resp.User = *machine.Namespace.toUser()
|
||||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
c.String(http.StatusInternalServerError, "Extremely sad!")
|
ctx.String(http.StatusInternalServerError, "Extremely sad!")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -249,47 +254,47 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||||
// The machine registration is new, redirect the client to the registration URL
|
// The machine registration is new, redirect the client to the registration URL
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("The node is sending us a new NodeKey, sending auth url")
|
Msg("The node is sending us a new NodeKey, sending auth url")
|
||||||
if h.cfg.OIDC.Issuer != "" {
|
if h.cfg.OIDC.Issuer != "" {
|
||||||
resp.AuthURL = fmt.Sprintf(
|
resp.AuthURL = fmt.Sprintf(
|
||||||
"%s/oidc/register/%s",
|
"%s/oidc/register/%s",
|
||||||
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
||||||
mKey.HexString(),
|
machineKey.HexString(),
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
|
||||||
}
|
}
|
||||||
|
|
||||||
// save the requested expiry time for retrieval later in the authentication flow
|
// save the requested expiry time for retrieval later in the authentication flow
|
||||||
m.RequestedExpiry = &req.Expiry
|
machine.RequestedExpiry = &req.Expiry
|
||||||
m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey
|
machine.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
|
|
||||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
c.String(http.StatusInternalServerError, "")
|
ctx.String(http.StatusInternalServerError, "")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) getMapResponse(
|
func (h *Headscale) getMapResponse(
|
||||||
mKey wgkey.Key,
|
machineKey wgkey.Key,
|
||||||
req tailcfg.MapRequest,
|
req tailcfg.MapRequest,
|
||||||
m *Machine,
|
machine *Machine,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("func", "getMapResponse").
|
Str("func", "getMapResponse").
|
||||||
Str("machine", req.Hostinfo.Hostname).
|
Str("machine", req.Hostinfo.Hostname).
|
||||||
Msg("Creating Map response")
|
Msg("Creating Map response")
|
||||||
node, err := m.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
|
node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "getMapResponse").
|
Str("func", "getMapResponse").
|
||||||
|
@ -299,7 +304,7 @@ func (h *Headscale) getMapResponse(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
peers, err := h.getPeers(m)
|
peers, err := h.getPeers(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "getMapResponse").
|
Str("func", "getMapResponse").
|
||||||
|
@ -309,7 +314,7 @@ func (h *Headscale) getMapResponse(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
profiles := getMapResponseUserProfiles(*m, peers)
|
profiles := getMapResponseUserProfiles(*machine, peers)
|
||||||
|
|
||||||
nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
|
nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -324,7 +329,7 @@ func (h *Headscale) getMapResponse(
|
||||||
dnsConfig := getMapResponseDNSConfig(
|
dnsConfig := getMapResponseDNSConfig(
|
||||||
h.cfg.DNSConfig,
|
h.cfg.DNSConfig,
|
||||||
h.cfg.BaseDomain,
|
h.cfg.BaseDomain,
|
||||||
*m,
|
*machine,
|
||||||
peers,
|
peers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -351,12 +356,12 @@ func (h *Headscale) getMapResponse(
|
||||||
|
|
||||||
encoder, _ := zstd.NewWriter(nil)
|
encoder, _ := zstd.NewWriter(nil)
|
||||||
srcCompressed := encoder.EncodeAll(src, nil)
|
srcCompressed := encoder.EncodeAll(src, nil)
|
||||||
respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey)
|
respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
respBody, err = encode(resp, &mKey, h.privateKey)
|
respBody, err = encode(resp, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -370,24 +375,24 @@ func (h *Headscale) getMapResponse(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) getMapKeepAliveResponse(
|
func (h *Headscale) getMapKeepAliveResponse(
|
||||||
mKey wgkey.Key,
|
machineKey wgkey.Key,
|
||||||
req tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
resp := tailcfg.MapResponse{
|
mapResponse := tailcfg.MapResponse{
|
||||||
KeepAlive: true,
|
KeepAlive: true,
|
||||||
}
|
}
|
||||||
var respBody []byte
|
var respBody []byte
|
||||||
var err error
|
var err error
|
||||||
if req.Compress == "zstd" {
|
if mapRequest.Compress == "zstd" {
|
||||||
src, _ := json.Marshal(resp)
|
src, _ := json.Marshal(mapResponse)
|
||||||
encoder, _ := zstd.NewWriter(nil)
|
encoder, _ := zstd.NewWriter(nil)
|
||||||
srcCompressed := encoder.EncodeAll(src, nil)
|
srcCompressed := encoder.EncodeAll(src, nil)
|
||||||
respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey)
|
respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
respBody, err = encode(resp, &mKey, h.privateKey)
|
respBody, err = encode(mapResponse, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -400,22 +405,22 @@ func (h *Headscale) getMapKeepAliveResponse(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) handleAuthKey(
|
func (h *Headscale) handleAuthKey(
|
||||||
c *gin.Context,
|
ctx *gin.Context,
|
||||||
db *gorm.DB,
|
db *gorm.DB,
|
||||||
idKey wgkey.Key,
|
idKey wgkey.Key,
|
||||||
req tailcfg.RegisterRequest,
|
reqisterRequest tailcfg.RegisterRequest,
|
||||||
m Machine,
|
machine Machine,
|
||||||
) {
|
) {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", req.Hostinfo.Hostname).
|
Str("machine", reqisterRequest.Hostinfo.Hostname).
|
||||||
Msgf("Processing auth key for %s", req.Hostinfo.Hostname)
|
Msgf("Processing auth key for %s", reqisterRequest.Hostinfo.Hostname)
|
||||||
resp := tailcfg.RegisterResponse{}
|
resp := tailcfg.RegisterResponse{}
|
||||||
pak, err := h.checkKeyValidity(req.Auth.AuthKey)
|
pak, err := h.checkKeyValidity(reqisterRequest.Auth.AuthKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Failed authentication via AuthKey")
|
Msg("Failed authentication via AuthKey")
|
||||||
resp.MachineAuthorized = false
|
resp.MachineAuthorized = false
|
||||||
|
@ -423,21 +428,21 @@ func (h *Headscale) handleAuthKey(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
c.String(http.StatusInternalServerError, "")
|
ctx.String(http.StatusInternalServerError, "")
|
||||||
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
|
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
|
||||||
Inc()
|
Inc()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody)
|
ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody)
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Failed authentication via AuthKey")
|
Msg("Failed authentication via AuthKey")
|
||||||
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
|
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
|
||||||
Inc()
|
Inc()
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -445,32 +450,34 @@ func (h *Headscale) handleAuthKey(
|
||||||
|
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Authentication key was valid, proceeding to acquire an IP address")
|
Msg("Authentication key was valid, proceeding to acquire an IP address")
|
||||||
ip, err := h.getAvailableIP()
|
ip, err := h.getAvailableIP()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Failed to find an available IP")
|
Msg("Failed to find an available IP")
|
||||||
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
|
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
|
||||||
Inc()
|
Inc()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("ip", ip.String()).
|
Str("ip", ip.String()).
|
||||||
Msgf("Assigning %s to %s", ip, m.Name)
|
Msgf("Assigning %s to %s", ip, machine.Name)
|
||||||
|
|
||||||
m.AuthKeyID = uint(pak.ID)
|
machine.AuthKeyID = uint(pak.ID)
|
||||||
m.IPAddress = ip.String()
|
machine.IPAddress = ip.String()
|
||||||
m.NamespaceID = pak.NamespaceID
|
machine.NamespaceID = pak.NamespaceID
|
||||||
m.NodeKey = wgkey.Key(req.NodeKey).HexString() // we update it just in case
|
machine.NodeKey = wgkey.Key(reqisterRequest.NodeKey).
|
||||||
m.Registered = true
|
HexString()
|
||||||
m.RegisterMethod = "authKey"
|
// we update it just in case
|
||||||
db.Save(&m)
|
machine.Registered = true
|
||||||
|
machine.RegisterMethod = "authKey"
|
||||||
|
db.Save(&machine)
|
||||||
|
|
||||||
pak.Used = true
|
pak.Used = true
|
||||||
db.Save(&pak)
|
db.Save(&pak)
|
||||||
|
@ -481,21 +488,21 @@ func (h *Headscale) handleAuthKey(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
|
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
|
||||||
Inc()
|
Inc()
|
||||||
c.String(http.StatusInternalServerError, "Extremely sad!")
|
ctx.String(http.StatusInternalServerError, "Extremely sad!")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
machineRegistrations.WithLabelValues("new", "authkey", "success", m.Namespace.Name).
|
machineRegistrations.WithLabelValues("new", "authkey", "success", machine.Namespace.Name).
|
||||||
Inc()
|
Inc()
|
||||||
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("ip", ip.String()).
|
Str("ip", ip.String()).
|
||||||
Msg("Successfully authenticated via AuthKey")
|
Msg("Successfully authenticated via AuthKey")
|
||||||
}
|
}
|
||||||
|
|
140
app.go
140
app.go
|
@ -169,7 +169,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
|
||||||
return nil, errors.New("unsupported DB")
|
return nil, errors.New("unsupported DB")
|
||||||
}
|
}
|
||||||
|
|
||||||
h := Headscale{
|
app := Headscale{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
dbType: cfg.DBtype,
|
dbType: cfg.DBtype,
|
||||||
dbString: dbString,
|
dbString: dbString,
|
||||||
|
@ -178,32 +178,32 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
|
||||||
aclRules: tailcfg.FilterAllowAll, // default allowall
|
aclRules: tailcfg.FilterAllowAll, // default allowall
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.initDB()
|
err = app.initDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.OIDC.Issuer != "" {
|
if cfg.OIDC.Issuer != "" {
|
||||||
err = h.initOIDC()
|
err = app.initOIDC()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS
|
if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS
|
||||||
magicDNSDomains := generateMagicDNSRootDomains(
|
magicDNSDomains := generateMagicDNSRootDomains(
|
||||||
h.cfg.IPPrefix,
|
app.cfg.IPPrefix,
|
||||||
)
|
)
|
||||||
// we might have routes already from Split DNS
|
// we might have routes already from Split DNS
|
||||||
if h.cfg.DNSConfig.Routes == nil {
|
if app.cfg.DNSConfig.Routes == nil {
|
||||||
h.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver)
|
app.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver)
|
||||||
}
|
}
|
||||||
for _, d := range magicDNSDomains {
|
for _, d := range magicDNSDomains {
|
||||||
h.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil
|
app.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &h, nil
|
return &app, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Redirect to our TLS url.
|
// Redirect to our TLS url.
|
||||||
|
@ -229,35 +229,37 @@ func (h *Headscale) expireEphemeralNodesWorker() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ns := range namespaces {
|
for _, namespace := range namespaces {
|
||||||
machines, err := h.ListMachinesInNamespace(ns.Name)
|
machines, err := h.ListMachinesInNamespace(namespace.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(err).
|
Err(err).
|
||||||
Str("namespace", ns.Name).
|
Str("namespace", namespace.Name).
|
||||||
Msg("Error listing machines in namespace")
|
Msg("Error listing machines in namespace")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, m := range machines {
|
for _, machine := range machines {
|
||||||
if m.AuthKey != nil && m.LastSeen != nil && m.AuthKey.Ephemeral &&
|
if machine.AuthKey != nil && machine.LastSeen != nil &&
|
||||||
time.Now().After(m.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
|
machine.AuthKey.Ephemeral &&
|
||||||
|
time.Now().
|
||||||
|
After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Ephemeral client removed from database")
|
Msg("Ephemeral client removed from database")
|
||||||
|
|
||||||
err = h.db.Unscoped().Delete(m).Error
|
err = h.db.Unscoped().Delete(machine).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(err).
|
Err(err).
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("🤮 Cannot delete ephemeral machine from the database")
|
Msg("🤮 Cannot delete ephemeral machine from the database")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h.setLastStateChangeToNow(ns.Name)
|
h.setLastStateChangeToNow(namespace.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -284,18 +286,18 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
||||||
// with the "legacy" database-based client
|
// with the "legacy" database-based client
|
||||||
// It is also neede for grpc-gateway to be able to connect to
|
// It is also neede for grpc-gateway to be able to connect to
|
||||||
// the server
|
// the server
|
||||||
p, _ := peer.FromContext(ctx)
|
client, _ := peer.FromContext(ctx)
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("client_address", p.Addr.String()).
|
Str("client_address", client.Addr.String()).
|
||||||
Msg("Client is trying to authenticate")
|
Msg("Client is trying to authenticate")
|
||||||
|
|
||||||
md, ok := metadata.FromIncomingContext(ctx)
|
meta, ok := metadata.FromIncomingContext(ctx)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
Str("client_address", p.Addr.String()).
|
Str("client_address", client.Addr.String()).
|
||||||
Msg("Retrieving metadata is failed")
|
Msg("Retrieving metadata is failed")
|
||||||
|
|
||||||
return ctx, status.Errorf(
|
return ctx, status.Errorf(
|
||||||
|
@ -304,11 +306,11 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
authHeader, ok := md["authorization"]
|
authHeader, ok := meta["authorization"]
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
Str("client_address", p.Addr.String()).
|
Str("client_address", client.Addr.String()).
|
||||||
Msg("Authorization token is not supplied")
|
Msg("Authorization token is not supplied")
|
||||||
|
|
||||||
return ctx, status.Errorf(
|
return ctx, status.Errorf(
|
||||||
|
@ -322,7 +324,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
||||||
if !strings.HasPrefix(token, AUTH_PREFIX) {
|
if !strings.HasPrefix(token, AUTH_PREFIX) {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
Str("client_address", p.Addr.String()).
|
Str("client_address", client.Addr.String()).
|
||||||
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
||||||
|
|
||||||
return ctx, status.Error(
|
return ctx, status.Error(
|
||||||
|
@ -353,25 +355,25 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
||||||
// return handler(ctx, req)
|
// return handler(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) httpAuthenticationMiddleware(c *gin.Context) {
|
func (h *Headscale) httpAuthenticationMiddleware(ctx *gin.Context) {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("client_address", c.ClientIP()).
|
Str("client_address", ctx.ClientIP()).
|
||||||
Msg("HTTP authentication invoked")
|
Msg("HTTP authentication invoked")
|
||||||
|
|
||||||
authHeader := c.GetHeader("authorization")
|
authHeader := ctx.GetHeader("authorization")
|
||||||
|
|
||||||
if !strings.HasPrefix(authHeader, AUTH_PREFIX) {
|
if !strings.HasPrefix(authHeader, AUTH_PREFIX) {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
Str("client_address", c.ClientIP()).
|
Str("client_address", ctx.ClientIP()).
|
||||||
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
||||||
c.AbortWithStatus(http.StatusUnauthorized)
|
ctx.AbortWithStatus(http.StatusUnauthorized)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.AbortWithStatus(http.StatusUnauthorized)
|
ctx.AbortWithStatus(http.StatusUnauthorized)
|
||||||
|
|
||||||
// TODO(kradalby): Implement API key backend
|
// TODO(kradalby): Implement API key backend
|
||||||
// Currently all traffic is unauthorized, this is intentional to allow
|
// Currently all traffic is unauthorized, this is intentional to allow
|
||||||
|
@ -438,9 +440,9 @@ func (h *Headscale) Serve() error {
|
||||||
|
|
||||||
// Create the cmux object that will multiplex 2 protocols on the same port.
|
// Create the cmux object that will multiplex 2 protocols on the same port.
|
||||||
// The two following listeners will be served on the same port below gracefully.
|
// The two following listeners will be served on the same port below gracefully.
|
||||||
m := cmux.New(networkListener)
|
networkMutex := cmux.New(networkListener)
|
||||||
// Match gRPC requests here
|
// Match gRPC requests here
|
||||||
grpcListener := m.MatchWithWriters(
|
grpcListener := networkMutex.MatchWithWriters(
|
||||||
cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"),
|
cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"),
|
||||||
cmux.HTTP2MatchHeaderFieldSendSettings(
|
cmux.HTTP2MatchHeaderFieldSendSettings(
|
||||||
"content-type",
|
"content-type",
|
||||||
|
@ -448,7 +450,7 @@ func (h *Headscale) Serve() error {
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
// Otherwise match regular http requests.
|
// Otherwise match regular http requests.
|
||||||
httpListener := m.Match(cmux.Any())
|
httpListener := networkMutex.Match(cmux.Any())
|
||||||
|
|
||||||
grpcGatewayMux := runtime.NewServeMux()
|
grpcGatewayMux := runtime.NewServeMux()
|
||||||
|
|
||||||
|
@ -471,33 +473,33 @@ func (h *Headscale) Serve() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
r := gin.Default()
|
router := gin.Default()
|
||||||
|
|
||||||
p := ginprometheus.NewPrometheus("gin")
|
prometheus := ginprometheus.NewPrometheus("gin")
|
||||||
p.Use(r)
|
prometheus.Use(router)
|
||||||
|
|
||||||
r.GET(
|
router.GET(
|
||||||
"/health",
|
"/health",
|
||||||
func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) },
|
func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) },
|
||||||
)
|
)
|
||||||
r.GET("/key", h.KeyHandler)
|
router.GET("/key", h.KeyHandler)
|
||||||
r.GET("/register", h.RegisterWebAPI)
|
router.GET("/register", h.RegisterWebAPI)
|
||||||
r.POST("/machine/:id/map", h.PollNetMapHandler)
|
router.POST("/machine/:id/map", h.PollNetMapHandler)
|
||||||
r.POST("/machine/:id", h.RegistrationHandler)
|
router.POST("/machine/:id", h.RegistrationHandler)
|
||||||
r.GET("/oidc/register/:mkey", h.RegisterOIDC)
|
router.GET("/oidc/register/:mkey", h.RegisterOIDC)
|
||||||
r.GET("/oidc/callback", h.OIDCCallback)
|
router.GET("/oidc/callback", h.OIDCCallback)
|
||||||
r.GET("/apple", h.AppleMobileConfig)
|
router.GET("/apple", h.AppleMobileConfig)
|
||||||
r.GET("/apple/:platform", h.ApplePlatformConfig)
|
router.GET("/apple/:platform", h.ApplePlatformConfig)
|
||||||
r.GET("/swagger", SwaggerUI)
|
router.GET("/swagger", SwaggerUI)
|
||||||
r.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1)
|
router.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1)
|
||||||
|
|
||||||
api := r.Group("/api")
|
api := router.Group("/api")
|
||||||
api.Use(h.httpAuthenticationMiddleware)
|
api.Use(h.httpAuthenticationMiddleware)
|
||||||
{
|
{
|
||||||
api.Any("/v1/*any", gin.WrapF(grpcGatewayMux.ServeHTTP))
|
api.Any("/v1/*any", gin.WrapF(grpcGatewayMux.ServeHTTP))
|
||||||
}
|
}
|
||||||
|
|
||||||
r.NoRoute(stdoutHandler)
|
router.NoRoute(stdoutHandler)
|
||||||
|
|
||||||
// Fetch an initial DERP Map before we start serving
|
// Fetch an initial DERP Map before we start serving
|
||||||
h.DERPMap = GetDERPMap(h.cfg.DERP)
|
h.DERPMap = GetDERPMap(h.cfg.DERP)
|
||||||
|
@ -514,7 +516,7 @@ func (h *Headscale) Serve() error {
|
||||||
|
|
||||||
httpServer := &http.Server{
|
httpServer := &http.Server{
|
||||||
Addr: h.cfg.Addr,
|
Addr: h.cfg.Addr,
|
||||||
Handler: r,
|
Handler: router,
|
||||||
ReadTimeout: HTTP_READ_TIMEOUT,
|
ReadTimeout: HTTP_READ_TIMEOUT,
|
||||||
// Go does not handle timeouts in HTTP very well, and there is
|
// Go does not handle timeouts in HTTP very well, and there is
|
||||||
// no good way to handle streaming timeouts, therefore we need to
|
// no good way to handle streaming timeouts, therefore we need to
|
||||||
|
@ -561,29 +563,29 @@ func (h *Headscale) Serve() error {
|
||||||
reflection.Register(grpcServer)
|
reflection.Register(grpcServer)
|
||||||
reflection.Register(grpcSocket)
|
reflection.Register(grpcSocket)
|
||||||
|
|
||||||
g := new(errgroup.Group)
|
errorGroup := new(errgroup.Group)
|
||||||
|
|
||||||
g.Go(func() error { return grpcSocket.Serve(socketListener) })
|
errorGroup.Go(func() error { return grpcSocket.Serve(socketListener) })
|
||||||
|
|
||||||
// TODO(kradalby): Verify if we need the same TLS setup for gRPC as HTTP
|
// TODO(kradalby): Verify if we need the same TLS setup for gRPC as HTTP
|
||||||
g.Go(func() error { return grpcServer.Serve(grpcListener) })
|
errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) })
|
||||||
|
|
||||||
if tlsConfig != nil {
|
if tlsConfig != nil {
|
||||||
g.Go(func() error {
|
errorGroup.Go(func() error {
|
||||||
tlsl := tls.NewListener(httpListener, tlsConfig)
|
tlsl := tls.NewListener(httpListener, tlsConfig)
|
||||||
|
|
||||||
return httpServer.Serve(tlsl)
|
return httpServer.Serve(tlsl)
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
g.Go(func() error { return httpServer.Serve(httpListener) })
|
errorGroup.Go(func() error { return httpServer.Serve(httpListener) })
|
||||||
}
|
}
|
||||||
|
|
||||||
g.Go(func() error { return m.Serve() })
|
errorGroup.Go(func() error { return networkMutex.Serve() })
|
||||||
|
|
||||||
log.Info().
|
log.Info().
|
||||||
Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr)
|
Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr)
|
||||||
|
|
||||||
return g.Wait()
|
return errorGroup.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
||||||
|
@ -594,7 +596,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
||||||
Msg("Listening with TLS but ServerURL does not start with https://")
|
Msg("Listening with TLS but ServerURL does not start with https://")
|
||||||
}
|
}
|
||||||
|
|
||||||
m := autocert.Manager{
|
certManager := autocert.Manager{
|
||||||
Prompt: autocert.AcceptTOS,
|
Prompt: autocert.AcceptTOS,
|
||||||
HostPolicy: autocert.HostWhitelist(h.cfg.TLSLetsEncryptHostname),
|
HostPolicy: autocert.HostWhitelist(h.cfg.TLSLetsEncryptHostname),
|
||||||
Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir),
|
Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir),
|
||||||
|
@ -609,7 +611,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
||||||
// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
|
// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
|
||||||
// The RFC requires that the validation is done on port 443; in other words, headscale
|
// The RFC requires that the validation is done on port 443; in other words, headscale
|
||||||
// must be reachable on port 443.
|
// must be reachable on port 443.
|
||||||
return m.TLSConfig(), nil
|
return certManager.TLSConfig(), nil
|
||||||
|
|
||||||
case "HTTP-01":
|
case "HTTP-01":
|
||||||
// Configuration via autocert with HTTP-01. This requires listening on
|
// Configuration via autocert with HTTP-01. This requires listening on
|
||||||
|
@ -617,11 +619,11 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
||||||
// service, which can be configured to run on any other port.
|
// service, which can be configured to run on any other port.
|
||||||
go func() {
|
go func() {
|
||||||
log.Fatal().
|
log.Fatal().
|
||||||
Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, m.HTTPHandler(http.HandlerFunc(h.redirect)))).
|
Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))).
|
||||||
Msg("failed to set up a HTTP server")
|
Msg("failed to set up a HTTP server")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return m.TLSConfig(), nil
|
return certManager.TLSConfig(), nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("unknown value for TLSLetsEncryptChallengeType")
|
return nil, errors.New("unknown value for TLSLetsEncryptChallengeType")
|
||||||
|
@ -676,13 +678,13 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func stdoutHandler(c *gin.Context) {
|
func stdoutHandler(ctx *gin.Context) {
|
||||||
b, _ := io.ReadAll(c.Request.Body)
|
body, _ := io.ReadAll(ctx.Request.Body)
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Interface("header", c.Request.Header).
|
Interface("header", ctx.Request.Header).
|
||||||
Interface("proto", c.Request.Proto).
|
Interface("proto", ctx.Request.Proto).
|
||||||
Interface("url", c.Request.URL).
|
Interface("url", ctx.Request.URL).
|
||||||
Bytes("body", b).
|
Bytes("body", body).
|
||||||
Msg("Request did not match")
|
Msg("Request did not match")
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,8 +12,8 @@ import (
|
||||||
|
|
||||||
// AppleMobileConfig shows a simple message in the browser to point to the CLI
|
// AppleMobileConfig shows a simple message in the browser to point to the CLI
|
||||||
// Listens in /register.
|
// Listens in /register.
|
||||||
func (h *Headscale) AppleMobileConfig(c *gin.Context) {
|
func (h *Headscale) AppleMobileConfig(ctx *gin.Context) {
|
||||||
t := template.Must(template.New("apple").Parse(`
|
appleTemplate := template.Must(template.New("apple").Parse(`
|
||||||
<html>
|
<html>
|
||||||
<body>
|
<body>
|
||||||
<h1>Apple configuration profiles</h1>
|
<h1>Apple configuration profiles</h1>
|
||||||
|
@ -67,12 +67,12 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var payload bytes.Buffer
|
var payload bytes.Buffer
|
||||||
if err := t.Execute(&payload, config); err != nil {
|
if err := appleTemplate.Execute(&payload, config); err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "AppleMobileConfig").
|
Str("handler", "AppleMobileConfig").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not render Apple index template")
|
Msg("Could not render Apple index template")
|
||||||
c.Data(
|
ctx.Data(
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
"text/html; charset=utf-8",
|
"text/html; charset=utf-8",
|
||||||
[]byte("Could not render Apple index template"),
|
[]byte("Could not render Apple index template"),
|
||||||
|
@ -81,11 +81,11 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
|
ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) {
|
||||||
platform := c.Param("platform")
|
platform := ctx.Param("platform")
|
||||||
|
|
||||||
id, err := uuid.NewV4()
|
id, err := uuid.NewV4()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -93,7 +93,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
||||||
Str("handler", "ApplePlatformConfig").
|
Str("handler", "ApplePlatformConfig").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Failed not create UUID")
|
Msg("Failed not create UUID")
|
||||||
c.Data(
|
ctx.Data(
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
"text/html; charset=utf-8",
|
"text/html; charset=utf-8",
|
||||||
[]byte("Failed to create UUID"),
|
[]byte("Failed to create UUID"),
|
||||||
|
@ -108,7 +108,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
||||||
Str("handler", "ApplePlatformConfig").
|
Str("handler", "ApplePlatformConfig").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Failed not create UUID")
|
Msg("Failed not create UUID")
|
||||||
c.Data(
|
ctx.Data(
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
"text/html; charset=utf-8",
|
"text/html; charset=utf-8",
|
||||||
[]byte("Failed to create UUID"),
|
[]byte("Failed to create UUID"),
|
||||||
|
@ -131,7 +131,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
||||||
Str("handler", "ApplePlatformConfig").
|
Str("handler", "ApplePlatformConfig").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not render Apple macOS template")
|
Msg("Could not render Apple macOS template")
|
||||||
c.Data(
|
ctx.Data(
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
"text/html; charset=utf-8",
|
"text/html; charset=utf-8",
|
||||||
[]byte("Could not render Apple macOS template"),
|
[]byte("Could not render Apple macOS template"),
|
||||||
|
@ -145,7 +145,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
||||||
Str("handler", "ApplePlatformConfig").
|
Str("handler", "ApplePlatformConfig").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not render Apple iOS template")
|
Msg("Could not render Apple iOS template")
|
||||||
c.Data(
|
ctx.Data(
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
"text/html; charset=utf-8",
|
"text/html; charset=utf-8",
|
||||||
[]byte("Could not render Apple iOS template"),
|
[]byte("Could not render Apple iOS template"),
|
||||||
|
@ -154,7 +154,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
c.Data(
|
ctx.Data(
|
||||||
http.StatusOK,
|
http.StatusOK,
|
||||||
"text/html; charset=utf-8",
|
"text/html; charset=utf-8",
|
||||||
[]byte("Invalid platform, only ios and macos is supported"),
|
[]byte("Invalid platform, only ios and macos is supported"),
|
||||||
|
@ -175,7 +175,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
||||||
Str("handler", "ApplePlatformConfig").
|
Str("handler", "ApplePlatformConfig").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not render Apple platform template")
|
Msg("Could not render Apple platform template")
|
||||||
c.Data(
|
ctx.Data(
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
"text/html; charset=utf-8",
|
"text/html; charset=utf-8",
|
||||||
[]byte("Could not render Apple platform template"),
|
[]byte("Could not render Apple platform template"),
|
||||||
|
@ -184,7 +184,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Data(
|
ctx.Data(
|
||||||
http.StatusOK,
|
http.StatusOK,
|
||||||
"application/x-apple-aspen-config; charset=utf-8",
|
"application/x-apple-aspen-config; charset=utf-8",
|
||||||
content.Bytes(),
|
content.Bytes(),
|
||||||
|
|
|
@ -167,10 +167,10 @@ var listNamespacesCmd = &cobra.Command{
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
d := pterm.TableData{{"ID", "Name", "Created"}}
|
tableData := pterm.TableData{{"ID", "Name", "Created"}}
|
||||||
for _, namespace := range response.GetNamespaces() {
|
for _, namespace := range response.GetNamespaces() {
|
||||||
d = append(
|
tableData = append(
|
||||||
d,
|
tableData,
|
||||||
[]string{
|
[]string{
|
||||||
namespace.GetId(),
|
namespace.GetId(),
|
||||||
namespace.GetName(),
|
namespace.GetName(),
|
||||||
|
@ -178,7 +178,7 @@ var listNamespacesCmd = &cobra.Command{
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
|
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(
|
ErrorOutput(
|
||||||
err,
|
err,
|
||||||
|
|
|
@ -157,14 +157,14 @@ var listNodesCmd = &cobra.Command{
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
d, err := nodesToPtables(namespace, response.Machines)
|
tableData, err := nodesToPtables(namespace, response.Machines)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
|
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(
|
ErrorOutput(
|
||||||
err,
|
err,
|
||||||
|
@ -183,7 +183,7 @@ var deleteNodeCmd = &cobra.Command{
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
output, _ := cmd.Flags().GetString("output")
|
output, _ := cmd.Flags().GetString("output")
|
||||||
|
|
||||||
id, err := cmd.Flags().GetInt("identifier")
|
identifier, err := cmd.Flags().GetInt("identifier")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(
|
ErrorOutput(
|
||||||
err,
|
err,
|
||||||
|
@ -199,7 +199,7 @@ var deleteNodeCmd = &cobra.Command{
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
getRequest := &v1.GetMachineRequest{
|
getRequest := &v1.GetMachineRequest{
|
||||||
MachineId: uint64(id),
|
MachineId: uint64(identifier),
|
||||||
}
|
}
|
||||||
|
|
||||||
getResponse, err := client.GetMachine(ctx, getRequest)
|
getResponse, err := client.GetMachine(ctx, getRequest)
|
||||||
|
@ -217,7 +217,7 @@ var deleteNodeCmd = &cobra.Command{
|
||||||
}
|
}
|
||||||
|
|
||||||
deleteRequest := &v1.DeleteMachineRequest{
|
deleteRequest := &v1.DeleteMachineRequest{
|
||||||
MachineId: uint64(id),
|
MachineId: uint64(identifier),
|
||||||
}
|
}
|
||||||
|
|
||||||
confirm := false
|
confirm := false
|
||||||
|
@ -280,7 +280,7 @@ func sharingWorker(
|
||||||
defer cancel()
|
defer cancel()
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
id, err := cmd.Flags().GetInt("identifier")
|
identifier, err := cmd.Flags().GetInt("identifier")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output)
|
ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output)
|
||||||
|
|
||||||
|
@ -288,7 +288,7 @@ func sharingWorker(
|
||||||
}
|
}
|
||||||
|
|
||||||
machineRequest := &v1.GetMachineRequest{
|
machineRequest := &v1.GetMachineRequest{
|
||||||
MachineId: uint64(id),
|
MachineId: uint64(identifier),
|
||||||
}
|
}
|
||||||
|
|
||||||
machineResponse, err := client.GetMachine(ctx, machineRequest)
|
machineResponse, err := client.GetMachine(ctx, machineRequest)
|
||||||
|
@ -402,7 +402,7 @@ func nodesToPtables(
|
||||||
currentNamespace string,
|
currentNamespace string,
|
||||||
machines []*v1.Machine,
|
machines []*v1.Machine,
|
||||||
) (pterm.TableData, error) {
|
) (pterm.TableData, error) {
|
||||||
d := pterm.TableData{
|
tableData := pterm.TableData{
|
||||||
{
|
{
|
||||||
"ID",
|
"ID",
|
||||||
"Name",
|
"Name",
|
||||||
|
@ -448,8 +448,8 @@ func nodesToPtables(
|
||||||
// Shared into this namespace
|
// Shared into this namespace
|
||||||
namespace = pterm.LightYellow(machine.Namespace.Name)
|
namespace = pterm.LightYellow(machine.Namespace.Name)
|
||||||
}
|
}
|
||||||
d = append(
|
tableData = append(
|
||||||
d,
|
tableData,
|
||||||
[]string{
|
[]string{
|
||||||
strconv.FormatUint(machine.Id, headscale.BASE_10),
|
strconv.FormatUint(machine.Id, headscale.BASE_10),
|
||||||
machine.Name,
|
machine.Name,
|
||||||
|
@ -463,5 +463,5 @@ func nodesToPtables(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return d, nil
|
return tableData, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,7 @@ var listPreAuthKeys = &cobra.Command{
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
output, _ := cmd.Flags().GetString("output")
|
output, _ := cmd.Flags().GetString("output")
|
||||||
|
|
||||||
n, err := cmd.Flags().GetString("namespace")
|
namespace, err := cmd.Flags().GetString("namespace")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
||||||
|
|
||||||
|
@ -57,7 +57,7 @@ var listPreAuthKeys = &cobra.Command{
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
request := &v1.ListPreAuthKeysRequest{
|
request := &v1.ListPreAuthKeysRequest{
|
||||||
Namespace: n,
|
Namespace: namespace,
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := client.ListPreAuthKeys(ctx, request)
|
response, err := client.ListPreAuthKeys(ctx, request)
|
||||||
|
@ -77,34 +77,34 @@ var listPreAuthKeys = &cobra.Command{
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
d := pterm.TableData{
|
tableData := pterm.TableData{
|
||||||
{"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"},
|
{"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"},
|
||||||
}
|
}
|
||||||
for _, k := range response.PreAuthKeys {
|
for _, key := range response.PreAuthKeys {
|
||||||
expiration := "-"
|
expiration := "-"
|
||||||
if k.GetExpiration() != nil {
|
if key.GetExpiration() != nil {
|
||||||
expiration = k.Expiration.AsTime().Format("2006-01-02 15:04:05")
|
expiration = key.Expiration.AsTime().Format("2006-01-02 15:04:05")
|
||||||
}
|
}
|
||||||
|
|
||||||
var reusable string
|
var reusable string
|
||||||
if k.GetEphemeral() {
|
if key.GetEphemeral() {
|
||||||
reusable = "N/A"
|
reusable = "N/A"
|
||||||
} else {
|
} else {
|
||||||
reusable = fmt.Sprintf("%v", k.GetReusable())
|
reusable = fmt.Sprintf("%v", key.GetReusable())
|
||||||
}
|
}
|
||||||
|
|
||||||
d = append(d, []string{
|
tableData = append(tableData, []string{
|
||||||
k.GetId(),
|
key.GetId(),
|
||||||
k.GetKey(),
|
key.GetKey(),
|
||||||
reusable,
|
reusable,
|
||||||
strconv.FormatBool(k.GetEphemeral()),
|
strconv.FormatBool(key.GetEphemeral()),
|
||||||
strconv.FormatBool(k.GetUsed()),
|
strconv.FormatBool(key.GetUsed()),
|
||||||
expiration,
|
expiration,
|
||||||
k.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
|
key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
|
||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
|
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(
|
ErrorOutput(
|
||||||
err,
|
err,
|
||||||
|
|
|
@ -81,14 +81,14 @@ var listRoutesCmd = &cobra.Command{
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
d := routesToPtables(response.Routes)
|
tableData := routesToPtables(response.Routes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
|
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(
|
ErrorOutput(
|
||||||
err,
|
err,
|
||||||
|
@ -162,14 +162,14 @@ omit the route you do not want to enable.
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
d := routesToPtables(response.Routes)
|
tableData := routesToPtables(response.Routes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
|
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(
|
ErrorOutput(
|
||||||
err,
|
err,
|
||||||
|
@ -184,15 +184,15 @@ omit the route you do not want to enable.
|
||||||
|
|
||||||
// routesToPtables converts the list of routes to a nice table.
|
// routesToPtables converts the list of routes to a nice table.
|
||||||
func routesToPtables(routes *v1.Routes) pterm.TableData {
|
func routesToPtables(routes *v1.Routes) pterm.TableData {
|
||||||
d := pterm.TableData{{"Route", "Enabled"}}
|
tableData := pterm.TableData{{"Route", "Enabled"}}
|
||||||
|
|
||||||
for _, route := range routes.GetAdvertisedRoutes() {
|
for _, route := range routes.GetAdvertisedRoutes() {
|
||||||
enabled := isStringInSlice(routes.EnabledRoutes, route)
|
enabled := isStringInSlice(routes.EnabledRoutes, route)
|
||||||
|
|
||||||
d = append(d, []string{route, strconv.FormatBool(enabled)})
|
tableData = append(tableData, []string{route, strconv.FormatBool(enabled)})
|
||||||
}
|
}
|
||||||
|
|
||||||
return d
|
return tableData
|
||||||
}
|
}
|
||||||
|
|
||||||
func isStringInSlice(strs []string, s string) bool {
|
func isStringInSlice(strs []string, s string) bool {
|
||||||
|
|
|
@ -318,7 +318,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
||||||
|
|
||||||
cfg.OIDC.MatchMap = loadOIDCMatchMap()
|
cfg.OIDC.MatchMap = loadOIDCMatchMap()
|
||||||
|
|
||||||
h, err := headscale.NewHeadscale(cfg)
|
app, err := headscale.NewHeadscale(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -327,7 +327,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
||||||
|
|
||||||
if viper.GetString("acl_policy_path") != "" {
|
if viper.GetString("acl_policy_path") != "" {
|
||||||
aclPath := absPath(viper.GetString("acl_policy_path"))
|
aclPath := absPath(viper.GetString("acl_policy_path"))
|
||||||
err = h.LoadACLPolicy(aclPath)
|
err = app.LoadACLPolicy(aclPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("path", aclPath).
|
Str("path", aclPath).
|
||||||
|
@ -336,7 +336,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return h, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
|
func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
|
||||||
|
|
6
dns.go
6
dns.go
|
@ -79,7 +79,7 @@ func generateMagicDNSRootDomains(
|
||||||
func getMapResponseDNSConfig(
|
func getMapResponseDNSConfig(
|
||||||
dnsConfigOrig *tailcfg.DNSConfig,
|
dnsConfigOrig *tailcfg.DNSConfig,
|
||||||
baseDomain string,
|
baseDomain string,
|
||||||
m Machine,
|
machine Machine,
|
||||||
peers Machines,
|
peers Machines,
|
||||||
) *tailcfg.DNSConfig {
|
) *tailcfg.DNSConfig {
|
||||||
var dnsConfig *tailcfg.DNSConfig
|
var dnsConfig *tailcfg.DNSConfig
|
||||||
|
@ -88,11 +88,11 @@ func getMapResponseDNSConfig(
|
||||||
dnsConfig = dnsConfigOrig.Clone()
|
dnsConfig = dnsConfigOrig.Clone()
|
||||||
dnsConfig.Domains = append(
|
dnsConfig.Domains = append(
|
||||||
dnsConfig.Domains,
|
dnsConfig.Domains,
|
||||||
fmt.Sprintf("%s.%s", m.Namespace.Name, baseDomain),
|
fmt.Sprintf("%s.%s", machine.Namespace.Name, baseDomain),
|
||||||
)
|
)
|
||||||
|
|
||||||
namespaceSet := set.New(set.ThreadSafe)
|
namespaceSet := set.New(set.ThreadSafe)
|
||||||
namespaceSet.Add(m.Namespace)
|
namespaceSet.Add(machine.Namespace)
|
||||||
for _, p := range peers {
|
for _, p := range peers {
|
||||||
namespaceSet.Add(p.Namespace)
|
namespaceSet.Add(p.Namespace)
|
||||||
}
|
}
|
||||||
|
|
299
machine.go
299
machine.go
|
@ -56,21 +56,21 @@ type (
|
||||||
)
|
)
|
||||||
|
|
||||||
// For the time being this method is rather naive.
|
// For the time being this method is rather naive.
|
||||||
func (m Machine) isAlreadyRegistered() bool {
|
func (machine Machine) isAlreadyRegistered() bool {
|
||||||
return m.Registered
|
return machine.Registered
|
||||||
}
|
}
|
||||||
|
|
||||||
// isExpired returns whether the machine registration has expired.
|
// isExpired returns whether the machine registration has expired.
|
||||||
func (m Machine) isExpired() bool {
|
func (machine Machine) isExpired() bool {
|
||||||
return time.Now().UTC().After(*m.Expiry)
|
return time.Now().UTC().After(*machine.Expiry)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the Machine is expired, updateMachineExpiry updates the Machine Expiry time to the maximum allowed duration,
|
// If the Machine is expired, updateMachineExpiry updates the Machine Expiry time to the maximum allowed duration,
|
||||||
// or the default duration if no Expiry time was requested by the client. The expiry time here does not (yet) cause
|
// or the default duration if no Expiry time was requested by the client. The expiry time here does not (yet) cause
|
||||||
// a client to be disconnected, however they will have to re-auth the machine if they attempt to reconnect after the
|
// a client to be disconnected, however they will have to re-auth the machine if they attempt to reconnect after the
|
||||||
// expiry time.
|
// expiry time.
|
||||||
func (h *Headscale) updateMachineExpiry(m *Machine) {
|
func (h *Headscale) updateMachineExpiry(machine *Machine) {
|
||||||
if m.isExpired() {
|
if machine.isExpired() {
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
maxExpiry := now.Add(
|
maxExpiry := now.Add(
|
||||||
h.cfg.MaxMachineRegistrationDuration,
|
h.cfg.MaxMachineRegistrationDuration,
|
||||||
|
@ -80,31 +80,31 @@ func (h *Headscale) updateMachineExpiry(m *Machine) {
|
||||||
) // calculate the default expiry
|
) // calculate the default expiry
|
||||||
|
|
||||||
// clamp the expiry time of the machine registration to the maximum allowed, or use the default if none supplied
|
// clamp the expiry time of the machine registration to the maximum allowed, or use the default if none supplied
|
||||||
if maxExpiry.Before(*m.RequestedExpiry) {
|
if maxExpiry.Before(*machine.RequestedExpiry) {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Msgf("Clamping registration expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineRegistrationDuration)
|
Msgf("Clamping registration expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineRegistrationDuration)
|
||||||
m.Expiry = &maxExpiry
|
machine.Expiry = &maxExpiry
|
||||||
} else if m.RequestedExpiry.IsZero() {
|
} else if machine.RequestedExpiry.IsZero() {
|
||||||
log.Debug().Msgf("Using default machine registration expiry time: %v (%v)", defaultExpiry, h.cfg.DefaultMachineRegistrationDuration)
|
log.Debug().Msgf("Using default machine registration expiry time: %v (%v)", defaultExpiry, h.cfg.DefaultMachineRegistrationDuration)
|
||||||
m.Expiry = &defaultExpiry
|
machine.Expiry = &defaultExpiry
|
||||||
} else {
|
} else {
|
||||||
log.Debug().Msgf("Using requested machine registration expiry time: %v", m.RequestedExpiry)
|
log.Debug().Msgf("Using requested machine registration expiry time: %v", machine.RequestedExpiry)
|
||||||
m.Expiry = m.RequestedExpiry
|
machine.Expiry = machine.RequestedExpiry
|
||||||
}
|
}
|
||||||
|
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
|
func (h *Headscale) getDirectPeers(machine *Machine) (Machines, error) {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Finding direct peers")
|
Msg("Finding direct peers")
|
||||||
|
|
||||||
machines := Machines{}
|
machines := Machines{}
|
||||||
if err := h.db.Preload("Namespace").Where("namespace_id = ? AND machine_key <> ? AND registered",
|
if err := h.db.Preload("Namespace").Where("namespace_id = ? AND machine_key <> ? AND registered",
|
||||||
m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil {
|
machine.NamespaceID, machine.MachineKey).Find(&machines).Error; err != nil {
|
||||||
log.Error().Err(err).Msg("Error accessing db")
|
log.Error().Err(err).Msg("Error accessing db")
|
||||||
|
|
||||||
return Machines{}, err
|
return Machines{}, err
|
||||||
|
@ -114,22 +114,22 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msgf("Found direct machines: %s", machines.String())
|
Msgf("Found direct machines: %s", machines.String())
|
||||||
|
|
||||||
return machines, nil
|
return machines, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for.
|
// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for.
|
||||||
func (h *Headscale) getShared(m *Machine) (Machines, error) {
|
func (h *Headscale) getShared(machine *Machine) (Machines, error) {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Finding shared peers")
|
Msg("Finding shared peers")
|
||||||
|
|
||||||
sharedMachines := []SharedMachine{}
|
sharedMachines := []SharedMachine{}
|
||||||
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("namespace_id = ?",
|
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("namespace_id = ?",
|
||||||
m.NamespaceID).Find(&sharedMachines).Error; err != nil {
|
machine.NamespaceID).Find(&sharedMachines).Error; err != nil {
|
||||||
return Machines{}, err
|
return Machines{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -142,22 +142,22 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) {
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msgf("Found shared peers: %s", peers.String())
|
Msgf("Found shared peers: %s", peers.String())
|
||||||
|
|
||||||
return peers, nil
|
return peers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getSharedTo fetches the machines of the namespaces this machine is shared in.
|
// getSharedTo fetches the machines of the namespaces this machine is shared in.
|
||||||
func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
|
func (h *Headscale) getSharedTo(machine *Machine) (Machines, error) {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Finding peers in namespaces this machine is shared with")
|
Msg("Finding peers in namespaces this machine is shared with")
|
||||||
|
|
||||||
sharedMachines := []SharedMachine{}
|
sharedMachines := []SharedMachine{}
|
||||||
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("machine_id = ?",
|
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("machine_id = ?",
|
||||||
m.ID).Find(&sharedMachines).Error; err != nil {
|
machine.ID).Find(&sharedMachines).Error; err != nil {
|
||||||
return Machines{}, err
|
return Machines{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -176,14 +176,14 @@ func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msgf("Found peers we are shared with: %s", peers.String())
|
Msgf("Found peers we are shared with: %s", peers.String())
|
||||||
|
|
||||||
return peers, nil
|
return peers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) getPeers(m *Machine) (Machines, error) {
|
func (h *Headscale) getPeers(machine *Machine) (Machines, error) {
|
||||||
direct, err := h.getDirectPeers(m)
|
direct, err := h.getDirectPeers(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -193,7 +193,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
|
||||||
return Machines{}, err
|
return Machines{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
shared, err := h.getShared(m)
|
shared, err := h.getShared(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -203,7 +203,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
|
||||||
return Machines{}, err
|
return Machines{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
sharedTo, err := h.getSharedTo(m)
|
sharedTo, err := h.getSharedTo(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -220,7 +220,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msgf("Found total peers: %s", peers.String())
|
Msgf("Found total peers: %s", peers.String())
|
||||||
|
|
||||||
return peers, nil
|
return peers, nil
|
||||||
|
@ -262,9 +262,9 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
|
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
|
||||||
func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) {
|
func (h *Headscale) GetMachineByMachineKey(machineKey string) (*Machine, error) {
|
||||||
m := Machine{}
|
m := Machine{}
|
||||||
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey); result.Error != nil {
|
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", machineKey); result.Error != nil {
|
||||||
return nil, result.Error
|
return nil, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -273,8 +273,8 @@ func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) {
|
||||||
|
|
||||||
// UpdateMachine takes a Machine struct pointer (typically already loaded from database
|
// UpdateMachine takes a Machine struct pointer (typically already loaded from database
|
||||||
// and updates it with the latest data from the database.
|
// and updates it with the latest data from the database.
|
||||||
func (h *Headscale) UpdateMachine(m *Machine) error {
|
func (h *Headscale) UpdateMachine(machine *Machine) error {
|
||||||
if result := h.db.Find(m).First(&m); result.Error != nil {
|
if result := h.db.Find(machine).First(&machine); result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -282,16 +282,16 @@ func (h *Headscale) UpdateMachine(m *Machine) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteMachine softs deletes a Machine from the database.
|
// DeleteMachine softs deletes a Machine from the database.
|
||||||
func (h *Headscale) DeleteMachine(m *Machine) error {
|
func (h *Headscale) DeleteMachine(machine *Machine) error {
|
||||||
err := h.RemoveSharedMachineFromAllNamespaces(m)
|
err := h.RemoveSharedMachineFromAllNamespaces(machine)
|
||||||
if err != nil && err != errorMachineNotShared {
|
if err != nil && err != errorMachineNotShared {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.Registered = false
|
machine.Registered = false
|
||||||
namespaceID := m.NamespaceID
|
namespaceID := machine.NamespaceID
|
||||||
h.db.Save(&m) // we mark it as unregistered, just in case
|
h.db.Save(&machine) // we mark it as unregistered, just in case
|
||||||
if err := h.db.Delete(&m).Error; err != nil {
|
if err := h.db.Delete(&machine).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -299,14 +299,14 @@ func (h *Headscale) DeleteMachine(m *Machine) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// HardDeleteMachine hard deletes a Machine from the database.
|
// HardDeleteMachine hard deletes a Machine from the database.
|
||||||
func (h *Headscale) HardDeleteMachine(m *Machine) error {
|
func (h *Headscale) HardDeleteMachine(machine *Machine) error {
|
||||||
err := h.RemoveSharedMachineFromAllNamespaces(m)
|
err := h.RemoveSharedMachineFromAllNamespaces(machine)
|
||||||
if err != nil && err != errorMachineNotShared {
|
if err != nil && err != errorMachineNotShared {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
namespaceID := m.NamespaceID
|
namespaceID := machine.NamespaceID
|
||||||
if err := h.db.Unscoped().Delete(&m).Error; err != nil {
|
if err := h.db.Unscoped().Delete(&machine).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -314,10 +314,10 @@ func (h *Headscale) HardDeleteMachine(m *Machine) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHostInfo returns a Hostinfo struct for the machine.
|
// GetHostInfo returns a Hostinfo struct for the machine.
|
||||||
func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
|
func (machine *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
|
||||||
hostinfo := tailcfg.Hostinfo{}
|
hostinfo := tailcfg.Hostinfo{}
|
||||||
if len(m.HostInfo) != 0 {
|
if len(machine.HostInfo) != 0 {
|
||||||
hi, err := m.HostInfo.MarshalJSON()
|
hi, err := machine.HostInfo.MarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -330,17 +330,17 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
|
||||||
return &hostinfo, nil
|
return &hostinfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) isOutdated(m *Machine) bool {
|
func (h *Headscale) isOutdated(machine *Machine) bool {
|
||||||
if err := h.UpdateMachine(m); err != nil {
|
if err := h.UpdateMachine(machine); err != nil {
|
||||||
// It does not seem meaningful to propagate this error as the end result
|
// It does not seem meaningful to propagate this error as the end result
|
||||||
// will have to be that the machine has to be considered outdated.
|
// will have to be that the machine has to be considered outdated.
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
sharedMachines, _ := h.getShared(m)
|
sharedMachines, _ := h.getShared(machine)
|
||||||
|
|
||||||
namespaceSet := set.New(set.ThreadSafe)
|
namespaceSet := set.New(set.ThreadSafe)
|
||||||
namespaceSet.Add(m.Namespace.Name)
|
namespaceSet.Add(machine.Namespace.Name)
|
||||||
|
|
||||||
// Check if any of our shared namespaces has updates that we have
|
// Check if any of our shared namespaces has updates that we have
|
||||||
// not propagated.
|
// not propagated.
|
||||||
|
@ -356,22 +356,22 @@ func (h *Headscale) isOutdated(m *Machine) bool {
|
||||||
lastChange := h.getLastStateChange(namespaces...)
|
lastChange := h.getLastStateChange(namespaces...)
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Time("last_successful_update", *m.LastSuccessfulUpdate).
|
Time("last_successful_update", *machine.LastSuccessfulUpdate).
|
||||||
Time("last_state_change", lastChange).
|
Time("last_state_change", lastChange).
|
||||||
Msgf("Checking if %s is missing updates", m.Name)
|
Msgf("Checking if %s is missing updates", machine.Name)
|
||||||
|
|
||||||
return m.LastSuccessfulUpdate.Before(lastChange)
|
return machine.LastSuccessfulUpdate.Before(lastChange)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Machine) String() string {
|
func (machine Machine) String() string {
|
||||||
return m.Name
|
return machine.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms Machines) String() string {
|
func (machines Machines) String() string {
|
||||||
temp := make([]string, len(ms))
|
temp := make([]string, len(machines))
|
||||||
|
|
||||||
for index, machine := range ms {
|
for index, machine := range machines {
|
||||||
temp[index] = machine.Name
|
temp[index] = machine.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -379,24 +379,24 @@ func (ms Machines) String() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): Remove when we have generics...
|
// TODO(kradalby): Remove when we have generics...
|
||||||
func (ms MachinesP) String() string {
|
func (machines MachinesP) String() string {
|
||||||
temp := make([]string, len(ms))
|
temp := make([]string, len(machines))
|
||||||
|
|
||||||
for index, machine := range ms {
|
for index, machine := range machines {
|
||||||
temp[index] = machine.Name
|
temp[index] = machine.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
|
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms Machines) toNodes(
|
func (machines Machines) toNodes(
|
||||||
baseDomain string,
|
baseDomain string,
|
||||||
dnsConfig *tailcfg.DNSConfig,
|
dnsConfig *tailcfg.DNSConfig,
|
||||||
includeRoutes bool,
|
includeRoutes bool,
|
||||||
) ([]*tailcfg.Node, error) {
|
) ([]*tailcfg.Node, error) {
|
||||||
nodes := make([]*tailcfg.Node, len(ms))
|
nodes := make([]*tailcfg.Node, len(machines))
|
||||||
|
|
||||||
for index, machine := range ms {
|
for index, machine := range machines {
|
||||||
node, err := machine.toNode(baseDomain, dnsConfig, includeRoutes)
|
node, err := machine.toNode(baseDomain, dnsConfig, includeRoutes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -410,23 +410,24 @@ func (ms Machines) toNodes(
|
||||||
|
|
||||||
// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
|
// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
|
||||||
// as per the expected behaviour in the official SaaS.
|
// as per the expected behaviour in the official SaaS.
|
||||||
func (m Machine) toNode(
|
func (machine Machine) toNode(
|
||||||
baseDomain string,
|
baseDomain string,
|
||||||
dnsConfig *tailcfg.DNSConfig,
|
dnsConfig *tailcfg.DNSConfig,
|
||||||
includeRoutes bool,
|
includeRoutes bool,
|
||||||
) (*tailcfg.Node, error) {
|
) (*tailcfg.Node, error) {
|
||||||
nKey, err := wgkey.ParseHex(m.NodeKey)
|
nodeKey, err := wgkey.ParseHex(machine.NodeKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
mKey, err := wgkey.ParseHex(m.MachineKey)
|
|
||||||
|
machineKey, err := wgkey.ParseHex(machine.MachineKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var discoKey tailcfg.DiscoKey
|
var discoKey tailcfg.DiscoKey
|
||||||
if m.DiscoKey != "" {
|
if machine.DiscoKey != "" {
|
||||||
dKey, err := wgkey.ParseHex(m.DiscoKey)
|
dKey, err := wgkey.ParseHex(machine.DiscoKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -436,12 +437,12 @@ func (m Machine) toNode(
|
||||||
}
|
}
|
||||||
|
|
||||||
addrs := []netaddr.IPPrefix{}
|
addrs := []netaddr.IPPrefix{}
|
||||||
ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", m.IPAddress))
|
ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", machine.IPAddress))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("ip", m.IPAddress).
|
Str("ip", machine.IPAddress).
|
||||||
Msgf("Failed to parse IP Prefix from IP: %s", m.IPAddress)
|
Msgf("Failed to parse IP Prefix from IP: %s", machine.IPAddress)
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -455,8 +456,8 @@ func (m Machine) toNode(
|
||||||
|
|
||||||
if includeRoutes {
|
if includeRoutes {
|
||||||
routesStr := []string{}
|
routesStr := []string{}
|
||||||
if len(m.EnabledRoutes) != 0 {
|
if len(machine.EnabledRoutes) != 0 {
|
||||||
allwIps, err := m.EnabledRoutes.MarshalJSON()
|
allwIps, err := machine.EnabledRoutes.MarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -476,8 +477,8 @@ func (m Machine) toNode(
|
||||||
}
|
}
|
||||||
|
|
||||||
endpoints := []string{}
|
endpoints := []string{}
|
||||||
if len(m.Endpoints) != 0 {
|
if len(machine.Endpoints) != 0 {
|
||||||
be, err := m.Endpoints.MarshalJSON()
|
be, err := machine.Endpoints.MarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -488,8 +489,8 @@ func (m Machine) toNode(
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo := tailcfg.Hostinfo{}
|
hostinfo := tailcfg.Hostinfo{}
|
||||||
if len(m.HostInfo) != 0 {
|
if len(machine.HostInfo) != 0 {
|
||||||
hi, err := m.HostInfo.MarshalJSON()
|
hi, err := machine.HostInfo.MarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -507,29 +508,34 @@ func (m Machine) toNode(
|
||||||
}
|
}
|
||||||
|
|
||||||
var keyExpiry time.Time
|
var keyExpiry time.Time
|
||||||
if m.Expiry != nil {
|
if machine.Expiry != nil {
|
||||||
keyExpiry = *m.Expiry
|
keyExpiry = *machine.Expiry
|
||||||
} else {
|
} else {
|
||||||
keyExpiry = time.Time{}
|
keyExpiry = time.Time{}
|
||||||
}
|
}
|
||||||
|
|
||||||
var hostname string
|
var hostname string
|
||||||
if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS
|
if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS
|
||||||
hostname = fmt.Sprintf("%s.%s.%s", m.Name, m.Namespace.Name, baseDomain)
|
hostname = fmt.Sprintf(
|
||||||
|
"%s.%s.%s",
|
||||||
|
machine.Name,
|
||||||
|
machine.Namespace.Name,
|
||||||
|
baseDomain,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
hostname = m.Name
|
hostname = machine.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
n := tailcfg.Node{
|
n := tailcfg.Node{
|
||||||
ID: tailcfg.NodeID(m.ID), // this is the actual ID
|
ID: tailcfg.NodeID(machine.ID), // this is the actual ID
|
||||||
StableID: tailcfg.StableNodeID(
|
StableID: tailcfg.StableNodeID(
|
||||||
strconv.FormatUint(m.ID, BASE_10),
|
strconv.FormatUint(machine.ID, BASE_10),
|
||||||
), // in headscale, unlike tailcontrol server, IDs are permanent
|
), // in headscale, unlike tailcontrol server, IDs are permanent
|
||||||
Name: hostname,
|
Name: hostname,
|
||||||
User: tailcfg.UserID(m.NamespaceID),
|
User: tailcfg.UserID(machine.NamespaceID),
|
||||||
Key: tailcfg.NodeKey(nKey),
|
Key: tailcfg.NodeKey(nodeKey),
|
||||||
KeyExpiry: keyExpiry,
|
KeyExpiry: keyExpiry,
|
||||||
Machine: tailcfg.MachineKey(mKey),
|
Machine: tailcfg.MachineKey(machineKey),
|
||||||
DiscoKey: discoKey,
|
DiscoKey: discoKey,
|
||||||
Addresses: addrs,
|
Addresses: addrs,
|
||||||
AllowedIPs: allowedIPs,
|
AllowedIPs: allowedIPs,
|
||||||
|
@ -537,68 +543,73 @@ func (m Machine) toNode(
|
||||||
DERP: derp,
|
DERP: derp,
|
||||||
|
|
||||||
Hostinfo: hostinfo,
|
Hostinfo: hostinfo,
|
||||||
Created: m.CreatedAt,
|
Created: machine.CreatedAt,
|
||||||
LastSeen: m.LastSeen,
|
LastSeen: machine.LastSeen,
|
||||||
|
|
||||||
KeepAlive: true,
|
KeepAlive: true,
|
||||||
MachineAuthorized: m.Registered,
|
MachineAuthorized: machine.Registered,
|
||||||
Capabilities: []string{tailcfg.CapabilityFileSharing},
|
Capabilities: []string{tailcfg.CapabilityFileSharing},
|
||||||
}
|
}
|
||||||
|
|
||||||
return &n, nil
|
return &n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Machine) toProto() *v1.Machine {
|
func (machine *Machine) toProto() *v1.Machine {
|
||||||
machine := &v1.Machine{
|
machineProto := &v1.Machine{
|
||||||
Id: m.ID,
|
Id: machine.ID,
|
||||||
MachineKey: m.MachineKey,
|
MachineKey: machine.MachineKey,
|
||||||
|
|
||||||
NodeKey: m.NodeKey,
|
NodeKey: machine.NodeKey,
|
||||||
DiscoKey: m.DiscoKey,
|
DiscoKey: machine.DiscoKey,
|
||||||
IpAddress: m.IPAddress,
|
IpAddress: machine.IPAddress,
|
||||||
Name: m.Name,
|
Name: machine.Name,
|
||||||
Namespace: m.Namespace.toProto(),
|
Namespace: machine.Namespace.toProto(),
|
||||||
|
|
||||||
Registered: m.Registered,
|
Registered: machine.Registered,
|
||||||
|
|
||||||
// TODO(kradalby): Implement register method enum converter
|
// TODO(kradalby): Implement register method enum converter
|
||||||
// RegisterMethod: ,
|
// RegisterMethod: ,
|
||||||
|
|
||||||
CreatedAt: timestamppb.New(m.CreatedAt),
|
CreatedAt: timestamppb.New(machine.CreatedAt),
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.AuthKey != nil {
|
if machine.AuthKey != nil {
|
||||||
machine.PreAuthKey = m.AuthKey.toProto()
|
machineProto.PreAuthKey = machine.AuthKey.toProto()
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.LastSeen != nil {
|
if machine.LastSeen != nil {
|
||||||
machine.LastSeen = timestamppb.New(*m.LastSeen)
|
machineProto.LastSeen = timestamppb.New(*machine.LastSeen)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.LastSuccessfulUpdate != nil {
|
if machine.LastSuccessfulUpdate != nil {
|
||||||
machine.LastSuccessfulUpdate = timestamppb.New(*m.LastSuccessfulUpdate)
|
machineProto.LastSuccessfulUpdate = timestamppb.New(
|
||||||
|
*machine.LastSuccessfulUpdate,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.Expiry != nil {
|
if machine.Expiry != nil {
|
||||||
machine.Expiry = timestamppb.New(*m.Expiry)
|
machineProto.Expiry = timestamppb.New(*machine.Expiry)
|
||||||
}
|
}
|
||||||
|
|
||||||
return machine
|
return machineProto
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
|
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
|
||||||
func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, error) {
|
func (h *Headscale) RegisterMachine(
|
||||||
ns, err := h.GetNamespace(namespace)
|
key string,
|
||||||
|
namespaceName string,
|
||||||
|
) (*Machine, error) {
|
||||||
|
namespace, err := h.GetNamespace(namespaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
mKey, err := wgkey.ParseHex(key)
|
machineKey, err := wgkey.ParseHex(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
m := Machine{}
|
machine := Machine{}
|
||||||
if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(
|
if result := h.db.First(&machine, "machine_key = ?", machineKey.HexString()); errors.Is(
|
||||||
result.Error,
|
result.Error,
|
||||||
gorm.ErrRecordNotFound,
|
gorm.ErrRecordNotFound,
|
||||||
) {
|
) {
|
||||||
|
@ -607,15 +618,15 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Attempting to register machine")
|
Msg("Attempting to register machine")
|
||||||
|
|
||||||
if m.isAlreadyRegistered() {
|
if machine.isAlreadyRegistered() {
|
||||||
err := errors.New("Machine already registered")
|
err := errors.New("Machine already registered")
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Attempting to register machine")
|
Msg("Attempting to register machine")
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -626,7 +637,7 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Could not find IP for the new machine")
|
Msg("Could not find IP for the new machine")
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -634,27 +645,27 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("ip", ip.String()).
|
Str("ip", ip.String()).
|
||||||
Msg("Found IP for host")
|
Msg("Found IP for host")
|
||||||
|
|
||||||
m.IPAddress = ip.String()
|
machine.IPAddress = ip.String()
|
||||||
m.NamespaceID = ns.ID
|
machine.NamespaceID = namespace.ID
|
||||||
m.Registered = true
|
machine.Registered = true
|
||||||
m.RegisterMethod = "cli"
|
machine.RegisterMethod = "cli"
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Str("ip", ip.String()).
|
Str("ip", ip.String()).
|
||||||
Msg("Machine registered with the database")
|
Msg("Machine registered with the database")
|
||||||
|
|
||||||
return &m, nil
|
return &machine, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
|
func (machine *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
|
||||||
hostInfo, err := m.GetHostInfo()
|
hostInfo, err := machine.GetHostInfo()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -662,8 +673,8 @@ func (m *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
|
||||||
return hostInfo.RoutableIPs, nil
|
return hostInfo.RoutableIPs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
|
func (machine *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
|
||||||
data, err := m.EnabledRoutes.MarshalJSON()
|
data, err := machine.EnabledRoutes.MarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -686,13 +697,13 @@ func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
|
||||||
return routes, nil
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Machine) IsRoutesEnabled(routeStr string) bool {
|
func (machine *Machine) IsRoutesEnabled(routeStr string) bool {
|
||||||
route, err := netaddr.ParseIPPrefix(routeStr)
|
route, err := netaddr.ParseIPPrefix(routeStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
enabledRoutes, err := m.GetEnabledRoutes()
|
enabledRoutes, err := machine.GetEnabledRoutes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -708,7 +719,7 @@ func (m *Machine) IsRoutesEnabled(routeStr string) bool {
|
||||||
|
|
||||||
// EnableNodeRoute enables new routes based on a list of new routes. It will _replace_ the
|
// EnableNodeRoute enables new routes based on a list of new routes. It will _replace_ the
|
||||||
// previous list of routes.
|
// previous list of routes.
|
||||||
func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
|
func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error {
|
||||||
newRoutes := make([]netaddr.IPPrefix, len(routeStrs))
|
newRoutes := make([]netaddr.IPPrefix, len(routeStrs))
|
||||||
for index, routeStr := range routeStrs {
|
for index, routeStr := range routeStrs {
|
||||||
route, err := netaddr.ParseIPPrefix(routeStr)
|
route, err := netaddr.ParseIPPrefix(routeStr)
|
||||||
|
@ -719,7 +730,7 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
|
||||||
newRoutes[index] = route
|
newRoutes[index] = route
|
||||||
}
|
}
|
||||||
|
|
||||||
availableRoutes, err := m.GetAdvertisedRoutes()
|
availableRoutes, err := machine.GetAdvertisedRoutes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -728,7 +739,7 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
|
||||||
if !containsIpPrefix(availableRoutes, newRoute) {
|
if !containsIpPrefix(availableRoutes, newRoute) {
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
"route (%s) is not available on node %s",
|
"route (%s) is not available on node %s",
|
||||||
m.Name,
|
machine.Name,
|
||||||
newRoute,
|
newRoute,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -739,10 +750,10 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.EnabledRoutes = datatypes.JSON(routes)
|
machine.EnabledRoutes = datatypes.JSON(routes)
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
|
|
||||||
err = h.RequestMapUpdates(m.NamespaceID)
|
err = h.RequestMapUpdates(machine.NamespaceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -750,13 +761,13 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Machine) RoutesToProto() (*v1.Routes, error) {
|
func (machine *Machine) RoutesToProto() (*v1.Routes, error) {
|
||||||
availableRoutes, err := m.GetAdvertisedRoutes()
|
availableRoutes, err := machine.GetAdvertisedRoutes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
enabledRoutes, err := m.GetEnabledRoutes()
|
enabledRoutes, err := machine.GetEnabledRoutes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,12 +32,12 @@ type Namespace struct {
|
||||||
// CreateNamespace creates a new Namespace. Returns error if could not be created
|
// CreateNamespace creates a new Namespace. Returns error if could not be created
|
||||||
// or another namespace already exists.
|
// or another namespace already exists.
|
||||||
func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
|
func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
|
||||||
n := Namespace{}
|
namespace := Namespace{}
|
||||||
if err := h.db.Where("name = ?", name).First(&n).Error; err == nil {
|
if err := h.db.Where("name = ?", name).First(&namespace).Error; err == nil {
|
||||||
return nil, errorNamespaceExists
|
return nil, errorNamespaceExists
|
||||||
}
|
}
|
||||||
n.Name = name
|
namespace.Name = name
|
||||||
if err := h.db.Create(&n).Error; err != nil {
|
if err := h.db.Create(&namespace).Error; err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "CreateNamespace").
|
Str("func", "CreateNamespace").
|
||||||
Err(err).
|
Err(err).
|
||||||
|
@ -46,22 +46,22 @@ func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &n, nil
|
return &namespace, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DestroyNamespace destroys a Namespace. Returns error if the Namespace does
|
// DestroyNamespace destroys a Namespace. Returns error if the Namespace does
|
||||||
// not exist or if there are machines associated with it.
|
// not exist or if there are machines associated with it.
|
||||||
func (h *Headscale) DestroyNamespace(name string) error {
|
func (h *Headscale) DestroyNamespace(name string) error {
|
||||||
n, err := h.GetNamespace(name)
|
namespace, err := h.GetNamespace(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorNamespaceNotFound
|
return errorNamespaceNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := h.ListMachinesInNamespace(name)
|
machines, err := h.ListMachinesInNamespace(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(m) > 0 {
|
if len(machines) > 0 {
|
||||||
return errorNamespaceNotEmptyOfNodes
|
return errorNamespaceNotEmptyOfNodes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,14 +69,14 @@ func (h *Headscale) DestroyNamespace(name string) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, p := range keys {
|
for _, key := range keys {
|
||||||
err = h.DestroyPreAuthKey(&p)
|
err = h.DestroyPreAuthKey(&key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if result := h.db.Unscoped().Delete(&n); result.Error != nil {
|
if result := h.db.Unscoped().Delete(&namespace); result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ func (h *Headscale) DestroyNamespace(name string) error {
|
||||||
// RenameNamespace renames a Namespace. Returns error if the Namespace does
|
// RenameNamespace renames a Namespace. Returns error if the Namespace does
|
||||||
// not exist or if another Namespace exists with the new name.
|
// not exist or if another Namespace exists with the new name.
|
||||||
func (h *Headscale) RenameNamespace(oldName, newName string) error {
|
func (h *Headscale) RenameNamespace(oldName, newName string) error {
|
||||||
n, err := h.GetNamespace(oldName)
|
oldNamespace, err := h.GetNamespace(oldName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -98,13 +98,13 @@ func (h *Headscale) RenameNamespace(oldName, newName string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
n.Name = newName
|
oldNamespace.Name = newName
|
||||||
|
|
||||||
if result := h.db.Save(&n); result.Error != nil {
|
if result := h.db.Save(&oldNamespace); result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.RequestMapUpdates(n.ID)
|
err = h.RequestMapUpdates(oldNamespace.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -114,15 +114,15 @@ func (h *Headscale) RenameNamespace(oldName, newName string) error {
|
||||||
|
|
||||||
// GetNamespace fetches a namespace by name.
|
// GetNamespace fetches a namespace by name.
|
||||||
func (h *Headscale) GetNamespace(name string) (*Namespace, error) {
|
func (h *Headscale) GetNamespace(name string) (*Namespace, error) {
|
||||||
n := Namespace{}
|
namespace := Namespace{}
|
||||||
if result := h.db.First(&n, "name = ?", name); errors.Is(
|
if result := h.db.First(&namespace, "name = ?", name); errors.Is(
|
||||||
result.Error,
|
result.Error,
|
||||||
gorm.ErrRecordNotFound,
|
gorm.ErrRecordNotFound,
|
||||||
) {
|
) {
|
||||||
return nil, errorNamespaceNotFound
|
return nil, errorNamespaceNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
return &n, nil
|
return &namespace, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListNamespaces gets all the existing namespaces.
|
// ListNamespaces gets all the existing namespaces.
|
||||||
|
@ -137,13 +137,13 @@ func (h *Headscale) ListNamespaces() ([]Namespace, error) {
|
||||||
|
|
||||||
// ListMachinesInNamespace gets all the nodes in a given namespace.
|
// ListMachinesInNamespace gets all the nodes in a given namespace.
|
||||||
func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) {
|
func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) {
|
||||||
n, err := h.GetNamespace(name)
|
namespace, err := h.GetNamespace(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
machines := []Machine{}
|
machines := []Machine{}
|
||||||
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil {
|
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: namespace.ID}).Find(&machines).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -176,17 +176,18 @@ func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMachineNamespace assigns a Machine to a namespace.
|
// SetMachineNamespace assigns a Machine to a namespace.
|
||||||
func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error {
|
func (h *Headscale) SetMachineNamespace(machine *Machine, namespaceName string) error {
|
||||||
n, err := h.GetNamespace(namespaceName)
|
namespace, err := h.GetNamespace(namespaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
m.NamespaceID = n.ID
|
machine.NamespaceID = namespace.ID
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): Remove the need for this.
|
||||||
// RequestMapUpdates signals the KV worker to update the maps for this namespace.
|
// RequestMapUpdates signals the KV worker to update the maps for this namespace.
|
||||||
func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
|
func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
|
||||||
namespace := Namespace{}
|
namespace := Namespace{}
|
||||||
|
@ -194,8 +195,8 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
v, err := h.getValue("namespaces_pending_updates")
|
namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates")
|
||||||
if err != nil || v == "" {
|
if err != nil || namespacesPendingUpdates == "" {
|
||||||
err = h.setValue(
|
err = h.setValue(
|
||||||
"namespaces_pending_updates",
|
"namespaces_pending_updates",
|
||||||
fmt.Sprintf(`["%s"]`, namespace.Name),
|
fmt.Sprintf(`["%s"]`, namespace.Name),
|
||||||
|
@ -207,7 +208,7 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
names := []string{}
|
names := []string{}
|
||||||
err = json.Unmarshal([]byte(v), &names)
|
err = json.Unmarshal([]byte(namespacesPendingUpdates), &names)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = h.setValue(
|
err = h.setValue(
|
||||||
"namespaces_pending_updates",
|
"namespaces_pending_updates",
|
||||||
|
@ -235,16 +236,16 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) checkForNamespacesPendingUpdates() {
|
func (h *Headscale) checkForNamespacesPendingUpdates() {
|
||||||
v, err := h.getValue("namespaces_pending_updates")
|
namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if v == "" {
|
if namespacesPendingUpdates == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
namespaces := []string{}
|
namespaces := []string{}
|
||||||
err = json.Unmarshal([]byte(v), &namespaces)
|
err = json.Unmarshal([]byte(namespacesPendingUpdates), &namespaces)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -255,11 +256,11 @@ func (h *Headscale) checkForNamespacesPendingUpdates() {
|
||||||
Msg("Sending updates to nodes in namespacespace")
|
Msg("Sending updates to nodes in namespacespace")
|
||||||
h.setLastStateChangeToNow(namespace)
|
h.setLastStateChangeToNow(namespace)
|
||||||
}
|
}
|
||||||
newV, err := h.getValue("namespaces_pending_updates")
|
newPendingUpdateValue, err := h.getValue("namespaces_pending_updates")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if v == newV { // only clear when no changes, so we notified everybody
|
if namespacesPendingUpdates == newPendingUpdateValue { // only clear when no changes, so we notified everybody
|
||||||
err = h.setValue("namespaces_pending_updates", "")
|
err = h.setValue("namespaces_pending_updates", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -273,7 +274,7 @@ func (h *Headscale) checkForNamespacesPendingUpdates() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Namespace) toUser() *tailcfg.User {
|
func (n *Namespace) toUser() *tailcfg.User {
|
||||||
u := tailcfg.User{
|
user := tailcfg.User{
|
||||||
ID: tailcfg.UserID(n.ID),
|
ID: tailcfg.UserID(n.ID),
|
||||||
LoginName: n.Name,
|
LoginName: n.Name,
|
||||||
DisplayName: n.Name,
|
DisplayName: n.Name,
|
||||||
|
@ -283,11 +284,11 @@ func (n *Namespace) toUser() *tailcfg.User {
|
||||||
Created: time.Time{},
|
Created: time.Time{},
|
||||||
}
|
}
|
||||||
|
|
||||||
return &u
|
return &user
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Namespace) toLogin() *tailcfg.Login {
|
func (n *Namespace) toLogin() *tailcfg.Login {
|
||||||
l := tailcfg.Login{
|
login := tailcfg.Login{
|
||||||
ID: tailcfg.LoginID(n.ID),
|
ID: tailcfg.LoginID(n.ID),
|
||||||
LoginName: n.Name,
|
LoginName: n.Name,
|
||||||
DisplayName: n.Name,
|
DisplayName: n.Name,
|
||||||
|
@ -295,14 +296,14 @@ func (n *Namespace) toLogin() *tailcfg.Login {
|
||||||
Domain: "headscale.net",
|
Domain: "headscale.net",
|
||||||
}
|
}
|
||||||
|
|
||||||
return &l
|
return &login
|
||||||
}
|
}
|
||||||
|
|
||||||
func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile {
|
func getMapResponseUserProfiles(machine Machine, peers Machines) []tailcfg.UserProfile {
|
||||||
namespaceMap := make(map[string]Namespace)
|
namespaceMap := make(map[string]Namespace)
|
||||||
namespaceMap[m.Namespace.Name] = m.Namespace
|
namespaceMap[machine.Namespace.Name] = machine.Namespace
|
||||||
for _, p := range peers {
|
for _, peer := range peers {
|
||||||
namespaceMap[p.Namespace.Name] = p.Namespace // not worth checking if already is there
|
namespaceMap[peer.Namespace.Name] = peer.Namespace // not worth checking if already is there
|
||||||
}
|
}
|
||||||
|
|
||||||
profiles := []tailcfg.UserProfile{}
|
profiles := []tailcfg.UserProfile{}
|
||||||
|
|
69
oidc.go
69
oidc.go
|
@ -68,10 +68,10 @@ func (h *Headscale) initOIDC() error {
|
||||||
// RegisterOIDC redirects to the OIDC provider for authentication
|
// RegisterOIDC redirects to the OIDC provider for authentication
|
||||||
// Puts machine key in cache so the callback can retrieve it using the oidc state param
|
// Puts machine key in cache so the callback can retrieve it using the oidc state param
|
||||||
// Listens in /oidc/register/:mKey.
|
// Listens in /oidc/register/:mKey.
|
||||||
func (h *Headscale) RegisterOIDC(c *gin.Context) {
|
func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
|
||||||
mKeyStr := c.Param("mkey")
|
mKeyStr := ctx.Param("mkey")
|
||||||
if mKeyStr == "" {
|
if mKeyStr == "" {
|
||||||
c.String(http.StatusBadRequest, "Wrong params")
|
ctx.String(http.StatusBadRequest, "Wrong params")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -79,7 +79,7 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
|
||||||
b := make([]byte, RANDOM_BYTE_SIZE)
|
b := make([]byte, RANDOM_BYTE_SIZE)
|
||||||
if _, err := rand.Read(b); err != nil {
|
if _, err := rand.Read(b); err != nil {
|
||||||
log.Error().Msg("could not read 16 bytes from rand")
|
log.Error().Msg("could not read 16 bytes from rand")
|
||||||
c.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
|
ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -92,7 +92,7 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
|
||||||
authUrl := h.oauth2Config.AuthCodeURL(stateStr)
|
authUrl := h.oauth2Config.AuthCodeURL(stateStr)
|
||||||
log.Debug().Msgf("Redirecting to %s for authentication", authUrl)
|
log.Debug().Msgf("Redirecting to %s for authentication", authUrl)
|
||||||
|
|
||||||
c.Redirect(http.StatusFound, authUrl)
|
ctx.Redirect(http.StatusFound, authUrl)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OIDCCallback handles the callback from the OIDC endpoint
|
// OIDCCallback handles the callback from the OIDC endpoint
|
||||||
|
@ -100,19 +100,19 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
|
||||||
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
|
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
|
||||||
// TODO: Add groups information from OIDC tokens into machine HostInfo
|
// TODO: Add groups information from OIDC tokens into machine HostInfo
|
||||||
// Listens in /oidc/callback.
|
// Listens in /oidc/callback.
|
||||||
func (h *Headscale) OIDCCallback(c *gin.Context) {
|
func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
||||||
code := c.Query("code")
|
code := ctx.Query("code")
|
||||||
state := c.Query("state")
|
state := ctx.Query("state")
|
||||||
|
|
||||||
if code == "" || state == "" {
|
if code == "" || state == "" {
|
||||||
c.String(http.StatusBadRequest, "Wrong params")
|
ctx.String(http.StatusBadRequest, "Wrong params")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
|
oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.String(http.StatusBadRequest, "Could not exchange code for token")
|
ctx.String(http.StatusBadRequest, "Could not exchange code for token")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -121,7 +121,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
|
|
||||||
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
|
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
|
||||||
if !rawIDTokenOK {
|
if !rawIDTokenOK {
|
||||||
c.String(http.StatusBadRequest, "Could not extract ID Token")
|
ctx.String(http.StatusBadRequest, "Could not extract ID Token")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -130,7 +130,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
|
|
||||||
idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
|
ctx.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -145,7 +145,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
// Extract custom claims
|
// Extract custom claims
|
||||||
var claims IDTokenClaims
|
var claims IDTokenClaims
|
||||||
if err = idToken.Claims(&claims); err != nil {
|
if err = idToken.Claims(&claims); err != nil {
|
||||||
c.String(
|
ctx.String(
|
||||||
http.StatusBadRequest,
|
http.StatusBadRequest,
|
||||||
fmt.Sprintf("Failed to decode id token claims: %s", err),
|
fmt.Sprintf("Failed to decode id token claims: %s", err),
|
||||||
)
|
)
|
||||||
|
@ -159,7 +159,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
if !mKeyFound {
|
if !mKeyFound {
|
||||||
log.Error().
|
log.Error().
|
||||||
Msg("requested machine state key expired before authorisation completed")
|
Msg("requested machine state key expired before authorisation completed")
|
||||||
c.String(http.StatusBadRequest, "state has expired")
|
ctx.String(http.StatusBadRequest, "state has expired")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -167,16 +167,19 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
|
|
||||||
if !mKeyOK {
|
if !mKeyOK {
|
||||||
log.Error().Msg("could not get machine key from cache")
|
log.Error().Msg("could not get machine key from cache")
|
||||||
c.String(http.StatusInternalServerError, "could not get machine key from cache")
|
ctx.String(
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
"could not get machine key from cache",
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// retrieve machine information
|
// retrieve machine information
|
||||||
m, err := h.GetMachineByMachineKey(mKeyStr)
|
machine, err := h.GetMachineByMachineKey(mKeyStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Msg("machine key not found in database")
|
log.Error().Msg("machine key not found in database")
|
||||||
c.String(
|
ctx.String(
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
"could not get machine info from database",
|
"could not get machine info from database",
|
||||||
)
|
)
|
||||||
|
@ -186,19 +189,19 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
|
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
|
||||||
if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok {
|
if namespaceName, ok := h.getNamespaceFromEmail(claims.Email); ok {
|
||||||
// register the machine if it's new
|
// register the machine if it's new
|
||||||
if !m.Registered {
|
if !machine.Registered {
|
||||||
log.Debug().Msg("Registering new machine after successful callback")
|
log.Debug().Msg("Registering new machine after successful callback")
|
||||||
|
|
||||||
ns, err := h.GetNamespace(nsName)
|
namespace, err := h.GetNamespace(namespaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ns, err = h.CreateNamespace(nsName)
|
namespace, err = h.CreateNamespace(namespaceName)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Msgf("could not create new namespace '%s'", claims.Email)
|
Msgf("could not create new namespace '%s'", claims.Email)
|
||||||
c.String(
|
ctx.String(
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
"could not create new namespace",
|
"could not create new namespace",
|
||||||
)
|
)
|
||||||
|
@ -209,7 +212,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
|
|
||||||
ip, err := h.getAvailableIP()
|
ip, err := h.getAvailableIP()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.String(
|
ctx.String(
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
"could not get an IP from the pool",
|
"could not get an IP from the pool",
|
||||||
)
|
)
|
||||||
|
@ -217,17 +220,17 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
m.IPAddress = ip.String()
|
machine.IPAddress = ip.String()
|
||||||
m.NamespaceID = ns.ID
|
machine.NamespaceID = namespace.ID
|
||||||
m.Registered = true
|
machine.Registered = true
|
||||||
m.RegisterMethod = "oidc"
|
machine.RegisterMethod = "oidc"
|
||||||
m.LastSuccessfulUpdate = &now
|
machine.LastSuccessfulUpdate = &now
|
||||||
h.db.Save(&m)
|
h.db.Save(&machine)
|
||||||
}
|
}
|
||||||
|
|
||||||
h.updateMachineExpiry(m)
|
h.updateMachineExpiry(machine)
|
||||||
|
|
||||||
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
||||||
<html>
|
<html>
|
||||||
<body>
|
<body>
|
||||||
<h1>headscale</h1>
|
<h1>headscale</h1>
|
||||||
|
@ -243,9 +246,9 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("email", claims.Email).
|
Str("email", claims.Email).
|
||||||
Str("username", claims.Username).
|
Str("username", claims.Username).
|
||||||
Str("machine", m.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("Email could not be mapped to a namespace")
|
Msg("Email could not be mapped to a namespace")
|
||||||
c.String(
|
ctx.String(
|
||||||
http.StatusBadRequest,
|
http.StatusBadRequest,
|
||||||
"email from claim could not be mapped to a namespace",
|
"email from claim could not be mapped to a namespace",
|
||||||
)
|
)
|
||||||
|
|
8
poll.go
8
poll.go
|
@ -233,7 +233,7 @@ func (h *Headscale) PollNetMapStream(
|
||||||
) {
|
) {
|
||||||
go h.scheduledPollWorker(cancelKeepAlive, updateChan, keepAliveChan, mKey, req, m)
|
go h.scheduledPollWorker(cancelKeepAlive, updateChan, keepAliveChan, mKey, req, m)
|
||||||
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(writer io.Writer) bool {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", m.Name).
|
||||||
|
@ -252,7 +252,7 @@ func (h *Headscale) PollNetMapStream(
|
||||||
Str("channel", "pollData").
|
Str("channel", "pollData").
|
||||||
Int("bytes", len(data)).
|
Int("bytes", len(data)).
|
||||||
Msg("Sending data received via pollData channel")
|
Msg("Sending data received via pollData channel")
|
||||||
_, err := w.Write(data)
|
_, err := writer.Write(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
|
@ -305,7 +305,7 @@ func (h *Headscale) PollNetMapStream(
|
||||||
Str("channel", "keepAlive").
|
Str("channel", "keepAlive").
|
||||||
Int("bytes", len(data)).
|
Int("bytes", len(data)).
|
||||||
Msg("Sending keep alive message")
|
Msg("Sending keep alive message")
|
||||||
_, err := w.Write(data)
|
_, err := writer.Write(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
|
@ -370,7 +370,7 @@ func (h *Headscale) PollNetMapStream(
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not get the map update")
|
Msg("Could not get the map update")
|
||||||
}
|
}
|
||||||
_, err = w.Write(data)
|
_, err = writer.Write(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
|
|
32
sharing.go
32
sharing.go
|
@ -18,13 +18,16 @@ type SharedMachine struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddSharedMachineToNamespace adds a machine as a shared node to a namespace.
|
// AddSharedMachineToNamespace adds a machine as a shared node to a namespace.
|
||||||
func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error {
|
func (h *Headscale) AddSharedMachineToNamespace(
|
||||||
if m.NamespaceID == ns.ID {
|
machine *Machine,
|
||||||
|
namespace *Namespace,
|
||||||
|
) error {
|
||||||
|
if machine.NamespaceID == namespace.ID {
|
||||||
return errorSameNamespace
|
return errorSameNamespace
|
||||||
}
|
}
|
||||||
|
|
||||||
sharedMachines := []SharedMachine{}
|
sharedMachines := []SharedMachine{}
|
||||||
if err := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).Find(&sharedMachines).Error; err != nil {
|
if err := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID).Find(&sharedMachines).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(sharedMachines) > 0 {
|
if len(sharedMachines) > 0 {
|
||||||
|
@ -32,10 +35,10 @@ func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error
|
||||||
}
|
}
|
||||||
|
|
||||||
sharedMachine := SharedMachine{
|
sharedMachine := SharedMachine{
|
||||||
MachineID: m.ID,
|
MachineID: machine.ID,
|
||||||
Machine: *m,
|
Machine: *machine,
|
||||||
NamespaceID: ns.ID,
|
NamespaceID: namespace.ID,
|
||||||
Namespace: *ns,
|
Namespace: *namespace,
|
||||||
}
|
}
|
||||||
h.db.Save(&sharedMachine)
|
h.db.Save(&sharedMachine)
|
||||||
|
|
||||||
|
@ -43,14 +46,17 @@ func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveSharedMachineFromNamespace removes a shared machine from a namespace.
|
// RemoveSharedMachineFromNamespace removes a shared machine from a namespace.
|
||||||
func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace) error {
|
func (h *Headscale) RemoveSharedMachineFromNamespace(
|
||||||
if m.NamespaceID == ns.ID {
|
machine *Machine,
|
||||||
|
namespace *Namespace,
|
||||||
|
) error {
|
||||||
|
if machine.NamespaceID == namespace.ID {
|
||||||
// Can't unshare from primary namespace
|
// Can't unshare from primary namespace
|
||||||
return errorMachineNotShared
|
return errorMachineNotShared
|
||||||
}
|
}
|
||||||
|
|
||||||
sharedMachine := SharedMachine{}
|
sharedMachine := SharedMachine{}
|
||||||
result := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).
|
result := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID).
|
||||||
Unscoped().
|
Unscoped().
|
||||||
Delete(&sharedMachine)
|
Delete(&sharedMachine)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
|
@ -61,7 +67,7 @@ func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace)
|
||||||
return errorMachineNotShared
|
return errorMachineNotShared
|
||||||
}
|
}
|
||||||
|
|
||||||
err := h.RequestMapUpdates(ns.ID)
|
err := h.RequestMapUpdates(namespace.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -70,9 +76,9 @@ func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces.
|
// RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces.
|
||||||
func (h *Headscale) RemoveSharedMachineFromAllNamespaces(m *Machine) error {
|
func (h *Headscale) RemoveSharedMachineFromAllNamespaces(machine *Machine) error {
|
||||||
sharedMachine := SharedMachine{}
|
sharedMachine := SharedMachine{}
|
||||||
if result := h.db.Where("machine_id = ?", m.ID).Unscoped().Delete(&sharedMachine); result.Error != nil {
|
if result := h.db.Where("machine_id = ?", machine.ID).Unscoped().Delete(&sharedMachine); result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
14
swagger.go
14
swagger.go
|
@ -13,8 +13,8 @@ import (
|
||||||
//go:embed gen/openapiv2/headscale/v1/headscale.swagger.json
|
//go:embed gen/openapiv2/headscale/v1/headscale.swagger.json
|
||||||
var apiV1JSON []byte
|
var apiV1JSON []byte
|
||||||
|
|
||||||
func SwaggerUI(c *gin.Context) {
|
func SwaggerUI(ctx *gin.Context) {
|
||||||
t := template.Must(template.New("swagger").Parse(`
|
swaggerTemplate := template.Must(template.New("swagger").Parse(`
|
||||||
<html>
|
<html>
|
||||||
<head>
|
<head>
|
||||||
<link rel="stylesheet" type="text/css" href="https://unpkg.com/swagger-ui-dist@3/swagger-ui.css">
|
<link rel="stylesheet" type="text/css" href="https://unpkg.com/swagger-ui-dist@3/swagger-ui.css">
|
||||||
|
@ -47,12 +47,12 @@ func SwaggerUI(c *gin.Context) {
|
||||||
</html>`))
|
</html>`))
|
||||||
|
|
||||||
var payload bytes.Buffer
|
var payload bytes.Buffer
|
||||||
if err := t.Execute(&payload, struct{}{}); err != nil {
|
if err := swaggerTemplate.Execute(&payload, struct{}{}); err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not render Swagger")
|
Msg("Could not render Swagger")
|
||||||
c.Data(
|
ctx.Data(
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
"text/html; charset=utf-8",
|
"text/html; charset=utf-8",
|
||||||
[]byte("Could not render Swagger"),
|
[]byte("Could not render Swagger"),
|
||||||
|
@ -61,9 +61,9 @@ func SwaggerUI(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
|
ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
func SwaggerAPIv1(c *gin.Context) {
|
func SwaggerAPIv1(ctx *gin.Context) {
|
||||||
c.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON)
|
ctx.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON)
|
||||||
}
|
}
|
||||||
|
|
12
utils.go
12
utils.go
|
@ -36,7 +36,7 @@ func decode(
|
||||||
|
|
||||||
func decodeMsg(
|
func decodeMsg(
|
||||||
msg []byte,
|
msg []byte,
|
||||||
v interface{},
|
output interface{},
|
||||||
pubKey *wgkey.Key,
|
pubKey *wgkey.Key,
|
||||||
privKey *wgkey.Private,
|
privKey *wgkey.Private,
|
||||||
) error {
|
) error {
|
||||||
|
@ -45,7 +45,7 @@ func decodeMsg(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// fmt.Println(string(decrypted))
|
// fmt.Println(string(decrypted))
|
||||||
if err := json.Unmarshal(decrypted, v); err != nil {
|
if err := json.Unmarshal(decrypted, output); err != nil {
|
||||||
return fmt.Errorf("response: %v", err)
|
return fmt.Errorf("response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,13 +78,17 @@ func encode(v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, e
|
||||||
return encodeMsg(b, pubKey, privKey)
|
return encodeMsg(b, pubKey, privKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) {
|
func encodeMsg(
|
||||||
|
payload []byte,
|
||||||
|
pubKey *wgkey.Key,
|
||||||
|
privKey *wgkey.Private,
|
||||||
|
) ([]byte, error) {
|
||||||
var nonce [24]byte
|
var nonce [24]byte
|
||||||
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
|
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
|
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
|
||||||
msg := box.Seal(nonce[:], b, &nonce, pub, pri)
|
msg := box.Seal(nonce[:], payload, &nonce, pub, pri)
|
||||||
|
|
||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue