diff --git a/.golangci.yaml b/.golangci.yaml
index b4ad089..f7fc0c6 100644
--- a/.golangci.yaml
+++ b/.golangci.yaml
@@ -28,6 +28,9 @@ linters:
# In progress
- gocritic
+ # TODO: approve: ok, db, id
+ - varnamelen
+
# We should strive to enable these:
- testpackage
- stylecheck
@@ -39,7 +42,6 @@ linters:
- gosec
- forbidigo
- dupl
- - varnamelen
- makezero
- paralleltest
diff --git a/acls.go b/acls.go
index fdcb098..01de114 100644
--- a/acls.go
+++ b/acls.go
@@ -41,18 +41,18 @@ func (h *Headscale) LoadACLPolicy(path string) error {
defer policyFile.Close()
var policy ACLPolicy
- b, err := io.ReadAll(policyFile)
+ policyBytes, err := io.ReadAll(policyFile)
if err != nil {
return err
}
- ast, err := hujson.Parse(b)
+ ast, err := hujson.Parse(policyBytes)
if err != nil {
return err
}
ast.Standardize()
- b = ast.Pack()
- err = json.Unmarshal(b, &policy)
+ policyBytes = ast.Pack()
+ err = json.Unmarshal(policyBytes, &policy)
if err != nil {
return err
}
@@ -73,32 +73,32 @@ func (h *Headscale) LoadACLPolicy(path string) error {
func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
rules := []tailcfg.FilterRule{}
- for i, a := range h.aclPolicy.ACLs {
- if a.Action != "accept" {
+ for index, acl := range h.aclPolicy.ACLs {
+ if acl.Action != "accept" {
return nil, errorInvalidAction
}
- r := tailcfg.FilterRule{}
+ filterRule := tailcfg.FilterRule{}
srcIPs := []string{}
- for j, u := range a.Users {
- srcs, err := h.generateACLPolicySrcIP(u)
+ for innerIndex, user := range acl.Users {
+ srcs, err := h.generateACLPolicySrcIP(user)
if err != nil {
log.Error().
- Msgf("Error parsing ACL %d, User %d", i, j)
+ Msgf("Error parsing ACL %d, User %d", index, innerIndex)
return nil, err
}
srcIPs = append(srcIPs, srcs...)
}
- r.SrcIPs = srcIPs
+ filterRule.SrcIPs = srcIPs
destPorts := []tailcfg.NetPortRange{}
- for j, d := range a.Ports {
- dests, err := h.generateACLPolicyDestPorts(d)
+ for innerIndex, ports := range acl.Ports {
+ dests, err := h.generateACLPolicyDestPorts(ports)
if err != nil {
log.Error().
- Msgf("Error parsing ACL %d, Port %d", i, j)
+ Msgf("Error parsing ACL %d, Port %d", index, innerIndex)
return nil, err
}
@@ -162,17 +162,17 @@ func (h *Headscale) generateACLPolicyDestPorts(
return dests, nil
}
-func (h *Headscale) expandAlias(s string) ([]string, error) {
- if s == "*" {
+func (h *Headscale) expandAlias(alias string) ([]string, error) {
+ if alias == "*" {
return []string{"*"}, nil
}
- if strings.HasPrefix(s, "group:") {
- if _, ok := h.aclPolicy.Groups[s]; !ok {
+ if strings.HasPrefix(alias, "group:") {
+ if _, ok := h.aclPolicy.Groups[alias]; !ok {
return nil, errorInvalidGroup
}
ips := []string{}
- for _, n := range h.aclPolicy.Groups[s] {
+ for _, n := range h.aclPolicy.Groups[alias] {
nodes, err := h.ListMachinesInNamespace(n)
if err != nil {
return nil, errorInvalidNamespace
@@ -185,8 +185,8 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
return ips, nil
}
- if strings.HasPrefix(s, "tag:") {
- if _, ok := h.aclPolicy.TagOwners[s]; !ok {
+ if strings.HasPrefix(alias, "tag:") {
+ if _, ok := h.aclPolicy.TagOwners[alias]; !ok {
return nil, errorInvalidTag
}
@@ -197,10 +197,10 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
return nil, err
}
ips := []string{}
- for _, m := range machines {
+ for _, machine := range machines {
hostinfo := tailcfg.Hostinfo{}
- if len(m.HostInfo) != 0 {
- hi, err := m.HostInfo.MarshalJSON()
+ if len(machine.HostInfo) != 0 {
+ hi, err := machine.HostInfo.MarshalJSON()
if err != nil {
return nil, err
}
@@ -211,8 +211,8 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
// FIXME: Check TagOwners allows this
for _, t := range hostinfo.RequestTags {
- if s[4:] == t {
- ips = append(ips, m.IPAddress)
+ if alias[4:] == t {
+ ips = append(ips, machine.IPAddress)
break
}
@@ -223,7 +223,7 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
return ips, nil
}
- n, err := h.GetNamespace(s)
+ n, err := h.GetNamespace(alias)
if err == nil {
nodes, err := h.ListMachinesInNamespace(n.Name)
if err != nil {
@@ -237,16 +237,16 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
return ips, nil
}
- if h, ok := h.aclPolicy.Hosts[s]; ok {
+ if h, ok := h.aclPolicy.Hosts[alias]; ok {
return []string{h.String()}, nil
}
- ip, err := netaddr.ParseIP(s)
+ ip, err := netaddr.ParseIP(alias)
if err == nil {
return []string{ip.String()}, nil
}
- cidr, err := netaddr.ParseIPPrefix(s)
+ cidr, err := netaddr.ParseIPPrefix(alias)
if err == nil {
return []string{cidr.String()}, nil
}
@@ -254,25 +254,25 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
return nil, errorInvalidUserSection
}
-func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) {
- if s == "*" {
+func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
+ if portsStr == "*" {
return &[]tailcfg.PortRange{
{First: PORT_RANGE_BEGIN, Last: PORT_RANGE_END},
}, nil
}
ports := []tailcfg.PortRange{}
- for _, p := range strings.Split(s, ",") {
- rang := strings.Split(p, "-")
+ for _, portStr := range strings.Split(portsStr, ",") {
+ rang := strings.Split(portStr, "-")
switch len(rang) {
case 1:
- pi, err := strconv.ParseUint(rang[0], BASE_10, BIT_SIZE_16)
+ port, err := strconv.ParseUint(rang[0], BASE_10, BIT_SIZE_16)
if err != nil {
return nil, err
}
ports = append(ports, tailcfg.PortRange{
- First: uint16(pi),
- Last: uint16(pi),
+ First: uint16(port),
+ Last: uint16(port),
})
case EXPECTED_TOKEN_ITEMS:
diff --git a/acls_types.go b/acls_types.go
index 8611d90..f6557bf 100644
--- a/acls_types.go
+++ b/acls_types.go
@@ -41,37 +41,37 @@ type ACLTest struct {
}
// UnmarshalJSON allows to parse the Hosts directly into netaddr objects.
-func (h *Hosts) UnmarshalJSON(data []byte) error {
- hosts := Hosts{}
- hs := make(map[string]string)
+func (hosts *Hosts) UnmarshalJSON(data []byte) error {
+ newHosts := Hosts{}
+ hostIpPrefixMap := make(map[string]string)
ast, err := hujson.Parse(data)
if err != nil {
return err
}
ast.Standardize()
data = ast.Pack()
- err = json.Unmarshal(data, &hs)
+ err = json.Unmarshal(data, &hostIpPrefixMap)
if err != nil {
return err
}
- for k, v := range hs {
- if !strings.Contains(v, "/") {
- v += "/32"
+ for host, prefixStr := range hostIpPrefixMap {
+ if !strings.Contains(prefixStr, "/") {
+ prefixStr += "/32"
}
- prefix, err := netaddr.ParseIPPrefix(v)
+ prefix, err := netaddr.ParseIPPrefix(prefixStr)
if err != nil {
return err
}
- hosts[k] = prefix
+ newHosts[host] = prefix
}
- *h = hosts
+ *hosts = newHosts
return nil
}
// IsZero is perhaps a bit naive here.
-func (p ACLPolicy) IsZero() bool {
- if len(p.Groups) == 0 && len(p.Hosts) == 0 && len(p.ACLs) == 0 {
+func (policy ACLPolicy) IsZero() bool {
+ if len(policy.Groups) == 0 && len(policy.Hosts) == 0 && len(policy.ACLs) == 0 {
return true
}
diff --git a/api.go b/api.go
index 211b486..fb8c49c 100644
--- a/api.go
+++ b/api.go
@@ -22,21 +22,25 @@ const RESERVED_RESPONSE_HEADER_SIZE = 4
// KeyHandler provides the Headscale pub key
// Listens in /key.
-func (h *Headscale) KeyHandler(c *gin.Context) {
- c.Data(http.StatusOK, "text/plain; charset=utf-8", []byte(h.publicKey.HexString()))
+func (h *Headscale) KeyHandler(ctx *gin.Context) {
+ ctx.Data(
+ http.StatusOK,
+ "text/plain; charset=utf-8",
+ []byte(h.publicKey.HexString()),
+ )
}
// RegisterWebAPI shows a simple message in the browser to point to the CLI
// Listens in /register.
-func (h *Headscale) RegisterWebAPI(c *gin.Context) {
- mKeyStr := c.Query("key")
- if mKeyStr == "" {
- c.String(http.StatusBadRequest, "Wrong params")
+func (h *Headscale) RegisterWebAPI(ctx *gin.Context) {
+ machineKeyStr := ctx.Query("key")
+ if machineKeyStr == "" {
+ ctx.String(http.StatusBadRequest, "Wrong params")
return
}
- c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
+ ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
headscale
@@ -53,45 +57,45 @@ func (h *Headscale) RegisterWebAPI(c *gin.Context) {
- `, mKeyStr)))
+ `, machineKeyStr)))
}
// RegistrationHandler handles the actual registration process of a machine
// Endpoint /machine/:id.
-func (h *Headscale) RegistrationHandler(c *gin.Context) {
- body, _ := io.ReadAll(c.Request.Body)
- mKeyStr := c.Param("id")
- mKey, err := wgkey.ParseHex(mKeyStr)
+func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
+ body, _ := io.ReadAll(ctx.Request.Body)
+ machineKeyStr := ctx.Param("id")
+ machineKey, err := wgkey.ParseHex(machineKeyStr)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot parse machine key")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
- c.String(http.StatusInternalServerError, "Sad!")
+ ctx.String(http.StatusInternalServerError, "Sad!")
return
}
req := tailcfg.RegisterRequest{}
- err = decode(body, &req, &mKey, h.privateKey)
+ err = decode(body, &req, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot decode message")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
- c.String(http.StatusInternalServerError, "Very sad!")
+ ctx.String(http.StatusInternalServerError, "Very sad!")
return
}
now := time.Now().UTC()
- m, err := h.GetMachineByMachineKey(mKey.HexString())
+ machine, err := h.GetMachineByMachineKey(machineKey.HexString())
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
newMachine := Machine{
Expiry: &time.Time{},
- MachineKey: mKey.HexString(),
+ MachineKey: machineKey.HexString(),
Name: req.Hostinfo.Hostname,
}
if err := h.db.Create(&newMachine).Error; err != nil {
@@ -99,16 +103,16 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Str("handler", "Registration").
Err(err).
Msg("Could not create row")
- machineRegistrations.WithLabelValues("unknown", "web", "error", m.Namespace.Name).
+ machineRegistrations.WithLabelValues("unknown", "web", "error", machine.Namespace.Name).
Inc()
return
}
- m = &newMachine
+ machine = &newMachine
}
- if !m.Registered && req.Auth.AuthKey != "" {
- h.handleAuthKey(c, h.db, mKey, req, *m)
+ if !machine.Registered && req.Auth.AuthKey != "" {
+ h.handleAuthKey(ctx, h.db, machineKey, req, *machine)
return
}
@@ -116,63 +120,63 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
resp := tailcfg.RegisterResponse{}
// We have the updated key!
- if m.NodeKey == wgkey.Key(req.NodeKey).HexString() {
+ if machine.NodeKey == wgkey.Key(req.NodeKey).HexString() {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
log.Info().
Str("handler", "Registration").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Client requested logout")
- m.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired
- h.db.Save(&m)
+ machine.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired
+ h.db.Save(&machine)
resp.AuthURL = ""
resp.MachineAuthorized = false
- resp.User = *m.Namespace.toUser()
- respBody, err := encode(resp, &mKey, h.privateKey)
+ resp.User = *machine.Namespace.toUser()
+ respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
- c.String(http.StatusInternalServerError, "")
+ ctx.String(http.StatusInternalServerError, "")
return
}
- c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
+ ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
return
}
- if m.Registered && m.Expiry.UTC().After(now) {
+ if machine.Registered && machine.Expiry.UTC().After(now) {
// The machine registration is valid, respond with redirect to /map
log.Debug().
Str("handler", "Registration").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Client is registered and we have the current NodeKey. All clear to /map")
resp.AuthURL = ""
resp.MachineAuthorized = true
- resp.User = *m.Namespace.toUser()
- resp.Login = *m.Namespace.toLogin()
+ resp.User = *machine.Namespace.toUser()
+ resp.Login = *machine.Namespace.toLogin()
- respBody, err := encode(resp, &mKey, h.privateKey)
+ respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
- machineRegistrations.WithLabelValues("update", "web", "error", m.Namespace.Name).
+ machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name).
Inc()
- c.String(http.StatusInternalServerError, "")
+ ctx.String(http.StatusInternalServerError, "")
return
}
- machineRegistrations.WithLabelValues("update", "web", "success", m.Namespace.Name).
+ machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name).
Inc()
- c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
+ ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
return
}
@@ -180,15 +184,15 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
// The client has registered before, but has expired
log.Debug().
Str("handler", "Registration").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Machine registration has expired. Sending a authurl to register")
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
- strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
+ strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
- strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
+ strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
}
// When a client connects, it may request a specific expiry time in its
@@ -197,51 +201,52 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
// into two steps (which cant pass arbitrary data between them easily) and needs to be
// retrieved again after the user has authenticated. After the authentication flow
// completes, RequestedExpiry is copied into Expiry.
- m.RequestedExpiry = &req.Expiry
+ machine.RequestedExpiry = &req.Expiry
- h.db.Save(&m)
+ h.db.Save(&machine)
- respBody, err := encode(resp, &mKey, h.privateKey)
+ respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
- machineRegistrations.WithLabelValues("new", "web", "error", m.Namespace.Name).
+ machineRegistrations.WithLabelValues("new", "web", "error", machine.Namespace.Name).
Inc()
- c.String(http.StatusInternalServerError, "")
+ ctx.String(http.StatusInternalServerError, "")
return
}
- machineRegistrations.WithLabelValues("new", "web", "success", m.Namespace.Name).
+ machineRegistrations.WithLabelValues("new", "web", "success", machine.Namespace.Name).
Inc()
- c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
+ ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
return
}
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
- if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && m.Expiry.UTC().After(now) {
+ if machine.NodeKey == wgkey.Key(req.OldNodeKey).HexString() &&
+ machine.Expiry.UTC().After(now) {
log.Debug().
Str("handler", "Registration").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("We have the OldNodeKey in the database. This is a key refresh")
- m.NodeKey = wgkey.Key(req.NodeKey).HexString()
- h.db.Save(&m)
+ machine.NodeKey = wgkey.Key(req.NodeKey).HexString()
+ h.db.Save(&machine)
resp.AuthURL = ""
- resp.User = *m.Namespace.toUser()
- respBody, err := encode(resp, &mKey, h.privateKey)
+ resp.User = *machine.Namespace.toUser()
+ respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
- c.String(http.StatusInternalServerError, "Extremely sad!")
+ ctx.String(http.StatusInternalServerError, "Extremely sad!")
return
}
- c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
+ ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
return
}
@@ -249,47 +254,47 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
// The machine registration is new, redirect the client to the registration URL
log.Debug().
Str("handler", "Registration").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("The node is sending us a new NodeKey, sending auth url")
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf(
"%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
- mKey.HexString(),
+ machineKey.HexString(),
)
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
- strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
+ strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
}
// save the requested expiry time for retrieval later in the authentication flow
- m.RequestedExpiry = &req.Expiry
- m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey
- h.db.Save(&m)
+ machine.RequestedExpiry = &req.Expiry
+ machine.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey
+ h.db.Save(&machine)
- respBody, err := encode(resp, &mKey, h.privateKey)
+ respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
- c.String(http.StatusInternalServerError, "")
+ ctx.String(http.StatusInternalServerError, "")
return
}
- c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
+ ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
}
func (h *Headscale) getMapResponse(
- mKey wgkey.Key,
+ machineKey wgkey.Key,
req tailcfg.MapRequest,
- m *Machine,
+ machine *Machine,
) ([]byte, error) {
log.Trace().
Str("func", "getMapResponse").
Str("machine", req.Hostinfo.Hostname).
Msg("Creating Map response")
- node, err := m.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
+ node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil {
log.Error().
Str("func", "getMapResponse").
@@ -299,7 +304,7 @@ func (h *Headscale) getMapResponse(
return nil, err
}
- peers, err := h.getPeers(m)
+ peers, err := h.getPeers(machine)
if err != nil {
log.Error().
Str("func", "getMapResponse").
@@ -309,7 +314,7 @@ func (h *Headscale) getMapResponse(
return nil, err
}
- profiles := getMapResponseUserProfiles(*m, peers)
+ profiles := getMapResponseUserProfiles(*machine, peers)
nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil {
@@ -324,7 +329,7 @@ func (h *Headscale) getMapResponse(
dnsConfig := getMapResponseDNSConfig(
h.cfg.DNSConfig,
h.cfg.BaseDomain,
- *m,
+ *machine,
peers,
)
@@ -351,12 +356,12 @@ func (h *Headscale) getMapResponse(
encoder, _ := zstd.NewWriter(nil)
srcCompressed := encoder.EncodeAll(src, nil)
- respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey)
+ respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
} else {
- respBody, err = encode(resp, &mKey, h.privateKey)
+ respBody, err = encode(resp, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
@@ -370,24 +375,24 @@ func (h *Headscale) getMapResponse(
}
func (h *Headscale) getMapKeepAliveResponse(
- mKey wgkey.Key,
- req tailcfg.MapRequest,
+ machineKey wgkey.Key,
+ mapRequest tailcfg.MapRequest,
) ([]byte, error) {
- resp := tailcfg.MapResponse{
+ mapResponse := tailcfg.MapResponse{
KeepAlive: true,
}
var respBody []byte
var err error
- if req.Compress == "zstd" {
- src, _ := json.Marshal(resp)
+ if mapRequest.Compress == "zstd" {
+ src, _ := json.Marshal(mapResponse)
encoder, _ := zstd.NewWriter(nil)
srcCompressed := encoder.EncodeAll(src, nil)
- respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey)
+ respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
} else {
- respBody, err = encode(resp, &mKey, h.privateKey)
+ respBody, err = encode(mapResponse, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
@@ -400,22 +405,22 @@ func (h *Headscale) getMapKeepAliveResponse(
}
func (h *Headscale) handleAuthKey(
- c *gin.Context,
+ ctx *gin.Context,
db *gorm.DB,
idKey wgkey.Key,
- req tailcfg.RegisterRequest,
- m Machine,
+ reqisterRequest tailcfg.RegisterRequest,
+ machine Machine,
) {
log.Debug().
Str("func", "handleAuthKey").
- Str("machine", req.Hostinfo.Hostname).
- Msgf("Processing auth key for %s", req.Hostinfo.Hostname)
+ Str("machine", reqisterRequest.Hostinfo.Hostname).
+ Msgf("Processing auth key for %s", reqisterRequest.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{}
- pak, err := h.checkKeyValidity(req.Auth.AuthKey)
+ pak, err := h.checkKeyValidity(reqisterRequest.Auth.AuthKey)
if err != nil {
log.Error().
Str("func", "handleAuthKey").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Err(err).
Msg("Failed authentication via AuthKey")
resp.MachineAuthorized = false
@@ -423,21 +428,21 @@ func (h *Headscale) handleAuthKey(
if err != nil {
log.Error().
Str("func", "handleAuthKey").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Err(err).
Msg("Cannot encode message")
- c.String(http.StatusInternalServerError, "")
- machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
+ ctx.String(http.StatusInternalServerError, "")
+ machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
Inc()
return
}
- c.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody)
+ ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody)
log.Error().
Str("func", "handleAuthKey").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Failed authentication via AuthKey")
- machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
+ machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
Inc()
return
@@ -445,32 +450,34 @@ func (h *Headscale) handleAuthKey(
log.Debug().
Str("func", "handleAuthKey").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Authentication key was valid, proceeding to acquire an IP address")
ip, err := h.getAvailableIP()
if err != nil {
log.Error().
Str("func", "handleAuthKey").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Failed to find an available IP")
- machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
+ machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
Inc()
return
}
log.Info().
Str("func", "handleAuthKey").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("ip", ip.String()).
- Msgf("Assigning %s to %s", ip, m.Name)
+ Msgf("Assigning %s to %s", ip, machine.Name)
- m.AuthKeyID = uint(pak.ID)
- m.IPAddress = ip.String()
- m.NamespaceID = pak.NamespaceID
- m.NodeKey = wgkey.Key(req.NodeKey).HexString() // we update it just in case
- m.Registered = true
- m.RegisterMethod = "authKey"
- db.Save(&m)
+ machine.AuthKeyID = uint(pak.ID)
+ machine.IPAddress = ip.String()
+ machine.NamespaceID = pak.NamespaceID
+ machine.NodeKey = wgkey.Key(reqisterRequest.NodeKey).
+ HexString()
+ // we update it just in case
+ machine.Registered = true
+ machine.RegisterMethod = "authKey"
+ db.Save(&machine)
pak.Used = true
db.Save(&pak)
@@ -481,21 +488,21 @@ func (h *Headscale) handleAuthKey(
if err != nil {
log.Error().
Str("func", "handleAuthKey").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Err(err).
Msg("Cannot encode message")
- machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
+ machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
Inc()
- c.String(http.StatusInternalServerError, "Extremely sad!")
+ ctx.String(http.StatusInternalServerError, "Extremely sad!")
return
}
- machineRegistrations.WithLabelValues("new", "authkey", "success", m.Namespace.Name).
+ machineRegistrations.WithLabelValues("new", "authkey", "success", machine.Namespace.Name).
Inc()
- c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
+ ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
log.Info().
Str("func", "handleAuthKey").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("ip", ip.String()).
Msg("Successfully authenticated via AuthKey")
}
diff --git a/app.go b/app.go
index af51efd..a0aa238 100644
--- a/app.go
+++ b/app.go
@@ -169,7 +169,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
return nil, errors.New("unsupported DB")
}
- h := Headscale{
+ app := Headscale{
cfg: cfg,
dbType: cfg.DBtype,
dbString: dbString,
@@ -178,32 +178,32 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
aclRules: tailcfg.FilterAllowAll, // default allowall
}
- err = h.initDB()
+ err = app.initDB()
if err != nil {
return nil, err
}
if cfg.OIDC.Issuer != "" {
- err = h.initOIDC()
+ err = app.initOIDC()
if err != nil {
return nil, err
}
}
- if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS
+ if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS
magicDNSDomains := generateMagicDNSRootDomains(
- h.cfg.IPPrefix,
+ app.cfg.IPPrefix,
)
// we might have routes already from Split DNS
- if h.cfg.DNSConfig.Routes == nil {
- h.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver)
+ if app.cfg.DNSConfig.Routes == nil {
+ app.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver)
}
for _, d := range magicDNSDomains {
- h.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil
+ app.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil
}
}
- return &h, nil
+ return &app, nil
}
// Redirect to our TLS url.
@@ -229,35 +229,37 @@ func (h *Headscale) expireEphemeralNodesWorker() {
return
}
- for _, ns := range namespaces {
- machines, err := h.ListMachinesInNamespace(ns.Name)
+ for _, namespace := range namespaces {
+ machines, err := h.ListMachinesInNamespace(namespace.Name)
if err != nil {
log.Error().
Err(err).
- Str("namespace", ns.Name).
+ Str("namespace", namespace.Name).
Msg("Error listing machines in namespace")
return
}
- for _, m := range machines {
- if m.AuthKey != nil && m.LastSeen != nil && m.AuthKey.Ephemeral &&
- time.Now().After(m.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
+ for _, machine := range machines {
+ if machine.AuthKey != nil && machine.LastSeen != nil &&
+ machine.AuthKey.Ephemeral &&
+ time.Now().
+ After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
log.Info().
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Ephemeral client removed from database")
- err = h.db.Unscoped().Delete(m).Error
+ err = h.db.Unscoped().Delete(machine).Error
if err != nil {
log.Error().
Err(err).
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("🤮 Cannot delete ephemeral machine from the database")
}
}
}
- h.setLastStateChangeToNow(ns.Name)
+ h.setLastStateChangeToNow(namespace.Name)
}
}
@@ -284,18 +286,18 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
// with the "legacy" database-based client
// It is also neede for grpc-gateway to be able to connect to
// the server
- p, _ := peer.FromContext(ctx)
+ client, _ := peer.FromContext(ctx)
log.Trace().
Caller().
- Str("client_address", p.Addr.String()).
+ Str("client_address", client.Addr.String()).
Msg("Client is trying to authenticate")
- md, ok := metadata.FromIncomingContext(ctx)
+ meta, ok := metadata.FromIncomingContext(ctx)
if !ok {
log.Error().
Caller().
- Str("client_address", p.Addr.String()).
+ Str("client_address", client.Addr.String()).
Msg("Retrieving metadata is failed")
return ctx, status.Errorf(
@@ -304,11 +306,11 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
)
}
- authHeader, ok := md["authorization"]
+ authHeader, ok := meta["authorization"]
if !ok {
log.Error().
Caller().
- Str("client_address", p.Addr.String()).
+ Str("client_address", client.Addr.String()).
Msg("Authorization token is not supplied")
return ctx, status.Errorf(
@@ -322,7 +324,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
if !strings.HasPrefix(token, AUTH_PREFIX) {
log.Error().
Caller().
- Str("client_address", p.Addr.String()).
+ Str("client_address", client.Addr.String()).
Msg(`missing "Bearer " prefix in "Authorization" header`)
return ctx, status.Error(
@@ -353,25 +355,25 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
// return handler(ctx, req)
}
-func (h *Headscale) httpAuthenticationMiddleware(c *gin.Context) {
+func (h *Headscale) httpAuthenticationMiddleware(ctx *gin.Context) {
log.Trace().
Caller().
- Str("client_address", c.ClientIP()).
+ Str("client_address", ctx.ClientIP()).
Msg("HTTP authentication invoked")
- authHeader := c.GetHeader("authorization")
+ authHeader := ctx.GetHeader("authorization")
if !strings.HasPrefix(authHeader, AUTH_PREFIX) {
log.Error().
Caller().
- Str("client_address", c.ClientIP()).
+ Str("client_address", ctx.ClientIP()).
Msg(`missing "Bearer " prefix in "Authorization" header`)
- c.AbortWithStatus(http.StatusUnauthorized)
+ ctx.AbortWithStatus(http.StatusUnauthorized)
return
}
- c.AbortWithStatus(http.StatusUnauthorized)
+ ctx.AbortWithStatus(http.StatusUnauthorized)
// TODO(kradalby): Implement API key backend
// Currently all traffic is unauthorized, this is intentional to allow
@@ -438,9 +440,9 @@ func (h *Headscale) Serve() error {
// Create the cmux object that will multiplex 2 protocols on the same port.
// The two following listeners will be served on the same port below gracefully.
- m := cmux.New(networkListener)
+ networkMutex := cmux.New(networkListener)
// Match gRPC requests here
- grpcListener := m.MatchWithWriters(
+ grpcListener := networkMutex.MatchWithWriters(
cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"),
cmux.HTTP2MatchHeaderFieldSendSettings(
"content-type",
@@ -448,7 +450,7 @@ func (h *Headscale) Serve() error {
),
)
// Otherwise match regular http requests.
- httpListener := m.Match(cmux.Any())
+ httpListener := networkMutex.Match(cmux.Any())
grpcGatewayMux := runtime.NewServeMux()
@@ -471,33 +473,33 @@ func (h *Headscale) Serve() error {
return err
}
- r := gin.Default()
+ router := gin.Default()
- p := ginprometheus.NewPrometheus("gin")
- p.Use(r)
+ prometheus := ginprometheus.NewPrometheus("gin")
+ prometheus.Use(router)
- r.GET(
+ router.GET(
"/health",
func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) },
)
- r.GET("/key", h.KeyHandler)
- r.GET("/register", h.RegisterWebAPI)
- r.POST("/machine/:id/map", h.PollNetMapHandler)
- r.POST("/machine/:id", h.RegistrationHandler)
- r.GET("/oidc/register/:mkey", h.RegisterOIDC)
- r.GET("/oidc/callback", h.OIDCCallback)
- r.GET("/apple", h.AppleMobileConfig)
- r.GET("/apple/:platform", h.ApplePlatformConfig)
- r.GET("/swagger", SwaggerUI)
- r.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1)
+ router.GET("/key", h.KeyHandler)
+ router.GET("/register", h.RegisterWebAPI)
+ router.POST("/machine/:id/map", h.PollNetMapHandler)
+ router.POST("/machine/:id", h.RegistrationHandler)
+ router.GET("/oidc/register/:mkey", h.RegisterOIDC)
+ router.GET("/oidc/callback", h.OIDCCallback)
+ router.GET("/apple", h.AppleMobileConfig)
+ router.GET("/apple/:platform", h.ApplePlatformConfig)
+ router.GET("/swagger", SwaggerUI)
+ router.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1)
- api := r.Group("/api")
+ api := router.Group("/api")
api.Use(h.httpAuthenticationMiddleware)
{
api.Any("/v1/*any", gin.WrapF(grpcGatewayMux.ServeHTTP))
}
- r.NoRoute(stdoutHandler)
+ router.NoRoute(stdoutHandler)
// Fetch an initial DERP Map before we start serving
h.DERPMap = GetDERPMap(h.cfg.DERP)
@@ -514,7 +516,7 @@ func (h *Headscale) Serve() error {
httpServer := &http.Server{
Addr: h.cfg.Addr,
- Handler: r,
+ Handler: router,
ReadTimeout: HTTP_READ_TIMEOUT,
// Go does not handle timeouts in HTTP very well, and there is
// no good way to handle streaming timeouts, therefore we need to
@@ -561,29 +563,29 @@ func (h *Headscale) Serve() error {
reflection.Register(grpcServer)
reflection.Register(grpcSocket)
- g := new(errgroup.Group)
+ errorGroup := new(errgroup.Group)
- g.Go(func() error { return grpcSocket.Serve(socketListener) })
+ errorGroup.Go(func() error { return grpcSocket.Serve(socketListener) })
// TODO(kradalby): Verify if we need the same TLS setup for gRPC as HTTP
- g.Go(func() error { return grpcServer.Serve(grpcListener) })
+ errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) })
if tlsConfig != nil {
- g.Go(func() error {
+ errorGroup.Go(func() error {
tlsl := tls.NewListener(httpListener, tlsConfig)
return httpServer.Serve(tlsl)
})
} else {
- g.Go(func() error { return httpServer.Serve(httpListener) })
+ errorGroup.Go(func() error { return httpServer.Serve(httpListener) })
}
- g.Go(func() error { return m.Serve() })
+ errorGroup.Go(func() error { return networkMutex.Serve() })
log.Info().
Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr)
- return g.Wait()
+ return errorGroup.Wait()
}
func (h *Headscale) getTLSSettings() (*tls.Config, error) {
@@ -594,7 +596,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
Msg("Listening with TLS but ServerURL does not start with https://")
}
- m := autocert.Manager{
+ certManager := autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(h.cfg.TLSLetsEncryptHostname),
Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir),
@@ -609,7 +611,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
// The RFC requires that the validation is done on port 443; in other words, headscale
// must be reachable on port 443.
- return m.TLSConfig(), nil
+ return certManager.TLSConfig(), nil
case "HTTP-01":
// Configuration via autocert with HTTP-01. This requires listening on
@@ -617,11 +619,11 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
// service, which can be configured to run on any other port.
go func() {
log.Fatal().
- Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, m.HTTPHandler(http.HandlerFunc(h.redirect)))).
+ Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))).
Msg("failed to set up a HTTP server")
}()
- return m.TLSConfig(), nil
+ return certManager.TLSConfig(), nil
default:
return nil, errors.New("unknown value for TLSLetsEncryptChallengeType")
@@ -676,13 +678,13 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time {
}
}
-func stdoutHandler(c *gin.Context) {
- b, _ := io.ReadAll(c.Request.Body)
+func stdoutHandler(ctx *gin.Context) {
+ body, _ := io.ReadAll(ctx.Request.Body)
log.Trace().
- Interface("header", c.Request.Header).
- Interface("proto", c.Request.Proto).
- Interface("url", c.Request.URL).
- Bytes("body", b).
+ Interface("header", ctx.Request.Header).
+ Interface("proto", ctx.Request.Proto).
+ Interface("url", ctx.Request.URL).
+ Bytes("body", body).
Msg("Request did not match")
}
diff --git a/apple_mobileconfig.go b/apple_mobileconfig.go
index ff3ef3c..8989c27 100644
--- a/apple_mobileconfig.go
+++ b/apple_mobileconfig.go
@@ -12,8 +12,8 @@ import (
// AppleMobileConfig shows a simple message in the browser to point to the CLI
// Listens in /register.
-func (h *Headscale) AppleMobileConfig(c *gin.Context) {
- t := template.Must(template.New("apple").Parse(`
+func (h *Headscale) AppleMobileConfig(ctx *gin.Context) {
+ appleTemplate := template.Must(template.New("apple").Parse(`
Apple configuration profiles
@@ -67,12 +67,12 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {
}
var payload bytes.Buffer
- if err := t.Execute(&payload, config); err != nil {
+ if err := appleTemplate.Execute(&payload, config); err != nil {
log.Error().
Str("handler", "AppleMobileConfig").
Err(err).
Msg("Could not render Apple index template")
- c.Data(
+ ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Apple index template"),
@@ -81,11 +81,11 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {
return
}
- c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
+ ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
}
-func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
- platform := c.Param("platform")
+func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) {
+ platform := ctx.Param("platform")
id, err := uuid.NewV4()
if err != nil {
@@ -93,7 +93,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Failed not create UUID")
- c.Data(
+ ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Failed to create UUID"),
@@ -108,7 +108,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Failed not create UUID")
- c.Data(
+ ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Failed to create UUID"),
@@ -131,7 +131,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Could not render Apple macOS template")
- c.Data(
+ ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Apple macOS template"),
@@ -145,7 +145,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Could not render Apple iOS template")
- c.Data(
+ ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Apple iOS template"),
@@ -154,7 +154,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
return
}
default:
- c.Data(
+ ctx.Data(
http.StatusOK,
"text/html; charset=utf-8",
[]byte("Invalid platform, only ios and macos is supported"),
@@ -175,7 +175,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Could not render Apple platform template")
- c.Data(
+ ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Apple platform template"),
@@ -184,7 +184,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
return
}
- c.Data(
+ ctx.Data(
http.StatusOK,
"application/x-apple-aspen-config; charset=utf-8",
content.Bytes(),
diff --git a/cmd/headscale/cli/namespaces.go b/cmd/headscale/cli/namespaces.go
index 8c69ac5..8c6f10a 100644
--- a/cmd/headscale/cli/namespaces.go
+++ b/cmd/headscale/cli/namespaces.go
@@ -167,10 +167,10 @@ var listNamespacesCmd = &cobra.Command{
return
}
- d := pterm.TableData{{"ID", "Name", "Created"}}
+ tableData := pterm.TableData{{"ID", "Name", "Created"}}
for _, namespace := range response.GetNamespaces() {
- d = append(
- d,
+ tableData = append(
+ tableData,
[]string{
namespace.GetId(),
namespace.GetName(),
@@ -178,7 +178,7 @@ var listNamespacesCmd = &cobra.Command{
},
)
}
- err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
+ err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(
err,
diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go
index 83d6b1f..c644ad9 100644
--- a/cmd/headscale/cli/nodes.go
+++ b/cmd/headscale/cli/nodes.go
@@ -157,14 +157,14 @@ var listNodesCmd = &cobra.Command{
return
}
- d, err := nodesToPtables(namespace, response.Machines)
+ tableData, err := nodesToPtables(namespace, response.Machines)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return
}
- err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
+ err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(
err,
@@ -183,7 +183,7 @@ var deleteNodeCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
- id, err := cmd.Flags().GetInt("identifier")
+ identifier, err := cmd.Flags().GetInt("identifier")
if err != nil {
ErrorOutput(
err,
@@ -199,7 +199,7 @@ var deleteNodeCmd = &cobra.Command{
defer conn.Close()
getRequest := &v1.GetMachineRequest{
- MachineId: uint64(id),
+ MachineId: uint64(identifier),
}
getResponse, err := client.GetMachine(ctx, getRequest)
@@ -217,7 +217,7 @@ var deleteNodeCmd = &cobra.Command{
}
deleteRequest := &v1.DeleteMachineRequest{
- MachineId: uint64(id),
+ MachineId: uint64(identifier),
}
confirm := false
@@ -280,7 +280,7 @@ func sharingWorker(
defer cancel()
defer conn.Close()
- id, err := cmd.Flags().GetInt("identifier")
+ identifier, err := cmd.Flags().GetInt("identifier")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output)
@@ -288,7 +288,7 @@ func sharingWorker(
}
machineRequest := &v1.GetMachineRequest{
- MachineId: uint64(id),
+ MachineId: uint64(identifier),
}
machineResponse, err := client.GetMachine(ctx, machineRequest)
@@ -402,7 +402,7 @@ func nodesToPtables(
currentNamespace string,
machines []*v1.Machine,
) (pterm.TableData, error) {
- d := pterm.TableData{
+ tableData := pterm.TableData{
{
"ID",
"Name",
@@ -448,8 +448,8 @@ func nodesToPtables(
// Shared into this namespace
namespace = pterm.LightYellow(machine.Namespace.Name)
}
- d = append(
- d,
+ tableData = append(
+ tableData,
[]string{
strconv.FormatUint(machine.Id, headscale.BASE_10),
machine.Name,
@@ -463,5 +463,5 @@ func nodesToPtables(
)
}
- return d, nil
+ return tableData, nil
}
diff --git a/cmd/headscale/cli/preauthkeys.go b/cmd/headscale/cli/preauthkeys.go
index 13a3094..b28e600 100644
--- a/cmd/headscale/cli/preauthkeys.go
+++ b/cmd/headscale/cli/preauthkeys.go
@@ -45,7 +45,7 @@ var listPreAuthKeys = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
- n, err := cmd.Flags().GetString("namespace")
+ namespace, err := cmd.Flags().GetString("namespace")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
@@ -57,7 +57,7 @@ var listPreAuthKeys = &cobra.Command{
defer conn.Close()
request := &v1.ListPreAuthKeysRequest{
- Namespace: n,
+ Namespace: namespace,
}
response, err := client.ListPreAuthKeys(ctx, request)
@@ -77,34 +77,34 @@ var listPreAuthKeys = &cobra.Command{
return
}
- d := pterm.TableData{
+ tableData := pterm.TableData{
{"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"},
}
- for _, k := range response.PreAuthKeys {
+ for _, key := range response.PreAuthKeys {
expiration := "-"
- if k.GetExpiration() != nil {
- expiration = k.Expiration.AsTime().Format("2006-01-02 15:04:05")
+ if key.GetExpiration() != nil {
+ expiration = key.Expiration.AsTime().Format("2006-01-02 15:04:05")
}
var reusable string
- if k.GetEphemeral() {
+ if key.GetEphemeral() {
reusable = "N/A"
} else {
- reusable = fmt.Sprintf("%v", k.GetReusable())
+ reusable = fmt.Sprintf("%v", key.GetReusable())
}
- d = append(d, []string{
- k.GetId(),
- k.GetKey(),
+ tableData = append(tableData, []string{
+ key.GetId(),
+ key.GetKey(),
reusable,
- strconv.FormatBool(k.GetEphemeral()),
- strconv.FormatBool(k.GetUsed()),
+ strconv.FormatBool(key.GetEphemeral()),
+ strconv.FormatBool(key.GetUsed()),
expiration,
- k.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
+ key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
})
}
- err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
+ err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(
err,
diff --git a/cmd/headscale/cli/routes.go b/cmd/headscale/cli/routes.go
index 2ad64bc..7d0fcbf 100644
--- a/cmd/headscale/cli/routes.go
+++ b/cmd/headscale/cli/routes.go
@@ -81,14 +81,14 @@ var listRoutesCmd = &cobra.Command{
return
}
- d := routesToPtables(response.Routes)
+ tableData := routesToPtables(response.Routes)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return
}
- err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
+ err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(
err,
@@ -162,14 +162,14 @@ omit the route you do not want to enable.
return
}
- d := routesToPtables(response.Routes)
+ tableData := routesToPtables(response.Routes)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return
}
- err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
+ err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(
err,
@@ -184,15 +184,15 @@ omit the route you do not want to enable.
// routesToPtables converts the list of routes to a nice table.
func routesToPtables(routes *v1.Routes) pterm.TableData {
- d := pterm.TableData{{"Route", "Enabled"}}
+ tableData := pterm.TableData{{"Route", "Enabled"}}
for _, route := range routes.GetAdvertisedRoutes() {
enabled := isStringInSlice(routes.EnabledRoutes, route)
- d = append(d, []string{route, strconv.FormatBool(enabled)})
+ tableData = append(tableData, []string{route, strconv.FormatBool(enabled)})
}
- return d
+ return tableData
}
func isStringInSlice(strs []string, s string) bool {
diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go
index 343ebe3..ad49c01 100644
--- a/cmd/headscale/cli/utils.go
+++ b/cmd/headscale/cli/utils.go
@@ -318,7 +318,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
cfg.OIDC.MatchMap = loadOIDCMatchMap()
- h, err := headscale.NewHeadscale(cfg)
+ app, err := headscale.NewHeadscale(cfg)
if err != nil {
return nil, err
}
@@ -327,7 +327,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
if viper.GetString("acl_policy_path") != "" {
aclPath := absPath(viper.GetString("acl_policy_path"))
- err = h.LoadACLPolicy(aclPath)
+ err = app.LoadACLPolicy(aclPath)
if err != nil {
log.Error().
Str("path", aclPath).
@@ -336,7 +336,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
}
}
- return h, nil
+ return app, nil
}
func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
diff --git a/dns.go b/dns.go
index d7480b1..cb986cf 100644
--- a/dns.go
+++ b/dns.go
@@ -79,7 +79,7 @@ func generateMagicDNSRootDomains(
func getMapResponseDNSConfig(
dnsConfigOrig *tailcfg.DNSConfig,
baseDomain string,
- m Machine,
+ machine Machine,
peers Machines,
) *tailcfg.DNSConfig {
var dnsConfig *tailcfg.DNSConfig
@@ -88,11 +88,11 @@ func getMapResponseDNSConfig(
dnsConfig = dnsConfigOrig.Clone()
dnsConfig.Domains = append(
dnsConfig.Domains,
- fmt.Sprintf("%s.%s", m.Namespace.Name, baseDomain),
+ fmt.Sprintf("%s.%s", machine.Namespace.Name, baseDomain),
)
namespaceSet := set.New(set.ThreadSafe)
- namespaceSet.Add(m.Namespace)
+ namespaceSet.Add(machine.Namespace)
for _, p := range peers {
namespaceSet.Add(p.Namespace)
}
diff --git a/machine.go b/machine.go
index e8d2f72..d1f27cc 100644
--- a/machine.go
+++ b/machine.go
@@ -56,21 +56,21 @@ type (
)
// For the time being this method is rather naive.
-func (m Machine) isAlreadyRegistered() bool {
- return m.Registered
+func (machine Machine) isAlreadyRegistered() bool {
+ return machine.Registered
}
// isExpired returns whether the machine registration has expired.
-func (m Machine) isExpired() bool {
- return time.Now().UTC().After(*m.Expiry)
+func (machine Machine) isExpired() bool {
+ return time.Now().UTC().After(*machine.Expiry)
}
// If the Machine is expired, updateMachineExpiry updates the Machine Expiry time to the maximum allowed duration,
// or the default duration if no Expiry time was requested by the client. The expiry time here does not (yet) cause
// a client to be disconnected, however they will have to re-auth the machine if they attempt to reconnect after the
// expiry time.
-func (h *Headscale) updateMachineExpiry(m *Machine) {
- if m.isExpired() {
+func (h *Headscale) updateMachineExpiry(machine *Machine) {
+ if machine.isExpired() {
now := time.Now().UTC()
maxExpiry := now.Add(
h.cfg.MaxMachineRegistrationDuration,
@@ -80,31 +80,31 @@ func (h *Headscale) updateMachineExpiry(m *Machine) {
) // calculate the default expiry
// clamp the expiry time of the machine registration to the maximum allowed, or use the default if none supplied
- if maxExpiry.Before(*m.RequestedExpiry) {
+ if maxExpiry.Before(*machine.RequestedExpiry) {
log.Debug().
Msgf("Clamping registration expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineRegistrationDuration)
- m.Expiry = &maxExpiry
- } else if m.RequestedExpiry.IsZero() {
+ machine.Expiry = &maxExpiry
+ } else if machine.RequestedExpiry.IsZero() {
log.Debug().Msgf("Using default machine registration expiry time: %v (%v)", defaultExpiry, h.cfg.DefaultMachineRegistrationDuration)
- m.Expiry = &defaultExpiry
+ machine.Expiry = &defaultExpiry
} else {
- log.Debug().Msgf("Using requested machine registration expiry time: %v", m.RequestedExpiry)
- m.Expiry = m.RequestedExpiry
+ log.Debug().Msgf("Using requested machine registration expiry time: %v", machine.RequestedExpiry)
+ machine.Expiry = machine.RequestedExpiry
}
- h.db.Save(&m)
+ h.db.Save(&machine)
}
}
-func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
+func (h *Headscale) getDirectPeers(machine *Machine) (Machines, error) {
log.Trace().
Caller().
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Finding direct peers")
machines := Machines{}
if err := h.db.Preload("Namespace").Where("namespace_id = ? AND machine_key <> ? AND registered",
- m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil {
+ machine.NamespaceID, machine.MachineKey).Find(&machines).Error; err != nil {
log.Error().Err(err).Msg("Error accessing db")
return Machines{}, err
@@ -114,22 +114,22 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
log.Trace().
Caller().
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msgf("Found direct machines: %s", machines.String())
return machines, nil
}
// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for.
-func (h *Headscale) getShared(m *Machine) (Machines, error) {
+func (h *Headscale) getShared(machine *Machine) (Machines, error) {
log.Trace().
Caller().
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Finding shared peers")
sharedMachines := []SharedMachine{}
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("namespace_id = ?",
- m.NamespaceID).Find(&sharedMachines).Error; err != nil {
+ machine.NamespaceID).Find(&sharedMachines).Error; err != nil {
return Machines{}, err
}
@@ -142,22 +142,22 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) {
log.Trace().
Caller().
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msgf("Found shared peers: %s", peers.String())
return peers, nil
}
// getSharedTo fetches the machines of the namespaces this machine is shared in.
-func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
+func (h *Headscale) getSharedTo(machine *Machine) (Machines, error) {
log.Trace().
Caller().
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Finding peers in namespaces this machine is shared with")
sharedMachines := []SharedMachine{}
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("machine_id = ?",
- m.ID).Find(&sharedMachines).Error; err != nil {
+ machine.ID).Find(&sharedMachines).Error; err != nil {
return Machines{}, err
}
@@ -176,14 +176,14 @@ func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
log.Trace().
Caller().
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msgf("Found peers we are shared with: %s", peers.String())
return peers, nil
}
-func (h *Headscale) getPeers(m *Machine) (Machines, error) {
- direct, err := h.getDirectPeers(m)
+func (h *Headscale) getPeers(machine *Machine) (Machines, error) {
+ direct, err := h.getDirectPeers(machine)
if err != nil {
log.Error().
Caller().
@@ -193,7 +193,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
return Machines{}, err
}
- shared, err := h.getShared(m)
+ shared, err := h.getShared(machine)
if err != nil {
log.Error().
Caller().
@@ -203,7 +203,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
return Machines{}, err
}
- sharedTo, err := h.getSharedTo(m)
+ sharedTo, err := h.getSharedTo(machine)
if err != nil {
log.Error().
Caller().
@@ -220,7 +220,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
log.Trace().
Caller().
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msgf("Found total peers: %s", peers.String())
return peers, nil
@@ -262,9 +262,9 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
}
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
-func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) {
+func (h *Headscale) GetMachineByMachineKey(machineKey string) (*Machine, error) {
m := Machine{}
- if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey); result.Error != nil {
+ if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", machineKey); result.Error != nil {
return nil, result.Error
}
@@ -273,8 +273,8 @@ func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) {
// UpdateMachine takes a Machine struct pointer (typically already loaded from database
// and updates it with the latest data from the database.
-func (h *Headscale) UpdateMachine(m *Machine) error {
- if result := h.db.Find(m).First(&m); result.Error != nil {
+func (h *Headscale) UpdateMachine(machine *Machine) error {
+ if result := h.db.Find(machine).First(&machine); result.Error != nil {
return result.Error
}
@@ -282,16 +282,16 @@ func (h *Headscale) UpdateMachine(m *Machine) error {
}
// DeleteMachine softs deletes a Machine from the database.
-func (h *Headscale) DeleteMachine(m *Machine) error {
- err := h.RemoveSharedMachineFromAllNamespaces(m)
+func (h *Headscale) DeleteMachine(machine *Machine) error {
+ err := h.RemoveSharedMachineFromAllNamespaces(machine)
if err != nil && err != errorMachineNotShared {
return err
}
- m.Registered = false
- namespaceID := m.NamespaceID
- h.db.Save(&m) // we mark it as unregistered, just in case
- if err := h.db.Delete(&m).Error; err != nil {
+ machine.Registered = false
+ namespaceID := machine.NamespaceID
+ h.db.Save(&machine) // we mark it as unregistered, just in case
+ if err := h.db.Delete(&machine).Error; err != nil {
return err
}
@@ -299,14 +299,14 @@ func (h *Headscale) DeleteMachine(m *Machine) error {
}
// HardDeleteMachine hard deletes a Machine from the database.
-func (h *Headscale) HardDeleteMachine(m *Machine) error {
- err := h.RemoveSharedMachineFromAllNamespaces(m)
+func (h *Headscale) HardDeleteMachine(machine *Machine) error {
+ err := h.RemoveSharedMachineFromAllNamespaces(machine)
if err != nil && err != errorMachineNotShared {
return err
}
- namespaceID := m.NamespaceID
- if err := h.db.Unscoped().Delete(&m).Error; err != nil {
+ namespaceID := machine.NamespaceID
+ if err := h.db.Unscoped().Delete(&machine).Error; err != nil {
return err
}
@@ -314,10 +314,10 @@ func (h *Headscale) HardDeleteMachine(m *Machine) error {
}
// GetHostInfo returns a Hostinfo struct for the machine.
-func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
+func (machine *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
hostinfo := tailcfg.Hostinfo{}
- if len(m.HostInfo) != 0 {
- hi, err := m.HostInfo.MarshalJSON()
+ if len(machine.HostInfo) != 0 {
+ hi, err := machine.HostInfo.MarshalJSON()
if err != nil {
return nil, err
}
@@ -330,17 +330,17 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
return &hostinfo, nil
}
-func (h *Headscale) isOutdated(m *Machine) bool {
- if err := h.UpdateMachine(m); err != nil {
+func (h *Headscale) isOutdated(machine *Machine) bool {
+ if err := h.UpdateMachine(machine); err != nil {
// It does not seem meaningful to propagate this error as the end result
// will have to be that the machine has to be considered outdated.
return true
}
- sharedMachines, _ := h.getShared(m)
+ sharedMachines, _ := h.getShared(machine)
namespaceSet := set.New(set.ThreadSafe)
- namespaceSet.Add(m.Namespace.Name)
+ namespaceSet.Add(machine.Namespace.Name)
// Check if any of our shared namespaces has updates that we have
// not propagated.
@@ -356,22 +356,22 @@ func (h *Headscale) isOutdated(m *Machine) bool {
lastChange := h.getLastStateChange(namespaces...)
log.Trace().
Caller().
- Str("machine", m.Name).
- Time("last_successful_update", *m.LastSuccessfulUpdate).
+ Str("machine", machine.Name).
+ Time("last_successful_update", *machine.LastSuccessfulUpdate).
Time("last_state_change", lastChange).
- Msgf("Checking if %s is missing updates", m.Name)
+ Msgf("Checking if %s is missing updates", machine.Name)
- return m.LastSuccessfulUpdate.Before(lastChange)
+ return machine.LastSuccessfulUpdate.Before(lastChange)
}
-func (m Machine) String() string {
- return m.Name
+func (machine Machine) String() string {
+ return machine.Name
}
-func (ms Machines) String() string {
- temp := make([]string, len(ms))
+func (machines Machines) String() string {
+ temp := make([]string, len(machines))
- for index, machine := range ms {
+ for index, machine := range machines {
temp[index] = machine.Name
}
@@ -379,24 +379,24 @@ func (ms Machines) String() string {
}
// TODO(kradalby): Remove when we have generics...
-func (ms MachinesP) String() string {
- temp := make([]string, len(ms))
+func (machines MachinesP) String() string {
+ temp := make([]string, len(machines))
- for index, machine := range ms {
+ for index, machine := range machines {
temp[index] = machine.Name
}
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
}
-func (ms Machines) toNodes(
+func (machines Machines) toNodes(
baseDomain string,
dnsConfig *tailcfg.DNSConfig,
includeRoutes bool,
) ([]*tailcfg.Node, error) {
- nodes := make([]*tailcfg.Node, len(ms))
+ nodes := make([]*tailcfg.Node, len(machines))
- for index, machine := range ms {
+ for index, machine := range machines {
node, err := machine.toNode(baseDomain, dnsConfig, includeRoutes)
if err != nil {
return nil, err
@@ -410,23 +410,24 @@ func (ms Machines) toNodes(
// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
// as per the expected behaviour in the official SaaS.
-func (m Machine) toNode(
+func (machine Machine) toNode(
baseDomain string,
dnsConfig *tailcfg.DNSConfig,
includeRoutes bool,
) (*tailcfg.Node, error) {
- nKey, err := wgkey.ParseHex(m.NodeKey)
+ nodeKey, err := wgkey.ParseHex(machine.NodeKey)
if err != nil {
return nil, err
}
- mKey, err := wgkey.ParseHex(m.MachineKey)
+
+ machineKey, err := wgkey.ParseHex(machine.MachineKey)
if err != nil {
return nil, err
}
var discoKey tailcfg.DiscoKey
- if m.DiscoKey != "" {
- dKey, err := wgkey.ParseHex(m.DiscoKey)
+ if machine.DiscoKey != "" {
+ dKey, err := wgkey.ParseHex(machine.DiscoKey)
if err != nil {
return nil, err
}
@@ -436,12 +437,12 @@ func (m Machine) toNode(
}
addrs := []netaddr.IPPrefix{}
- ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", m.IPAddress))
+ ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", machine.IPAddress))
if err != nil {
log.Trace().
Caller().
- Str("ip", m.IPAddress).
- Msgf("Failed to parse IP Prefix from IP: %s", m.IPAddress)
+ Str("ip", machine.IPAddress).
+ Msgf("Failed to parse IP Prefix from IP: %s", machine.IPAddress)
return nil, err
}
@@ -455,8 +456,8 @@ func (m Machine) toNode(
if includeRoutes {
routesStr := []string{}
- if len(m.EnabledRoutes) != 0 {
- allwIps, err := m.EnabledRoutes.MarshalJSON()
+ if len(machine.EnabledRoutes) != 0 {
+ allwIps, err := machine.EnabledRoutes.MarshalJSON()
if err != nil {
return nil, err
}
@@ -476,8 +477,8 @@ func (m Machine) toNode(
}
endpoints := []string{}
- if len(m.Endpoints) != 0 {
- be, err := m.Endpoints.MarshalJSON()
+ if len(machine.Endpoints) != 0 {
+ be, err := machine.Endpoints.MarshalJSON()
if err != nil {
return nil, err
}
@@ -488,8 +489,8 @@ func (m Machine) toNode(
}
hostinfo := tailcfg.Hostinfo{}
- if len(m.HostInfo) != 0 {
- hi, err := m.HostInfo.MarshalJSON()
+ if len(machine.HostInfo) != 0 {
+ hi, err := machine.HostInfo.MarshalJSON()
if err != nil {
return nil, err
}
@@ -507,29 +508,34 @@ func (m Machine) toNode(
}
var keyExpiry time.Time
- if m.Expiry != nil {
- keyExpiry = *m.Expiry
+ if machine.Expiry != nil {
+ keyExpiry = *machine.Expiry
} else {
keyExpiry = time.Time{}
}
var hostname string
if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS
- hostname = fmt.Sprintf("%s.%s.%s", m.Name, m.Namespace.Name, baseDomain)
+ hostname = fmt.Sprintf(
+ "%s.%s.%s",
+ machine.Name,
+ machine.Namespace.Name,
+ baseDomain,
+ )
} else {
- hostname = m.Name
+ hostname = machine.Name
}
n := tailcfg.Node{
- ID: tailcfg.NodeID(m.ID), // this is the actual ID
+ ID: tailcfg.NodeID(machine.ID), // this is the actual ID
StableID: tailcfg.StableNodeID(
- strconv.FormatUint(m.ID, BASE_10),
+ strconv.FormatUint(machine.ID, BASE_10),
), // in headscale, unlike tailcontrol server, IDs are permanent
Name: hostname,
- User: tailcfg.UserID(m.NamespaceID),
- Key: tailcfg.NodeKey(nKey),
+ User: tailcfg.UserID(machine.NamespaceID),
+ Key: tailcfg.NodeKey(nodeKey),
KeyExpiry: keyExpiry,
- Machine: tailcfg.MachineKey(mKey),
+ Machine: tailcfg.MachineKey(machineKey),
DiscoKey: discoKey,
Addresses: addrs,
AllowedIPs: allowedIPs,
@@ -537,68 +543,73 @@ func (m Machine) toNode(
DERP: derp,
Hostinfo: hostinfo,
- Created: m.CreatedAt,
- LastSeen: m.LastSeen,
+ Created: machine.CreatedAt,
+ LastSeen: machine.LastSeen,
KeepAlive: true,
- MachineAuthorized: m.Registered,
+ MachineAuthorized: machine.Registered,
Capabilities: []string{tailcfg.CapabilityFileSharing},
}
return &n, nil
}
-func (m *Machine) toProto() *v1.Machine {
- machine := &v1.Machine{
- Id: m.ID,
- MachineKey: m.MachineKey,
+func (machine *Machine) toProto() *v1.Machine {
+ machineProto := &v1.Machine{
+ Id: machine.ID,
+ MachineKey: machine.MachineKey,
- NodeKey: m.NodeKey,
- DiscoKey: m.DiscoKey,
- IpAddress: m.IPAddress,
- Name: m.Name,
- Namespace: m.Namespace.toProto(),
+ NodeKey: machine.NodeKey,
+ DiscoKey: machine.DiscoKey,
+ IpAddress: machine.IPAddress,
+ Name: machine.Name,
+ Namespace: machine.Namespace.toProto(),
- Registered: m.Registered,
+ Registered: machine.Registered,
// TODO(kradalby): Implement register method enum converter
// RegisterMethod: ,
- CreatedAt: timestamppb.New(m.CreatedAt),
+ CreatedAt: timestamppb.New(machine.CreatedAt),
}
- if m.AuthKey != nil {
- machine.PreAuthKey = m.AuthKey.toProto()
+ if machine.AuthKey != nil {
+ machineProto.PreAuthKey = machine.AuthKey.toProto()
}
- if m.LastSeen != nil {
- machine.LastSeen = timestamppb.New(*m.LastSeen)
+ if machine.LastSeen != nil {
+ machineProto.LastSeen = timestamppb.New(*machine.LastSeen)
}
- if m.LastSuccessfulUpdate != nil {
- machine.LastSuccessfulUpdate = timestamppb.New(*m.LastSuccessfulUpdate)
+ if machine.LastSuccessfulUpdate != nil {
+ machineProto.LastSuccessfulUpdate = timestamppb.New(
+ *machine.LastSuccessfulUpdate,
+ )
}
- if m.Expiry != nil {
- machine.Expiry = timestamppb.New(*m.Expiry)
+ if machine.Expiry != nil {
+ machineProto.Expiry = timestamppb.New(*machine.Expiry)
}
- return machine
+ return machineProto
}
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
-func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, error) {
- ns, err := h.GetNamespace(namespace)
+func (h *Headscale) RegisterMachine(
+ key string,
+ namespaceName string,
+) (*Machine, error) {
+ namespace, err := h.GetNamespace(namespaceName)
if err != nil {
return nil, err
}
- mKey, err := wgkey.ParseHex(key)
+ machineKey, err := wgkey.ParseHex(key)
if err != nil {
return nil, err
}
- m := Machine{}
- if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(
+ machine := Machine{}
+ if result := h.db.First(&machine, "machine_key = ?", machineKey.HexString()); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
@@ -607,15 +618,15 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
log.Trace().
Caller().
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Attempting to register machine")
- if m.isAlreadyRegistered() {
+ if machine.isAlreadyRegistered() {
err := errors.New("Machine already registered")
log.Error().
Caller().
Err(err).
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Attempting to register machine")
return nil, err
@@ -626,7 +637,7 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
log.Error().
Caller().
Err(err).
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Could not find IP for the new machine")
return nil, err
@@ -634,27 +645,27 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
log.Trace().
Caller().
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("ip", ip.String()).
Msg("Found IP for host")
- m.IPAddress = ip.String()
- m.NamespaceID = ns.ID
- m.Registered = true
- m.RegisterMethod = "cli"
- h.db.Save(&m)
+ machine.IPAddress = ip.String()
+ machine.NamespaceID = namespace.ID
+ machine.Registered = true
+ machine.RegisterMethod = "cli"
+ h.db.Save(&machine)
log.Trace().
Caller().
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("ip", ip.String()).
Msg("Machine registered with the database")
- return &m, nil
+ return &machine, nil
}
-func (m *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
- hostInfo, err := m.GetHostInfo()
+func (machine *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
+ hostInfo, err := machine.GetHostInfo()
if err != nil {
return nil, err
}
@@ -662,8 +673,8 @@ func (m *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
return hostInfo.RoutableIPs, nil
}
-func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
- data, err := m.EnabledRoutes.MarshalJSON()
+func (machine *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
+ data, err := machine.EnabledRoutes.MarshalJSON()
if err != nil {
return nil, err
}
@@ -686,13 +697,13 @@ func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
return routes, nil
}
-func (m *Machine) IsRoutesEnabled(routeStr string) bool {
+func (machine *Machine) IsRoutesEnabled(routeStr string) bool {
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return false
}
- enabledRoutes, err := m.GetEnabledRoutes()
+ enabledRoutes, err := machine.GetEnabledRoutes()
if err != nil {
return false
}
@@ -708,7 +719,7 @@ func (m *Machine) IsRoutesEnabled(routeStr string) bool {
// EnableNodeRoute enables new routes based on a list of new routes. It will _replace_ the
// previous list of routes.
-func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
+func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error {
newRoutes := make([]netaddr.IPPrefix, len(routeStrs))
for index, routeStr := range routeStrs {
route, err := netaddr.ParseIPPrefix(routeStr)
@@ -719,7 +730,7 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
newRoutes[index] = route
}
- availableRoutes, err := m.GetAdvertisedRoutes()
+ availableRoutes, err := machine.GetAdvertisedRoutes()
if err != nil {
return err
}
@@ -728,7 +739,7 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
if !containsIpPrefix(availableRoutes, newRoute) {
return fmt.Errorf(
"route (%s) is not available on node %s",
- m.Name,
+ machine.Name,
newRoute,
)
}
@@ -739,10 +750,10 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
return err
}
- m.EnabledRoutes = datatypes.JSON(routes)
- h.db.Save(&m)
+ machine.EnabledRoutes = datatypes.JSON(routes)
+ h.db.Save(&machine)
- err = h.RequestMapUpdates(m.NamespaceID)
+ err = h.RequestMapUpdates(machine.NamespaceID)
if err != nil {
return err
}
@@ -750,13 +761,13 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
return nil
}
-func (m *Machine) RoutesToProto() (*v1.Routes, error) {
- availableRoutes, err := m.GetAdvertisedRoutes()
+func (machine *Machine) RoutesToProto() (*v1.Routes, error) {
+ availableRoutes, err := machine.GetAdvertisedRoutes()
if err != nil {
return nil, err
}
- enabledRoutes, err := m.GetEnabledRoutes()
+ enabledRoutes, err := machine.GetEnabledRoutes()
if err != nil {
return nil, err
}
diff --git a/namespaces.go b/namespaces.go
index bea922d..858a7aa 100644
--- a/namespaces.go
+++ b/namespaces.go
@@ -32,12 +32,12 @@ type Namespace struct {
// CreateNamespace creates a new Namespace. Returns error if could not be created
// or another namespace already exists.
func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
- n := Namespace{}
- if err := h.db.Where("name = ?", name).First(&n).Error; err == nil {
+ namespace := Namespace{}
+ if err := h.db.Where("name = ?", name).First(&namespace).Error; err == nil {
return nil, errorNamespaceExists
}
- n.Name = name
- if err := h.db.Create(&n).Error; err != nil {
+ namespace.Name = name
+ if err := h.db.Create(&namespace).Error; err != nil {
log.Error().
Str("func", "CreateNamespace").
Err(err).
@@ -46,22 +46,22 @@ func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
return nil, err
}
- return &n, nil
+ return &namespace, nil
}
// DestroyNamespace destroys a Namespace. Returns error if the Namespace does
// not exist or if there are machines associated with it.
func (h *Headscale) DestroyNamespace(name string) error {
- n, err := h.GetNamespace(name)
+ namespace, err := h.GetNamespace(name)
if err != nil {
return errorNamespaceNotFound
}
- m, err := h.ListMachinesInNamespace(name)
+ machines, err := h.ListMachinesInNamespace(name)
if err != nil {
return err
}
- if len(m) > 0 {
+ if len(machines) > 0 {
return errorNamespaceNotEmptyOfNodes
}
@@ -69,14 +69,14 @@ func (h *Headscale) DestroyNamespace(name string) error {
if err != nil {
return err
}
- for _, p := range keys {
- err = h.DestroyPreAuthKey(&p)
+ for _, key := range keys {
+ err = h.DestroyPreAuthKey(&key)
if err != nil {
return err
}
}
- if result := h.db.Unscoped().Delete(&n); result.Error != nil {
+ if result := h.db.Unscoped().Delete(&namespace); result.Error != nil {
return result.Error
}
@@ -86,7 +86,7 @@ func (h *Headscale) DestroyNamespace(name string) error {
// RenameNamespace renames a Namespace. Returns error if the Namespace does
// not exist or if another Namespace exists with the new name.
func (h *Headscale) RenameNamespace(oldName, newName string) error {
- n, err := h.GetNamespace(oldName)
+ oldNamespace, err := h.GetNamespace(oldName)
if err != nil {
return err
}
@@ -98,13 +98,13 @@ func (h *Headscale) RenameNamespace(oldName, newName string) error {
return err
}
- n.Name = newName
+ oldNamespace.Name = newName
- if result := h.db.Save(&n); result.Error != nil {
+ if result := h.db.Save(&oldNamespace); result.Error != nil {
return result.Error
}
- err = h.RequestMapUpdates(n.ID)
+ err = h.RequestMapUpdates(oldNamespace.ID)
if err != nil {
return err
}
@@ -114,15 +114,15 @@ func (h *Headscale) RenameNamespace(oldName, newName string) error {
// GetNamespace fetches a namespace by name.
func (h *Headscale) GetNamespace(name string) (*Namespace, error) {
- n := Namespace{}
- if result := h.db.First(&n, "name = ?", name); errors.Is(
+ namespace := Namespace{}
+ if result := h.db.First(&namespace, "name = ?", name); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
return nil, errorNamespaceNotFound
}
- return &n, nil
+ return &namespace, nil
}
// ListNamespaces gets all the existing namespaces.
@@ -137,13 +137,13 @@ func (h *Headscale) ListNamespaces() ([]Namespace, error) {
// ListMachinesInNamespace gets all the nodes in a given namespace.
func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) {
- n, err := h.GetNamespace(name)
+ namespace, err := h.GetNamespace(name)
if err != nil {
return nil, err
}
machines := []Machine{}
- if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil {
+ if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: namespace.ID}).Find(&machines).Error; err != nil {
return nil, err
}
@@ -176,17 +176,18 @@ func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error
}
// SetMachineNamespace assigns a Machine to a namespace.
-func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error {
- n, err := h.GetNamespace(namespaceName)
+func (h *Headscale) SetMachineNamespace(machine *Machine, namespaceName string) error {
+ namespace, err := h.GetNamespace(namespaceName)
if err != nil {
return err
}
- m.NamespaceID = n.ID
- h.db.Save(&m)
+ machine.NamespaceID = namespace.ID
+ h.db.Save(&machine)
return nil
}
+// TODO(kradalby): Remove the need for this.
// RequestMapUpdates signals the KV worker to update the maps for this namespace.
func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
namespace := Namespace{}
@@ -194,8 +195,8 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
return err
}
- v, err := h.getValue("namespaces_pending_updates")
- if err != nil || v == "" {
+ namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates")
+ if err != nil || namespacesPendingUpdates == "" {
err = h.setValue(
"namespaces_pending_updates",
fmt.Sprintf(`["%s"]`, namespace.Name),
@@ -207,7 +208,7 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
return nil
}
names := []string{}
- err = json.Unmarshal([]byte(v), &names)
+ err = json.Unmarshal([]byte(namespacesPendingUpdates), &names)
if err != nil {
err = h.setValue(
"namespaces_pending_updates",
@@ -235,16 +236,16 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
}
func (h *Headscale) checkForNamespacesPendingUpdates() {
- v, err := h.getValue("namespaces_pending_updates")
+ namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates")
if err != nil {
return
}
- if v == "" {
+ if namespacesPendingUpdates == "" {
return
}
namespaces := []string{}
- err = json.Unmarshal([]byte(v), &namespaces)
+ err = json.Unmarshal([]byte(namespacesPendingUpdates), &namespaces)
if err != nil {
return
}
@@ -255,11 +256,11 @@ func (h *Headscale) checkForNamespacesPendingUpdates() {
Msg("Sending updates to nodes in namespacespace")
h.setLastStateChangeToNow(namespace)
}
- newV, err := h.getValue("namespaces_pending_updates")
+ newPendingUpdateValue, err := h.getValue("namespaces_pending_updates")
if err != nil {
return
}
- if v == newV { // only clear when no changes, so we notified everybody
+ if namespacesPendingUpdates == newPendingUpdateValue { // only clear when no changes, so we notified everybody
err = h.setValue("namespaces_pending_updates", "")
if err != nil {
log.Error().
@@ -273,7 +274,7 @@ func (h *Headscale) checkForNamespacesPendingUpdates() {
}
func (n *Namespace) toUser() *tailcfg.User {
- u := tailcfg.User{
+ user := tailcfg.User{
ID: tailcfg.UserID(n.ID),
LoginName: n.Name,
DisplayName: n.Name,
@@ -283,11 +284,11 @@ func (n *Namespace) toUser() *tailcfg.User {
Created: time.Time{},
}
- return &u
+ return &user
}
func (n *Namespace) toLogin() *tailcfg.Login {
- l := tailcfg.Login{
+ login := tailcfg.Login{
ID: tailcfg.LoginID(n.ID),
LoginName: n.Name,
DisplayName: n.Name,
@@ -295,14 +296,14 @@ func (n *Namespace) toLogin() *tailcfg.Login {
Domain: "headscale.net",
}
- return &l
+ return &login
}
-func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile {
+func getMapResponseUserProfiles(machine Machine, peers Machines) []tailcfg.UserProfile {
namespaceMap := make(map[string]Namespace)
- namespaceMap[m.Namespace.Name] = m.Namespace
- for _, p := range peers {
- namespaceMap[p.Namespace.Name] = p.Namespace // not worth checking if already is there
+ namespaceMap[machine.Namespace.Name] = machine.Namespace
+ for _, peer := range peers {
+ namespaceMap[peer.Namespace.Name] = peer.Namespace // not worth checking if already is there
}
profiles := []tailcfg.UserProfile{}
diff --git a/oidc.go b/oidc.go
index c77a249..e68e112 100644
--- a/oidc.go
+++ b/oidc.go
@@ -68,10 +68,10 @@ func (h *Headscale) initOIDC() error {
// RegisterOIDC redirects to the OIDC provider for authentication
// Puts machine key in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:mKey.
-func (h *Headscale) RegisterOIDC(c *gin.Context) {
- mKeyStr := c.Param("mkey")
+func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
+ mKeyStr := ctx.Param("mkey")
if mKeyStr == "" {
- c.String(http.StatusBadRequest, "Wrong params")
+ ctx.String(http.StatusBadRequest, "Wrong params")
return
}
@@ -79,7 +79,7 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
b := make([]byte, RANDOM_BYTE_SIZE)
if _, err := rand.Read(b); err != nil {
log.Error().Msg("could not read 16 bytes from rand")
- c.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
+ ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
return
}
@@ -92,7 +92,7 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
authUrl := h.oauth2Config.AuthCodeURL(stateStr)
log.Debug().Msgf("Redirecting to %s for authentication", authUrl)
- c.Redirect(http.StatusFound, authUrl)
+ ctx.Redirect(http.StatusFound, authUrl)
}
// OIDCCallback handles the callback from the OIDC endpoint
@@ -100,19 +100,19 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into machine HostInfo
// Listens in /oidc/callback.
-func (h *Headscale) OIDCCallback(c *gin.Context) {
- code := c.Query("code")
- state := c.Query("state")
+func (h *Headscale) OIDCCallback(ctx *gin.Context) {
+ code := ctx.Query("code")
+ state := ctx.Query("state")
if code == "" || state == "" {
- c.String(http.StatusBadRequest, "Wrong params")
+ ctx.String(http.StatusBadRequest, "Wrong params")
return
}
oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
if err != nil {
- c.String(http.StatusBadRequest, "Could not exchange code for token")
+ ctx.String(http.StatusBadRequest, "Could not exchange code for token")
return
}
@@ -121,7 +121,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
if !rawIDTokenOK {
- c.String(http.StatusBadRequest, "Could not extract ID Token")
+ ctx.String(http.StatusBadRequest, "Could not extract ID Token")
return
}
@@ -130,7 +130,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
idToken, err := verifier.Verify(context.Background(), rawIDToken)
if err != nil {
- c.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
+ ctx.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
return
}
@@ -145,7 +145,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
// Extract custom claims
var claims IDTokenClaims
if err = idToken.Claims(&claims); err != nil {
- c.String(
+ ctx.String(
http.StatusBadRequest,
fmt.Sprintf("Failed to decode id token claims: %s", err),
)
@@ -159,7 +159,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
if !mKeyFound {
log.Error().
Msg("requested machine state key expired before authorisation completed")
- c.String(http.StatusBadRequest, "state has expired")
+ ctx.String(http.StatusBadRequest, "state has expired")
return
}
@@ -167,16 +167,19 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
if !mKeyOK {
log.Error().Msg("could not get machine key from cache")
- c.String(http.StatusInternalServerError, "could not get machine key from cache")
+ ctx.String(
+ http.StatusInternalServerError,
+ "could not get machine key from cache",
+ )
return
}
// retrieve machine information
- m, err := h.GetMachineByMachineKey(mKeyStr)
+ machine, err := h.GetMachineByMachineKey(mKeyStr)
if err != nil {
log.Error().Msg("machine key not found in database")
- c.String(
+ ctx.String(
http.StatusInternalServerError,
"could not get machine info from database",
)
@@ -186,19 +189,19 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
now := time.Now().UTC()
- if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok {
+ if namespaceName, ok := h.getNamespaceFromEmail(claims.Email); ok {
// register the machine if it's new
- if !m.Registered {
+ if !machine.Registered {
log.Debug().Msg("Registering new machine after successful callback")
- ns, err := h.GetNamespace(nsName)
+ namespace, err := h.GetNamespace(namespaceName)
if err != nil {
- ns, err = h.CreateNamespace(nsName)
+ namespace, err = h.CreateNamespace(namespaceName)
if err != nil {
log.Error().
Msgf("could not create new namespace '%s'", claims.Email)
- c.String(
+ ctx.String(
http.StatusInternalServerError,
"could not create new namespace",
)
@@ -209,7 +212,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
ip, err := h.getAvailableIP()
if err != nil {
- c.String(
+ ctx.String(
http.StatusInternalServerError,
"could not get an IP from the pool",
)
@@ -217,17 +220,17 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
return
}
- m.IPAddress = ip.String()
- m.NamespaceID = ns.ID
- m.Registered = true
- m.RegisterMethod = "oidc"
- m.LastSuccessfulUpdate = &now
- h.db.Save(&m)
+ machine.IPAddress = ip.String()
+ machine.NamespaceID = namespace.ID
+ machine.Registered = true
+ machine.RegisterMethod = "oidc"
+ machine.LastSuccessfulUpdate = &now
+ h.db.Save(&machine)
}
- h.updateMachineExpiry(m)
+ h.updateMachineExpiry(machine)
- c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
+ ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
headscale
@@ -243,9 +246,9 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
log.Error().
Str("email", claims.Email).
Str("username", claims.Username).
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Email could not be mapped to a namespace")
- c.String(
+ ctx.String(
http.StatusBadRequest,
"email from claim could not be mapped to a namespace",
)
diff --git a/poll.go b/poll.go
index 3f7b293..1927853 100644
--- a/poll.go
+++ b/poll.go
@@ -233,7 +233,7 @@ func (h *Headscale) PollNetMapStream(
) {
go h.scheduledPollWorker(cancelKeepAlive, updateChan, keepAliveChan, mKey, req, m)
- c.Stream(func(w io.Writer) bool {
+ c.Stream(func(writer io.Writer) bool {
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
@@ -252,7 +252,7 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Sending data received via pollData channel")
- _, err := w.Write(data)
+ _, err := writer.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
@@ -305,7 +305,7 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Sending keep alive message")
- _, err := w.Write(data)
+ _, err := writer.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
@@ -370,7 +370,7 @@ func (h *Headscale) PollNetMapStream(
Err(err).
Msg("Could not get the map update")
}
- _, err = w.Write(data)
+ _, err = writer.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
diff --git a/sharing.go b/sharing.go
index 8f65414..741deb6 100644
--- a/sharing.go
+++ b/sharing.go
@@ -18,13 +18,16 @@ type SharedMachine struct {
}
// AddSharedMachineToNamespace adds a machine as a shared node to a namespace.
-func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error {
- if m.NamespaceID == ns.ID {
+func (h *Headscale) AddSharedMachineToNamespace(
+ machine *Machine,
+ namespace *Namespace,
+) error {
+ if machine.NamespaceID == namespace.ID {
return errorSameNamespace
}
sharedMachines := []SharedMachine{}
- if err := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).Find(&sharedMachines).Error; err != nil {
+ if err := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID).Find(&sharedMachines).Error; err != nil {
return err
}
if len(sharedMachines) > 0 {
@@ -32,10 +35,10 @@ func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error
}
sharedMachine := SharedMachine{
- MachineID: m.ID,
- Machine: *m,
- NamespaceID: ns.ID,
- Namespace: *ns,
+ MachineID: machine.ID,
+ Machine: *machine,
+ NamespaceID: namespace.ID,
+ Namespace: *namespace,
}
h.db.Save(&sharedMachine)
@@ -43,14 +46,17 @@ func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error
}
// RemoveSharedMachineFromNamespace removes a shared machine from a namespace.
-func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace) error {
- if m.NamespaceID == ns.ID {
+func (h *Headscale) RemoveSharedMachineFromNamespace(
+ machine *Machine,
+ namespace *Namespace,
+) error {
+ if machine.NamespaceID == namespace.ID {
// Can't unshare from primary namespace
return errorMachineNotShared
}
sharedMachine := SharedMachine{}
- result := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).
+ result := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID).
Unscoped().
Delete(&sharedMachine)
if result.Error != nil {
@@ -61,7 +67,7 @@ func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace)
return errorMachineNotShared
}
- err := h.RequestMapUpdates(ns.ID)
+ err := h.RequestMapUpdates(namespace.ID)
if err != nil {
return err
}
@@ -70,9 +76,9 @@ func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace)
}
// RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces.
-func (h *Headscale) RemoveSharedMachineFromAllNamespaces(m *Machine) error {
+func (h *Headscale) RemoveSharedMachineFromAllNamespaces(machine *Machine) error {
sharedMachine := SharedMachine{}
- if result := h.db.Where("machine_id = ?", m.ID).Unscoped().Delete(&sharedMachine); result.Error != nil {
+ if result := h.db.Where("machine_id = ?", machine.ID).Unscoped().Delete(&sharedMachine); result.Error != nil {
return result.Error
}
diff --git a/swagger.go b/swagger.go
index 01b2eb5..9e62d39 100644
--- a/swagger.go
+++ b/swagger.go
@@ -13,8 +13,8 @@ import (
//go:embed gen/openapiv2/headscale/v1/headscale.swagger.json
var apiV1JSON []byte
-func SwaggerUI(c *gin.Context) {
- t := template.Must(template.New("swagger").Parse(`
+func SwaggerUI(ctx *gin.Context) {
+ swaggerTemplate := template.Must(template.New("swagger").Parse(`
@@ -47,12 +47,12 @@ func SwaggerUI(c *gin.Context) {
`))
var payload bytes.Buffer
- if err := t.Execute(&payload, struct{}{}); err != nil {
+ if err := swaggerTemplate.Execute(&payload, struct{}{}); err != nil {
log.Error().
Caller().
Err(err).
Msg("Could not render Swagger")
- c.Data(
+ ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Swagger"),
@@ -61,9 +61,9 @@ func SwaggerUI(c *gin.Context) {
return
}
- c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
+ ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
}
-func SwaggerAPIv1(c *gin.Context) {
- c.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON)
+func SwaggerAPIv1(ctx *gin.Context) {
+ ctx.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON)
}
diff --git a/utils.go b/utils.go
index 85e3ba2..803cfc5 100644
--- a/utils.go
+++ b/utils.go
@@ -36,7 +36,7 @@ func decode(
func decodeMsg(
msg []byte,
- v interface{},
+ output interface{},
pubKey *wgkey.Key,
privKey *wgkey.Private,
) error {
@@ -45,7 +45,7 @@ func decodeMsg(
return err
}
// fmt.Println(string(decrypted))
- if err := json.Unmarshal(decrypted, v); err != nil {
+ if err := json.Unmarshal(decrypted, output); err != nil {
return fmt.Errorf("response: %v", err)
}
@@ -78,13 +78,17 @@ func encode(v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, e
return encodeMsg(b, pubKey, privKey)
}
-func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) {
+func encodeMsg(
+ payload []byte,
+ pubKey *wgkey.Key,
+ privKey *wgkey.Private,
+) ([]byte, error) {
var nonce [24]byte
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
panic(err)
}
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
- msg := box.Seal(nonce[:], b, &nonce, pub, pri)
+ msg := box.Seal(nonce[:], payload, &nonce, pub, pri)
return msg, nil
}