diff --git a/.golangci.yaml b/.golangci.yaml index b4ad089..f7fc0c6 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -28,6 +28,9 @@ linters: # In progress - gocritic + # TODO: approve: ok, db, id + - varnamelen + # We should strive to enable these: - testpackage - stylecheck @@ -39,7 +42,6 @@ linters: - gosec - forbidigo - dupl - - varnamelen - makezero - paralleltest diff --git a/acls.go b/acls.go index fdcb098..01de114 100644 --- a/acls.go +++ b/acls.go @@ -41,18 +41,18 @@ func (h *Headscale) LoadACLPolicy(path string) error { defer policyFile.Close() var policy ACLPolicy - b, err := io.ReadAll(policyFile) + policyBytes, err := io.ReadAll(policyFile) if err != nil { return err } - ast, err := hujson.Parse(b) + ast, err := hujson.Parse(policyBytes) if err != nil { return err } ast.Standardize() - b = ast.Pack() - err = json.Unmarshal(b, &policy) + policyBytes = ast.Pack() + err = json.Unmarshal(policyBytes, &policy) if err != nil { return err } @@ -73,32 +73,32 @@ func (h *Headscale) LoadACLPolicy(path string) error { func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) { rules := []tailcfg.FilterRule{} - for i, a := range h.aclPolicy.ACLs { - if a.Action != "accept" { + for index, acl := range h.aclPolicy.ACLs { + if acl.Action != "accept" { return nil, errorInvalidAction } - r := tailcfg.FilterRule{} + filterRule := tailcfg.FilterRule{} srcIPs := []string{} - for j, u := range a.Users { - srcs, err := h.generateACLPolicySrcIP(u) + for innerIndex, user := range acl.Users { + srcs, err := h.generateACLPolicySrcIP(user) if err != nil { log.Error(). - Msgf("Error parsing ACL %d, User %d", i, j) + Msgf("Error parsing ACL %d, User %d", index, innerIndex) return nil, err } srcIPs = append(srcIPs, srcs...) } - r.SrcIPs = srcIPs + filterRule.SrcIPs = srcIPs destPorts := []tailcfg.NetPortRange{} - for j, d := range a.Ports { - dests, err := h.generateACLPolicyDestPorts(d) + for innerIndex, ports := range acl.Ports { + dests, err := h.generateACLPolicyDestPorts(ports) if err != nil { log.Error(). - Msgf("Error parsing ACL %d, Port %d", i, j) + Msgf("Error parsing ACL %d, Port %d", index, innerIndex) return nil, err } @@ -162,17 +162,17 @@ func (h *Headscale) generateACLPolicyDestPorts( return dests, nil } -func (h *Headscale) expandAlias(s string) ([]string, error) { - if s == "*" { +func (h *Headscale) expandAlias(alias string) ([]string, error) { + if alias == "*" { return []string{"*"}, nil } - if strings.HasPrefix(s, "group:") { - if _, ok := h.aclPolicy.Groups[s]; !ok { + if strings.HasPrefix(alias, "group:") { + if _, ok := h.aclPolicy.Groups[alias]; !ok { return nil, errorInvalidGroup } ips := []string{} - for _, n := range h.aclPolicy.Groups[s] { + for _, n := range h.aclPolicy.Groups[alias] { nodes, err := h.ListMachinesInNamespace(n) if err != nil { return nil, errorInvalidNamespace @@ -185,8 +185,8 @@ func (h *Headscale) expandAlias(s string) ([]string, error) { return ips, nil } - if strings.HasPrefix(s, "tag:") { - if _, ok := h.aclPolicy.TagOwners[s]; !ok { + if strings.HasPrefix(alias, "tag:") { + if _, ok := h.aclPolicy.TagOwners[alias]; !ok { return nil, errorInvalidTag } @@ -197,10 +197,10 @@ func (h *Headscale) expandAlias(s string) ([]string, error) { return nil, err } ips := []string{} - for _, m := range machines { + for _, machine := range machines { hostinfo := tailcfg.Hostinfo{} - if len(m.HostInfo) != 0 { - hi, err := m.HostInfo.MarshalJSON() + if len(machine.HostInfo) != 0 { + hi, err := machine.HostInfo.MarshalJSON() if err != nil { return nil, err } @@ -211,8 +211,8 @@ func (h *Headscale) expandAlias(s string) ([]string, error) { // FIXME: Check TagOwners allows this for _, t := range hostinfo.RequestTags { - if s[4:] == t { - ips = append(ips, m.IPAddress) + if alias[4:] == t { + ips = append(ips, machine.IPAddress) break } @@ -223,7 +223,7 @@ func (h *Headscale) expandAlias(s string) ([]string, error) { return ips, nil } - n, err := h.GetNamespace(s) + n, err := h.GetNamespace(alias) if err == nil { nodes, err := h.ListMachinesInNamespace(n.Name) if err != nil { @@ -237,16 +237,16 @@ func (h *Headscale) expandAlias(s string) ([]string, error) { return ips, nil } - if h, ok := h.aclPolicy.Hosts[s]; ok { + if h, ok := h.aclPolicy.Hosts[alias]; ok { return []string{h.String()}, nil } - ip, err := netaddr.ParseIP(s) + ip, err := netaddr.ParseIP(alias) if err == nil { return []string{ip.String()}, nil } - cidr, err := netaddr.ParseIPPrefix(s) + cidr, err := netaddr.ParseIPPrefix(alias) if err == nil { return []string{cidr.String()}, nil } @@ -254,25 +254,25 @@ func (h *Headscale) expandAlias(s string) ([]string, error) { return nil, errorInvalidUserSection } -func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) { - if s == "*" { +func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) { + if portsStr == "*" { return &[]tailcfg.PortRange{ {First: PORT_RANGE_BEGIN, Last: PORT_RANGE_END}, }, nil } ports := []tailcfg.PortRange{} - for _, p := range strings.Split(s, ",") { - rang := strings.Split(p, "-") + for _, portStr := range strings.Split(portsStr, ",") { + rang := strings.Split(portStr, "-") switch len(rang) { case 1: - pi, err := strconv.ParseUint(rang[0], BASE_10, BIT_SIZE_16) + port, err := strconv.ParseUint(rang[0], BASE_10, BIT_SIZE_16) if err != nil { return nil, err } ports = append(ports, tailcfg.PortRange{ - First: uint16(pi), - Last: uint16(pi), + First: uint16(port), + Last: uint16(port), }) case EXPECTED_TOKEN_ITEMS: diff --git a/acls_types.go b/acls_types.go index 8611d90..f6557bf 100644 --- a/acls_types.go +++ b/acls_types.go @@ -41,37 +41,37 @@ type ACLTest struct { } // UnmarshalJSON allows to parse the Hosts directly into netaddr objects. -func (h *Hosts) UnmarshalJSON(data []byte) error { - hosts := Hosts{} - hs := make(map[string]string) +func (hosts *Hosts) UnmarshalJSON(data []byte) error { + newHosts := Hosts{} + hostIpPrefixMap := make(map[string]string) ast, err := hujson.Parse(data) if err != nil { return err } ast.Standardize() data = ast.Pack() - err = json.Unmarshal(data, &hs) + err = json.Unmarshal(data, &hostIpPrefixMap) if err != nil { return err } - for k, v := range hs { - if !strings.Contains(v, "/") { - v += "/32" + for host, prefixStr := range hostIpPrefixMap { + if !strings.Contains(prefixStr, "/") { + prefixStr += "/32" } - prefix, err := netaddr.ParseIPPrefix(v) + prefix, err := netaddr.ParseIPPrefix(prefixStr) if err != nil { return err } - hosts[k] = prefix + newHosts[host] = prefix } - *h = hosts + *hosts = newHosts return nil } // IsZero is perhaps a bit naive here. -func (p ACLPolicy) IsZero() bool { - if len(p.Groups) == 0 && len(p.Hosts) == 0 && len(p.ACLs) == 0 { +func (policy ACLPolicy) IsZero() bool { + if len(policy.Groups) == 0 && len(policy.Hosts) == 0 && len(policy.ACLs) == 0 { return true } diff --git a/api.go b/api.go index 211b486..fb8c49c 100644 --- a/api.go +++ b/api.go @@ -22,21 +22,25 @@ const RESERVED_RESPONSE_HEADER_SIZE = 4 // KeyHandler provides the Headscale pub key // Listens in /key. -func (h *Headscale) KeyHandler(c *gin.Context) { - c.Data(http.StatusOK, "text/plain; charset=utf-8", []byte(h.publicKey.HexString())) +func (h *Headscale) KeyHandler(ctx *gin.Context) { + ctx.Data( + http.StatusOK, + "text/plain; charset=utf-8", + []byte(h.publicKey.HexString()), + ) } // RegisterWebAPI shows a simple message in the browser to point to the CLI // Listens in /register. -func (h *Headscale) RegisterWebAPI(c *gin.Context) { - mKeyStr := c.Query("key") - if mKeyStr == "" { - c.String(http.StatusBadRequest, "Wrong params") +func (h *Headscale) RegisterWebAPI(ctx *gin.Context) { + machineKeyStr := ctx.Query("key") + if machineKeyStr == "" { + ctx.String(http.StatusBadRequest, "Wrong params") return } - c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(` + ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`

headscale

@@ -53,45 +57,45 @@ func (h *Headscale) RegisterWebAPI(c *gin.Context) { - `, mKeyStr))) + `, machineKeyStr))) } // RegistrationHandler handles the actual registration process of a machine // Endpoint /machine/:id. -func (h *Headscale) RegistrationHandler(c *gin.Context) { - body, _ := io.ReadAll(c.Request.Body) - mKeyStr := c.Param("id") - mKey, err := wgkey.ParseHex(mKeyStr) +func (h *Headscale) RegistrationHandler(ctx *gin.Context) { + body, _ := io.ReadAll(ctx.Request.Body) + machineKeyStr := ctx.Param("id") + machineKey, err := wgkey.ParseHex(machineKeyStr) if err != nil { log.Error(). Str("handler", "Registration"). Err(err). Msg("Cannot parse machine key") machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() - c.String(http.StatusInternalServerError, "Sad!") + ctx.String(http.StatusInternalServerError, "Sad!") return } req := tailcfg.RegisterRequest{} - err = decode(body, &req, &mKey, h.privateKey) + err = decode(body, &req, &machineKey, h.privateKey) if err != nil { log.Error(). Str("handler", "Registration"). Err(err). Msg("Cannot decode message") machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() - c.String(http.StatusInternalServerError, "Very sad!") + ctx.String(http.StatusInternalServerError, "Very sad!") return } now := time.Now().UTC() - m, err := h.GetMachineByMachineKey(mKey.HexString()) + machine, err := h.GetMachineByMachineKey(machineKey.HexString()) if errors.Is(err, gorm.ErrRecordNotFound) { log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine") newMachine := Machine{ Expiry: &time.Time{}, - MachineKey: mKey.HexString(), + MachineKey: machineKey.HexString(), Name: req.Hostinfo.Hostname, } if err := h.db.Create(&newMachine).Error; err != nil { @@ -99,16 +103,16 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Str("handler", "Registration"). Err(err). Msg("Could not create row") - machineRegistrations.WithLabelValues("unknown", "web", "error", m.Namespace.Name). + machineRegistrations.WithLabelValues("unknown", "web", "error", machine.Namespace.Name). Inc() return } - m = &newMachine + machine = &newMachine } - if !m.Registered && req.Auth.AuthKey != "" { - h.handleAuthKey(c, h.db, mKey, req, *m) + if !machine.Registered && req.Auth.AuthKey != "" { + h.handleAuthKey(ctx, h.db, machineKey, req, *machine) return } @@ -116,63 +120,63 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { resp := tailcfg.RegisterResponse{} // We have the updated key! - if m.NodeKey == wgkey.Key(req.NodeKey).HexString() { + if machine.NodeKey == wgkey.Key(req.NodeKey).HexString() { // The client sends an Expiry in the past if the client is requesting to expire the key (aka logout) // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) { log.Info(). Str("handler", "Registration"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Client requested logout") - m.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired - h.db.Save(&m) + machine.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired + h.db.Save(&machine) resp.AuthURL = "" resp.MachineAuthorized = false - resp.User = *m.Namespace.toUser() - respBody, err := encode(resp, &mKey, h.privateKey) + resp.User = *machine.Namespace.toUser() + respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). Str("handler", "Registration"). Err(err). Msg("Cannot encode message") - c.String(http.StatusInternalServerError, "") + ctx.String(http.StatusInternalServerError, "") return } - c.Data(http.StatusOK, "application/json; charset=utf-8", respBody) + ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) return } - if m.Registered && m.Expiry.UTC().After(now) { + if machine.Registered && machine.Expiry.UTC().After(now) { // The machine registration is valid, respond with redirect to /map log.Debug(). Str("handler", "Registration"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Client is registered and we have the current NodeKey. All clear to /map") resp.AuthURL = "" resp.MachineAuthorized = true - resp.User = *m.Namespace.toUser() - resp.Login = *m.Namespace.toLogin() + resp.User = *machine.Namespace.toUser() + resp.Login = *machine.Namespace.toLogin() - respBody, err := encode(resp, &mKey, h.privateKey) + respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). Str("handler", "Registration"). Err(err). Msg("Cannot encode message") - machineRegistrations.WithLabelValues("update", "web", "error", m.Namespace.Name). + machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name). Inc() - c.String(http.StatusInternalServerError, "") + ctx.String(http.StatusInternalServerError, "") return } - machineRegistrations.WithLabelValues("update", "web", "success", m.Namespace.Name). + machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name). Inc() - c.Data(http.StatusOK, "application/json; charset=utf-8", respBody) + ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) return } @@ -180,15 +184,15 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { // The client has registered before, but has expired log.Debug(). Str("handler", "Registration"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Machine registration has expired. Sending a authurl to register") if h.cfg.OIDC.Issuer != "" { resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", - strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) + strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString()) } else { resp.AuthURL = fmt.Sprintf("%s/register?key=%s", - strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) + strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString()) } // When a client connects, it may request a specific expiry time in its @@ -197,51 +201,52 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { // into two steps (which cant pass arbitrary data between them easily) and needs to be // retrieved again after the user has authenticated. After the authentication flow // completes, RequestedExpiry is copied into Expiry. - m.RequestedExpiry = &req.Expiry + machine.RequestedExpiry = &req.Expiry - h.db.Save(&m) + h.db.Save(&machine) - respBody, err := encode(resp, &mKey, h.privateKey) + respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). Str("handler", "Registration"). Err(err). Msg("Cannot encode message") - machineRegistrations.WithLabelValues("new", "web", "error", m.Namespace.Name). + machineRegistrations.WithLabelValues("new", "web", "error", machine.Namespace.Name). Inc() - c.String(http.StatusInternalServerError, "") + ctx.String(http.StatusInternalServerError, "") return } - machineRegistrations.WithLabelValues("new", "web", "success", m.Namespace.Name). + machineRegistrations.WithLabelValues("new", "web", "success", machine.Namespace.Name). Inc() - c.Data(http.StatusOK, "application/json; charset=utf-8", respBody) + ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) return } // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration - if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && m.Expiry.UTC().After(now) { + if machine.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && + machine.Expiry.UTC().After(now) { log.Debug(). Str("handler", "Registration"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("We have the OldNodeKey in the database. This is a key refresh") - m.NodeKey = wgkey.Key(req.NodeKey).HexString() - h.db.Save(&m) + machine.NodeKey = wgkey.Key(req.NodeKey).HexString() + h.db.Save(&machine) resp.AuthURL = "" - resp.User = *m.Namespace.toUser() - respBody, err := encode(resp, &mKey, h.privateKey) + resp.User = *machine.Namespace.toUser() + respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). Str("handler", "Registration"). Err(err). Msg("Cannot encode message") - c.String(http.StatusInternalServerError, "Extremely sad!") + ctx.String(http.StatusInternalServerError, "Extremely sad!") return } - c.Data(http.StatusOK, "application/json; charset=utf-8", respBody) + ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) return } @@ -249,47 +254,47 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { // The machine registration is new, redirect the client to the registration URL log.Debug(). Str("handler", "Registration"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("The node is sending us a new NodeKey, sending auth url") if h.cfg.OIDC.Issuer != "" { resp.AuthURL = fmt.Sprintf( "%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), - mKey.HexString(), + machineKey.HexString(), ) } else { resp.AuthURL = fmt.Sprintf("%s/register?key=%s", - strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) + strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString()) } // save the requested expiry time for retrieval later in the authentication flow - m.RequestedExpiry = &req.Expiry - m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey - h.db.Save(&m) + machine.RequestedExpiry = &req.Expiry + machine.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey + h.db.Save(&machine) - respBody, err := encode(resp, &mKey, h.privateKey) + respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). Str("handler", "Registration"). Err(err). Msg("Cannot encode message") - c.String(http.StatusInternalServerError, "") + ctx.String(http.StatusInternalServerError, "") return } - c.Data(http.StatusOK, "application/json; charset=utf-8", respBody) + ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) } func (h *Headscale) getMapResponse( - mKey wgkey.Key, + machineKey wgkey.Key, req tailcfg.MapRequest, - m *Machine, + machine *Machine, ) ([]byte, error) { log.Trace(). Str("func", "getMapResponse"). Str("machine", req.Hostinfo.Hostname). Msg("Creating Map response") - node, err := m.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true) + node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true) if err != nil { log.Error(). Str("func", "getMapResponse"). @@ -299,7 +304,7 @@ func (h *Headscale) getMapResponse( return nil, err } - peers, err := h.getPeers(m) + peers, err := h.getPeers(machine) if err != nil { log.Error(). Str("func", "getMapResponse"). @@ -309,7 +314,7 @@ func (h *Headscale) getMapResponse( return nil, err } - profiles := getMapResponseUserProfiles(*m, peers) + profiles := getMapResponseUserProfiles(*machine, peers) nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true) if err != nil { @@ -324,7 +329,7 @@ func (h *Headscale) getMapResponse( dnsConfig := getMapResponseDNSConfig( h.cfg.DNSConfig, h.cfg.BaseDomain, - *m, + *machine, peers, ) @@ -351,12 +356,12 @@ func (h *Headscale) getMapResponse( encoder, _ := zstd.NewWriter(nil) srcCompressed := encoder.EncodeAll(src, nil) - respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey) + respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey) if err != nil { return nil, err } } else { - respBody, err = encode(resp, &mKey, h.privateKey) + respBody, err = encode(resp, &machineKey, h.privateKey) if err != nil { return nil, err } @@ -370,24 +375,24 @@ func (h *Headscale) getMapResponse( } func (h *Headscale) getMapKeepAliveResponse( - mKey wgkey.Key, - req tailcfg.MapRequest, + machineKey wgkey.Key, + mapRequest tailcfg.MapRequest, ) ([]byte, error) { - resp := tailcfg.MapResponse{ + mapResponse := tailcfg.MapResponse{ KeepAlive: true, } var respBody []byte var err error - if req.Compress == "zstd" { - src, _ := json.Marshal(resp) + if mapRequest.Compress == "zstd" { + src, _ := json.Marshal(mapResponse) encoder, _ := zstd.NewWriter(nil) srcCompressed := encoder.EncodeAll(src, nil) - respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey) + respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey) if err != nil { return nil, err } } else { - respBody, err = encode(resp, &mKey, h.privateKey) + respBody, err = encode(mapResponse, &machineKey, h.privateKey) if err != nil { return nil, err } @@ -400,22 +405,22 @@ func (h *Headscale) getMapKeepAliveResponse( } func (h *Headscale) handleAuthKey( - c *gin.Context, + ctx *gin.Context, db *gorm.DB, idKey wgkey.Key, - req tailcfg.RegisterRequest, - m Machine, + reqisterRequest tailcfg.RegisterRequest, + machine Machine, ) { log.Debug(). Str("func", "handleAuthKey"). - Str("machine", req.Hostinfo.Hostname). - Msgf("Processing auth key for %s", req.Hostinfo.Hostname) + Str("machine", reqisterRequest.Hostinfo.Hostname). + Msgf("Processing auth key for %s", reqisterRequest.Hostinfo.Hostname) resp := tailcfg.RegisterResponse{} - pak, err := h.checkKeyValidity(req.Auth.AuthKey) + pak, err := h.checkKeyValidity(reqisterRequest.Auth.AuthKey) if err != nil { log.Error(). Str("func", "handleAuthKey"). - Str("machine", m.Name). + Str("machine", machine.Name). Err(err). Msg("Failed authentication via AuthKey") resp.MachineAuthorized = false @@ -423,21 +428,21 @@ func (h *Headscale) handleAuthKey( if err != nil { log.Error(). Str("func", "handleAuthKey"). - Str("machine", m.Name). + Str("machine", machine.Name). Err(err). Msg("Cannot encode message") - c.String(http.StatusInternalServerError, "") - machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name). + ctx.String(http.StatusInternalServerError, "") + machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). Inc() return } - c.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody) + ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody) log.Error(). Str("func", "handleAuthKey"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Failed authentication via AuthKey") - machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name). + machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). Inc() return @@ -445,32 +450,34 @@ func (h *Headscale) handleAuthKey( log.Debug(). Str("func", "handleAuthKey"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Authentication key was valid, proceeding to acquire an IP address") ip, err := h.getAvailableIP() if err != nil { log.Error(). Str("func", "handleAuthKey"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Failed to find an available IP") - machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name). + machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). Inc() return } log.Info(). Str("func", "handleAuthKey"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("ip", ip.String()). - Msgf("Assigning %s to %s", ip, m.Name) + Msgf("Assigning %s to %s", ip, machine.Name) - m.AuthKeyID = uint(pak.ID) - m.IPAddress = ip.String() - m.NamespaceID = pak.NamespaceID - m.NodeKey = wgkey.Key(req.NodeKey).HexString() // we update it just in case - m.Registered = true - m.RegisterMethod = "authKey" - db.Save(&m) + machine.AuthKeyID = uint(pak.ID) + machine.IPAddress = ip.String() + machine.NamespaceID = pak.NamespaceID + machine.NodeKey = wgkey.Key(reqisterRequest.NodeKey). + HexString() + // we update it just in case + machine.Registered = true + machine.RegisterMethod = "authKey" + db.Save(&machine) pak.Used = true db.Save(&pak) @@ -481,21 +488,21 @@ func (h *Headscale) handleAuthKey( if err != nil { log.Error(). Str("func", "handleAuthKey"). - Str("machine", m.Name). + Str("machine", machine.Name). Err(err). Msg("Cannot encode message") - machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name). + machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). Inc() - c.String(http.StatusInternalServerError, "Extremely sad!") + ctx.String(http.StatusInternalServerError, "Extremely sad!") return } - machineRegistrations.WithLabelValues("new", "authkey", "success", m.Namespace.Name). + machineRegistrations.WithLabelValues("new", "authkey", "success", machine.Namespace.Name). Inc() - c.Data(http.StatusOK, "application/json; charset=utf-8", respBody) + ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) log.Info(). Str("func", "handleAuthKey"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("ip", ip.String()). Msg("Successfully authenticated via AuthKey") } diff --git a/app.go b/app.go index af51efd..a0aa238 100644 --- a/app.go +++ b/app.go @@ -169,7 +169,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) { return nil, errors.New("unsupported DB") } - h := Headscale{ + app := Headscale{ cfg: cfg, dbType: cfg.DBtype, dbString: dbString, @@ -178,32 +178,32 @@ func NewHeadscale(cfg Config) (*Headscale, error) { aclRules: tailcfg.FilterAllowAll, // default allowall } - err = h.initDB() + err = app.initDB() if err != nil { return nil, err } if cfg.OIDC.Issuer != "" { - err = h.initOIDC() + err = app.initOIDC() if err != nil { return nil, err } } - if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS + if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS magicDNSDomains := generateMagicDNSRootDomains( - h.cfg.IPPrefix, + app.cfg.IPPrefix, ) // we might have routes already from Split DNS - if h.cfg.DNSConfig.Routes == nil { - h.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver) + if app.cfg.DNSConfig.Routes == nil { + app.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver) } for _, d := range magicDNSDomains { - h.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil + app.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil } } - return &h, nil + return &app, nil } // Redirect to our TLS url. @@ -229,35 +229,37 @@ func (h *Headscale) expireEphemeralNodesWorker() { return } - for _, ns := range namespaces { - machines, err := h.ListMachinesInNamespace(ns.Name) + for _, namespace := range namespaces { + machines, err := h.ListMachinesInNamespace(namespace.Name) if err != nil { log.Error(). Err(err). - Str("namespace", ns.Name). + Str("namespace", namespace.Name). Msg("Error listing machines in namespace") return } - for _, m := range machines { - if m.AuthKey != nil && m.LastSeen != nil && m.AuthKey.Ephemeral && - time.Now().After(m.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) { + for _, machine := range machines { + if machine.AuthKey != nil && machine.LastSeen != nil && + machine.AuthKey.Ephemeral && + time.Now(). + After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) { log.Info(). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Ephemeral client removed from database") - err = h.db.Unscoped().Delete(m).Error + err = h.db.Unscoped().Delete(machine).Error if err != nil { log.Error(). Err(err). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("🤮 Cannot delete ephemeral machine from the database") } } } - h.setLastStateChangeToNow(ns.Name) + h.setLastStateChangeToNow(namespace.Name) } } @@ -284,18 +286,18 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, // with the "legacy" database-based client // It is also neede for grpc-gateway to be able to connect to // the server - p, _ := peer.FromContext(ctx) + client, _ := peer.FromContext(ctx) log.Trace(). Caller(). - Str("client_address", p.Addr.String()). + Str("client_address", client.Addr.String()). Msg("Client is trying to authenticate") - md, ok := metadata.FromIncomingContext(ctx) + meta, ok := metadata.FromIncomingContext(ctx) if !ok { log.Error(). Caller(). - Str("client_address", p.Addr.String()). + Str("client_address", client.Addr.String()). Msg("Retrieving metadata is failed") return ctx, status.Errorf( @@ -304,11 +306,11 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, ) } - authHeader, ok := md["authorization"] + authHeader, ok := meta["authorization"] if !ok { log.Error(). Caller(). - Str("client_address", p.Addr.String()). + Str("client_address", client.Addr.String()). Msg("Authorization token is not supplied") return ctx, status.Errorf( @@ -322,7 +324,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, if !strings.HasPrefix(token, AUTH_PREFIX) { log.Error(). Caller(). - Str("client_address", p.Addr.String()). + Str("client_address", client.Addr.String()). Msg(`missing "Bearer " prefix in "Authorization" header`) return ctx, status.Error( @@ -353,25 +355,25 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, // return handler(ctx, req) } -func (h *Headscale) httpAuthenticationMiddleware(c *gin.Context) { +func (h *Headscale) httpAuthenticationMiddleware(ctx *gin.Context) { log.Trace(). Caller(). - Str("client_address", c.ClientIP()). + Str("client_address", ctx.ClientIP()). Msg("HTTP authentication invoked") - authHeader := c.GetHeader("authorization") + authHeader := ctx.GetHeader("authorization") if !strings.HasPrefix(authHeader, AUTH_PREFIX) { log.Error(). Caller(). - Str("client_address", c.ClientIP()). + Str("client_address", ctx.ClientIP()). Msg(`missing "Bearer " prefix in "Authorization" header`) - c.AbortWithStatus(http.StatusUnauthorized) + ctx.AbortWithStatus(http.StatusUnauthorized) return } - c.AbortWithStatus(http.StatusUnauthorized) + ctx.AbortWithStatus(http.StatusUnauthorized) // TODO(kradalby): Implement API key backend // Currently all traffic is unauthorized, this is intentional to allow @@ -438,9 +440,9 @@ func (h *Headscale) Serve() error { // Create the cmux object that will multiplex 2 protocols on the same port. // The two following listeners will be served on the same port below gracefully. - m := cmux.New(networkListener) + networkMutex := cmux.New(networkListener) // Match gRPC requests here - grpcListener := m.MatchWithWriters( + grpcListener := networkMutex.MatchWithWriters( cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"), cmux.HTTP2MatchHeaderFieldSendSettings( "content-type", @@ -448,7 +450,7 @@ func (h *Headscale) Serve() error { ), ) // Otherwise match regular http requests. - httpListener := m.Match(cmux.Any()) + httpListener := networkMutex.Match(cmux.Any()) grpcGatewayMux := runtime.NewServeMux() @@ -471,33 +473,33 @@ func (h *Headscale) Serve() error { return err } - r := gin.Default() + router := gin.Default() - p := ginprometheus.NewPrometheus("gin") - p.Use(r) + prometheus := ginprometheus.NewPrometheus("gin") + prometheus.Use(router) - r.GET( + router.GET( "/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) }, ) - r.GET("/key", h.KeyHandler) - r.GET("/register", h.RegisterWebAPI) - r.POST("/machine/:id/map", h.PollNetMapHandler) - r.POST("/machine/:id", h.RegistrationHandler) - r.GET("/oidc/register/:mkey", h.RegisterOIDC) - r.GET("/oidc/callback", h.OIDCCallback) - r.GET("/apple", h.AppleMobileConfig) - r.GET("/apple/:platform", h.ApplePlatformConfig) - r.GET("/swagger", SwaggerUI) - r.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1) + router.GET("/key", h.KeyHandler) + router.GET("/register", h.RegisterWebAPI) + router.POST("/machine/:id/map", h.PollNetMapHandler) + router.POST("/machine/:id", h.RegistrationHandler) + router.GET("/oidc/register/:mkey", h.RegisterOIDC) + router.GET("/oidc/callback", h.OIDCCallback) + router.GET("/apple", h.AppleMobileConfig) + router.GET("/apple/:platform", h.ApplePlatformConfig) + router.GET("/swagger", SwaggerUI) + router.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1) - api := r.Group("/api") + api := router.Group("/api") api.Use(h.httpAuthenticationMiddleware) { api.Any("/v1/*any", gin.WrapF(grpcGatewayMux.ServeHTTP)) } - r.NoRoute(stdoutHandler) + router.NoRoute(stdoutHandler) // Fetch an initial DERP Map before we start serving h.DERPMap = GetDERPMap(h.cfg.DERP) @@ -514,7 +516,7 @@ func (h *Headscale) Serve() error { httpServer := &http.Server{ Addr: h.cfg.Addr, - Handler: r, + Handler: router, ReadTimeout: HTTP_READ_TIMEOUT, // Go does not handle timeouts in HTTP very well, and there is // no good way to handle streaming timeouts, therefore we need to @@ -561,29 +563,29 @@ func (h *Headscale) Serve() error { reflection.Register(grpcServer) reflection.Register(grpcSocket) - g := new(errgroup.Group) + errorGroup := new(errgroup.Group) - g.Go(func() error { return grpcSocket.Serve(socketListener) }) + errorGroup.Go(func() error { return grpcSocket.Serve(socketListener) }) // TODO(kradalby): Verify if we need the same TLS setup for gRPC as HTTP - g.Go(func() error { return grpcServer.Serve(grpcListener) }) + errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) }) if tlsConfig != nil { - g.Go(func() error { + errorGroup.Go(func() error { tlsl := tls.NewListener(httpListener, tlsConfig) return httpServer.Serve(tlsl) }) } else { - g.Go(func() error { return httpServer.Serve(httpListener) }) + errorGroup.Go(func() error { return httpServer.Serve(httpListener) }) } - g.Go(func() error { return m.Serve() }) + errorGroup.Go(func() error { return networkMutex.Serve() }) log.Info(). Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr) - return g.Wait() + return errorGroup.Wait() } func (h *Headscale) getTLSSettings() (*tls.Config, error) { @@ -594,7 +596,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { Msg("Listening with TLS but ServerURL does not start with https://") } - m := autocert.Manager{ + certManager := autocert.Manager{ Prompt: autocert.AcceptTOS, HostPolicy: autocert.HostWhitelist(h.cfg.TLSLetsEncryptHostname), Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir), @@ -609,7 +611,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { // Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737) // The RFC requires that the validation is done on port 443; in other words, headscale // must be reachable on port 443. - return m.TLSConfig(), nil + return certManager.TLSConfig(), nil case "HTTP-01": // Configuration via autocert with HTTP-01. This requires listening on @@ -617,11 +619,11 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { // service, which can be configured to run on any other port. go func() { log.Fatal(). - Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, m.HTTPHandler(http.HandlerFunc(h.redirect)))). + Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))). Msg("failed to set up a HTTP server") }() - return m.TLSConfig(), nil + return certManager.TLSConfig(), nil default: return nil, errors.New("unknown value for TLSLetsEncryptChallengeType") @@ -676,13 +678,13 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time { } } -func stdoutHandler(c *gin.Context) { - b, _ := io.ReadAll(c.Request.Body) +func stdoutHandler(ctx *gin.Context) { + body, _ := io.ReadAll(ctx.Request.Body) log.Trace(). - Interface("header", c.Request.Header). - Interface("proto", c.Request.Proto). - Interface("url", c.Request.URL). - Bytes("body", b). + Interface("header", ctx.Request.Header). + Interface("proto", ctx.Request.Proto). + Interface("url", ctx.Request.URL). + Bytes("body", body). Msg("Request did not match") } diff --git a/apple_mobileconfig.go b/apple_mobileconfig.go index ff3ef3c..8989c27 100644 --- a/apple_mobileconfig.go +++ b/apple_mobileconfig.go @@ -12,8 +12,8 @@ import ( // AppleMobileConfig shows a simple message in the browser to point to the CLI // Listens in /register. -func (h *Headscale) AppleMobileConfig(c *gin.Context) { - t := template.Must(template.New("apple").Parse(` +func (h *Headscale) AppleMobileConfig(ctx *gin.Context) { + appleTemplate := template.Must(template.New("apple").Parse(`

Apple configuration profiles

@@ -67,12 +67,12 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) { } var payload bytes.Buffer - if err := t.Execute(&payload, config); err != nil { + if err := appleTemplate.Execute(&payload, config); err != nil { log.Error(). Str("handler", "AppleMobileConfig"). Err(err). Msg("Could not render Apple index template") - c.Data( + ctx.Data( http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple index template"), @@ -81,11 +81,11 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) { return } - c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes()) + ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes()) } -func (h *Headscale) ApplePlatformConfig(c *gin.Context) { - platform := c.Param("platform") +func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) { + platform := ctx.Param("platform") id, err := uuid.NewV4() if err != nil { @@ -93,7 +93,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) { Str("handler", "ApplePlatformConfig"). Err(err). Msg("Failed not create UUID") - c.Data( + ctx.Data( http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Failed to create UUID"), @@ -108,7 +108,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) { Str("handler", "ApplePlatformConfig"). Err(err). Msg("Failed not create UUID") - c.Data( + ctx.Data( http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Failed to create UUID"), @@ -131,7 +131,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) { Str("handler", "ApplePlatformConfig"). Err(err). Msg("Could not render Apple macOS template") - c.Data( + ctx.Data( http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple macOS template"), @@ -145,7 +145,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) { Str("handler", "ApplePlatformConfig"). Err(err). Msg("Could not render Apple iOS template") - c.Data( + ctx.Data( http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple iOS template"), @@ -154,7 +154,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) { return } default: - c.Data( + ctx.Data( http.StatusOK, "text/html; charset=utf-8", []byte("Invalid platform, only ios and macos is supported"), @@ -175,7 +175,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) { Str("handler", "ApplePlatformConfig"). Err(err). Msg("Could not render Apple platform template") - c.Data( + ctx.Data( http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple platform template"), @@ -184,7 +184,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) { return } - c.Data( + ctx.Data( http.StatusOK, "application/x-apple-aspen-config; charset=utf-8", content.Bytes(), diff --git a/cmd/headscale/cli/namespaces.go b/cmd/headscale/cli/namespaces.go index 8c69ac5..8c6f10a 100644 --- a/cmd/headscale/cli/namespaces.go +++ b/cmd/headscale/cli/namespaces.go @@ -167,10 +167,10 @@ var listNamespacesCmd = &cobra.Command{ return } - d := pterm.TableData{{"ID", "Name", "Created"}} + tableData := pterm.TableData{{"ID", "Name", "Created"}} for _, namespace := range response.GetNamespaces() { - d = append( - d, + tableData = append( + tableData, []string{ namespace.GetId(), namespace.GetName(), @@ -178,7 +178,7 @@ var listNamespacesCmd = &cobra.Command{ }, ) } - err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() + err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() if err != nil { ErrorOutput( err, diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index 83d6b1f..c644ad9 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -157,14 +157,14 @@ var listNodesCmd = &cobra.Command{ return } - d, err := nodesToPtables(namespace, response.Machines) + tableData, err := nodesToPtables(namespace, response.Machines) if err != nil { ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) return } - err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() + err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() if err != nil { ErrorOutput( err, @@ -183,7 +183,7 @@ var deleteNodeCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - id, err := cmd.Flags().GetInt("identifier") + identifier, err := cmd.Flags().GetInt("identifier") if err != nil { ErrorOutput( err, @@ -199,7 +199,7 @@ var deleteNodeCmd = &cobra.Command{ defer conn.Close() getRequest := &v1.GetMachineRequest{ - MachineId: uint64(id), + MachineId: uint64(identifier), } getResponse, err := client.GetMachine(ctx, getRequest) @@ -217,7 +217,7 @@ var deleteNodeCmd = &cobra.Command{ } deleteRequest := &v1.DeleteMachineRequest{ - MachineId: uint64(id), + MachineId: uint64(identifier), } confirm := false @@ -280,7 +280,7 @@ func sharingWorker( defer cancel() defer conn.Close() - id, err := cmd.Flags().GetInt("identifier") + identifier, err := cmd.Flags().GetInt("identifier") if err != nil { ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output) @@ -288,7 +288,7 @@ func sharingWorker( } machineRequest := &v1.GetMachineRequest{ - MachineId: uint64(id), + MachineId: uint64(identifier), } machineResponse, err := client.GetMachine(ctx, machineRequest) @@ -402,7 +402,7 @@ func nodesToPtables( currentNamespace string, machines []*v1.Machine, ) (pterm.TableData, error) { - d := pterm.TableData{ + tableData := pterm.TableData{ { "ID", "Name", @@ -448,8 +448,8 @@ func nodesToPtables( // Shared into this namespace namespace = pterm.LightYellow(machine.Namespace.Name) } - d = append( - d, + tableData = append( + tableData, []string{ strconv.FormatUint(machine.Id, headscale.BASE_10), machine.Name, @@ -463,5 +463,5 @@ func nodesToPtables( ) } - return d, nil + return tableData, nil } diff --git a/cmd/headscale/cli/preauthkeys.go b/cmd/headscale/cli/preauthkeys.go index 13a3094..b28e600 100644 --- a/cmd/headscale/cli/preauthkeys.go +++ b/cmd/headscale/cli/preauthkeys.go @@ -45,7 +45,7 @@ var listPreAuthKeys = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - n, err := cmd.Flags().GetString("namespace") + namespace, err := cmd.Flags().GetString("namespace") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) @@ -57,7 +57,7 @@ var listPreAuthKeys = &cobra.Command{ defer conn.Close() request := &v1.ListPreAuthKeysRequest{ - Namespace: n, + Namespace: namespace, } response, err := client.ListPreAuthKeys(ctx, request) @@ -77,34 +77,34 @@ var listPreAuthKeys = &cobra.Command{ return } - d := pterm.TableData{ + tableData := pterm.TableData{ {"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"}, } - for _, k := range response.PreAuthKeys { + for _, key := range response.PreAuthKeys { expiration := "-" - if k.GetExpiration() != nil { - expiration = k.Expiration.AsTime().Format("2006-01-02 15:04:05") + if key.GetExpiration() != nil { + expiration = key.Expiration.AsTime().Format("2006-01-02 15:04:05") } var reusable string - if k.GetEphemeral() { + if key.GetEphemeral() { reusable = "N/A" } else { - reusable = fmt.Sprintf("%v", k.GetReusable()) + reusable = fmt.Sprintf("%v", key.GetReusable()) } - d = append(d, []string{ - k.GetId(), - k.GetKey(), + tableData = append(tableData, []string{ + key.GetId(), + key.GetKey(), reusable, - strconv.FormatBool(k.GetEphemeral()), - strconv.FormatBool(k.GetUsed()), + strconv.FormatBool(key.GetEphemeral()), + strconv.FormatBool(key.GetUsed()), expiration, - k.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"), + key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"), }) } - err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() + err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() if err != nil { ErrorOutput( err, diff --git a/cmd/headscale/cli/routes.go b/cmd/headscale/cli/routes.go index 2ad64bc..7d0fcbf 100644 --- a/cmd/headscale/cli/routes.go +++ b/cmd/headscale/cli/routes.go @@ -81,14 +81,14 @@ var listRoutesCmd = &cobra.Command{ return } - d := routesToPtables(response.Routes) + tableData := routesToPtables(response.Routes) if err != nil { ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) return } - err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() + err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() if err != nil { ErrorOutput( err, @@ -162,14 +162,14 @@ omit the route you do not want to enable. return } - d := routesToPtables(response.Routes) + tableData := routesToPtables(response.Routes) if err != nil { ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) return } - err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() + err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() if err != nil { ErrorOutput( err, @@ -184,15 +184,15 @@ omit the route you do not want to enable. // routesToPtables converts the list of routes to a nice table. func routesToPtables(routes *v1.Routes) pterm.TableData { - d := pterm.TableData{{"Route", "Enabled"}} + tableData := pterm.TableData{{"Route", "Enabled"}} for _, route := range routes.GetAdvertisedRoutes() { enabled := isStringInSlice(routes.EnabledRoutes, route) - d = append(d, []string{route, strconv.FormatBool(enabled)}) + tableData = append(tableData, []string{route, strconv.FormatBool(enabled)}) } - return d + return tableData } func isStringInSlice(strs []string, s string) bool { diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index 343ebe3..ad49c01 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -318,7 +318,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) { cfg.OIDC.MatchMap = loadOIDCMatchMap() - h, err := headscale.NewHeadscale(cfg) + app, err := headscale.NewHeadscale(cfg) if err != nil { return nil, err } @@ -327,7 +327,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) { if viper.GetString("acl_policy_path") != "" { aclPath := absPath(viper.GetString("acl_policy_path")) - err = h.LoadACLPolicy(aclPath) + err = app.LoadACLPolicy(aclPath) if err != nil { log.Error(). Str("path", aclPath). @@ -336,7 +336,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) { } } - return h, nil + return app, nil } func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) { diff --git a/dns.go b/dns.go index d7480b1..cb986cf 100644 --- a/dns.go +++ b/dns.go @@ -79,7 +79,7 @@ func generateMagicDNSRootDomains( func getMapResponseDNSConfig( dnsConfigOrig *tailcfg.DNSConfig, baseDomain string, - m Machine, + machine Machine, peers Machines, ) *tailcfg.DNSConfig { var dnsConfig *tailcfg.DNSConfig @@ -88,11 +88,11 @@ func getMapResponseDNSConfig( dnsConfig = dnsConfigOrig.Clone() dnsConfig.Domains = append( dnsConfig.Domains, - fmt.Sprintf("%s.%s", m.Namespace.Name, baseDomain), + fmt.Sprintf("%s.%s", machine.Namespace.Name, baseDomain), ) namespaceSet := set.New(set.ThreadSafe) - namespaceSet.Add(m.Namespace) + namespaceSet.Add(machine.Namespace) for _, p := range peers { namespaceSet.Add(p.Namespace) } diff --git a/machine.go b/machine.go index e8d2f72..d1f27cc 100644 --- a/machine.go +++ b/machine.go @@ -56,21 +56,21 @@ type ( ) // For the time being this method is rather naive. -func (m Machine) isAlreadyRegistered() bool { - return m.Registered +func (machine Machine) isAlreadyRegistered() bool { + return machine.Registered } // isExpired returns whether the machine registration has expired. -func (m Machine) isExpired() bool { - return time.Now().UTC().After(*m.Expiry) +func (machine Machine) isExpired() bool { + return time.Now().UTC().After(*machine.Expiry) } // If the Machine is expired, updateMachineExpiry updates the Machine Expiry time to the maximum allowed duration, // or the default duration if no Expiry time was requested by the client. The expiry time here does not (yet) cause // a client to be disconnected, however they will have to re-auth the machine if they attempt to reconnect after the // expiry time. -func (h *Headscale) updateMachineExpiry(m *Machine) { - if m.isExpired() { +func (h *Headscale) updateMachineExpiry(machine *Machine) { + if machine.isExpired() { now := time.Now().UTC() maxExpiry := now.Add( h.cfg.MaxMachineRegistrationDuration, @@ -80,31 +80,31 @@ func (h *Headscale) updateMachineExpiry(m *Machine) { ) // calculate the default expiry // clamp the expiry time of the machine registration to the maximum allowed, or use the default if none supplied - if maxExpiry.Before(*m.RequestedExpiry) { + if maxExpiry.Before(*machine.RequestedExpiry) { log.Debug(). Msgf("Clamping registration expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineRegistrationDuration) - m.Expiry = &maxExpiry - } else if m.RequestedExpiry.IsZero() { + machine.Expiry = &maxExpiry + } else if machine.RequestedExpiry.IsZero() { log.Debug().Msgf("Using default machine registration expiry time: %v (%v)", defaultExpiry, h.cfg.DefaultMachineRegistrationDuration) - m.Expiry = &defaultExpiry + machine.Expiry = &defaultExpiry } else { - log.Debug().Msgf("Using requested machine registration expiry time: %v", m.RequestedExpiry) - m.Expiry = m.RequestedExpiry + log.Debug().Msgf("Using requested machine registration expiry time: %v", machine.RequestedExpiry) + machine.Expiry = machine.RequestedExpiry } - h.db.Save(&m) + h.db.Save(&machine) } } -func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) { +func (h *Headscale) getDirectPeers(machine *Machine) (Machines, error) { log.Trace(). Caller(). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Finding direct peers") machines := Machines{} if err := h.db.Preload("Namespace").Where("namespace_id = ? AND machine_key <> ? AND registered", - m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil { + machine.NamespaceID, machine.MachineKey).Find(&machines).Error; err != nil { log.Error().Err(err).Msg("Error accessing db") return Machines{}, err @@ -114,22 +114,22 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) { log.Trace(). Caller(). - Str("machine", m.Name). + Str("machine", machine.Name). Msgf("Found direct machines: %s", machines.String()) return machines, nil } // getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for. -func (h *Headscale) getShared(m *Machine) (Machines, error) { +func (h *Headscale) getShared(machine *Machine) (Machines, error) { log.Trace(). Caller(). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Finding shared peers") sharedMachines := []SharedMachine{} if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("namespace_id = ?", - m.NamespaceID).Find(&sharedMachines).Error; err != nil { + machine.NamespaceID).Find(&sharedMachines).Error; err != nil { return Machines{}, err } @@ -142,22 +142,22 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) { log.Trace(). Caller(). - Str("machine", m.Name). + Str("machine", machine.Name). Msgf("Found shared peers: %s", peers.String()) return peers, nil } // getSharedTo fetches the machines of the namespaces this machine is shared in. -func (h *Headscale) getSharedTo(m *Machine) (Machines, error) { +func (h *Headscale) getSharedTo(machine *Machine) (Machines, error) { log.Trace(). Caller(). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Finding peers in namespaces this machine is shared with") sharedMachines := []SharedMachine{} if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("machine_id = ?", - m.ID).Find(&sharedMachines).Error; err != nil { + machine.ID).Find(&sharedMachines).Error; err != nil { return Machines{}, err } @@ -176,14 +176,14 @@ func (h *Headscale) getSharedTo(m *Machine) (Machines, error) { log.Trace(). Caller(). - Str("machine", m.Name). + Str("machine", machine.Name). Msgf("Found peers we are shared with: %s", peers.String()) return peers, nil } -func (h *Headscale) getPeers(m *Machine) (Machines, error) { - direct, err := h.getDirectPeers(m) +func (h *Headscale) getPeers(machine *Machine) (Machines, error) { + direct, err := h.getDirectPeers(machine) if err != nil { log.Error(). Caller(). @@ -193,7 +193,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) { return Machines{}, err } - shared, err := h.getShared(m) + shared, err := h.getShared(machine) if err != nil { log.Error(). Caller(). @@ -203,7 +203,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) { return Machines{}, err } - sharedTo, err := h.getSharedTo(m) + sharedTo, err := h.getSharedTo(machine) if err != nil { log.Error(). Caller(). @@ -220,7 +220,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) { log.Trace(). Caller(). - Str("machine", m.Name). + Str("machine", machine.Name). Msgf("Found total peers: %s", peers.String()) return peers, nil @@ -262,9 +262,9 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) { } // GetMachineByMachineKey finds a Machine by ID and returns the Machine struct. -func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) { +func (h *Headscale) GetMachineByMachineKey(machineKey string) (*Machine, error) { m := Machine{} - if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey); result.Error != nil { + if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", machineKey); result.Error != nil { return nil, result.Error } @@ -273,8 +273,8 @@ func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) { // UpdateMachine takes a Machine struct pointer (typically already loaded from database // and updates it with the latest data from the database. -func (h *Headscale) UpdateMachine(m *Machine) error { - if result := h.db.Find(m).First(&m); result.Error != nil { +func (h *Headscale) UpdateMachine(machine *Machine) error { + if result := h.db.Find(machine).First(&machine); result.Error != nil { return result.Error } @@ -282,16 +282,16 @@ func (h *Headscale) UpdateMachine(m *Machine) error { } // DeleteMachine softs deletes a Machine from the database. -func (h *Headscale) DeleteMachine(m *Machine) error { - err := h.RemoveSharedMachineFromAllNamespaces(m) +func (h *Headscale) DeleteMachine(machine *Machine) error { + err := h.RemoveSharedMachineFromAllNamespaces(machine) if err != nil && err != errorMachineNotShared { return err } - m.Registered = false - namespaceID := m.NamespaceID - h.db.Save(&m) // we mark it as unregistered, just in case - if err := h.db.Delete(&m).Error; err != nil { + machine.Registered = false + namespaceID := machine.NamespaceID + h.db.Save(&machine) // we mark it as unregistered, just in case + if err := h.db.Delete(&machine).Error; err != nil { return err } @@ -299,14 +299,14 @@ func (h *Headscale) DeleteMachine(m *Machine) error { } // HardDeleteMachine hard deletes a Machine from the database. -func (h *Headscale) HardDeleteMachine(m *Machine) error { - err := h.RemoveSharedMachineFromAllNamespaces(m) +func (h *Headscale) HardDeleteMachine(machine *Machine) error { + err := h.RemoveSharedMachineFromAllNamespaces(machine) if err != nil && err != errorMachineNotShared { return err } - namespaceID := m.NamespaceID - if err := h.db.Unscoped().Delete(&m).Error; err != nil { + namespaceID := machine.NamespaceID + if err := h.db.Unscoped().Delete(&machine).Error; err != nil { return err } @@ -314,10 +314,10 @@ func (h *Headscale) HardDeleteMachine(m *Machine) error { } // GetHostInfo returns a Hostinfo struct for the machine. -func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) { +func (machine *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) { hostinfo := tailcfg.Hostinfo{} - if len(m.HostInfo) != 0 { - hi, err := m.HostInfo.MarshalJSON() + if len(machine.HostInfo) != 0 { + hi, err := machine.HostInfo.MarshalJSON() if err != nil { return nil, err } @@ -330,17 +330,17 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) { return &hostinfo, nil } -func (h *Headscale) isOutdated(m *Machine) bool { - if err := h.UpdateMachine(m); err != nil { +func (h *Headscale) isOutdated(machine *Machine) bool { + if err := h.UpdateMachine(machine); err != nil { // It does not seem meaningful to propagate this error as the end result // will have to be that the machine has to be considered outdated. return true } - sharedMachines, _ := h.getShared(m) + sharedMachines, _ := h.getShared(machine) namespaceSet := set.New(set.ThreadSafe) - namespaceSet.Add(m.Namespace.Name) + namespaceSet.Add(machine.Namespace.Name) // Check if any of our shared namespaces has updates that we have // not propagated. @@ -356,22 +356,22 @@ func (h *Headscale) isOutdated(m *Machine) bool { lastChange := h.getLastStateChange(namespaces...) log.Trace(). Caller(). - Str("machine", m.Name). - Time("last_successful_update", *m.LastSuccessfulUpdate). + Str("machine", machine.Name). + Time("last_successful_update", *machine.LastSuccessfulUpdate). Time("last_state_change", lastChange). - Msgf("Checking if %s is missing updates", m.Name) + Msgf("Checking if %s is missing updates", machine.Name) - return m.LastSuccessfulUpdate.Before(lastChange) + return machine.LastSuccessfulUpdate.Before(lastChange) } -func (m Machine) String() string { - return m.Name +func (machine Machine) String() string { + return machine.Name } -func (ms Machines) String() string { - temp := make([]string, len(ms)) +func (machines Machines) String() string { + temp := make([]string, len(machines)) - for index, machine := range ms { + for index, machine := range machines { temp[index] = machine.Name } @@ -379,24 +379,24 @@ func (ms Machines) String() string { } // TODO(kradalby): Remove when we have generics... -func (ms MachinesP) String() string { - temp := make([]string, len(ms)) +func (machines MachinesP) String() string { + temp := make([]string, len(machines)) - for index, machine := range ms { + for index, machine := range machines { temp[index] = machine.Name } return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) } -func (ms Machines) toNodes( +func (machines Machines) toNodes( baseDomain string, dnsConfig *tailcfg.DNSConfig, includeRoutes bool, ) ([]*tailcfg.Node, error) { - nodes := make([]*tailcfg.Node, len(ms)) + nodes := make([]*tailcfg.Node, len(machines)) - for index, machine := range ms { + for index, machine := range machines { node, err := machine.toNode(baseDomain, dnsConfig, includeRoutes) if err != nil { return nil, err @@ -410,23 +410,24 @@ func (ms Machines) toNodes( // toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes // as per the expected behaviour in the official SaaS. -func (m Machine) toNode( +func (machine Machine) toNode( baseDomain string, dnsConfig *tailcfg.DNSConfig, includeRoutes bool, ) (*tailcfg.Node, error) { - nKey, err := wgkey.ParseHex(m.NodeKey) + nodeKey, err := wgkey.ParseHex(machine.NodeKey) if err != nil { return nil, err } - mKey, err := wgkey.ParseHex(m.MachineKey) + + machineKey, err := wgkey.ParseHex(machine.MachineKey) if err != nil { return nil, err } var discoKey tailcfg.DiscoKey - if m.DiscoKey != "" { - dKey, err := wgkey.ParseHex(m.DiscoKey) + if machine.DiscoKey != "" { + dKey, err := wgkey.ParseHex(machine.DiscoKey) if err != nil { return nil, err } @@ -436,12 +437,12 @@ func (m Machine) toNode( } addrs := []netaddr.IPPrefix{} - ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", m.IPAddress)) + ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", machine.IPAddress)) if err != nil { log.Trace(). Caller(). - Str("ip", m.IPAddress). - Msgf("Failed to parse IP Prefix from IP: %s", m.IPAddress) + Str("ip", machine.IPAddress). + Msgf("Failed to parse IP Prefix from IP: %s", machine.IPAddress) return nil, err } @@ -455,8 +456,8 @@ func (m Machine) toNode( if includeRoutes { routesStr := []string{} - if len(m.EnabledRoutes) != 0 { - allwIps, err := m.EnabledRoutes.MarshalJSON() + if len(machine.EnabledRoutes) != 0 { + allwIps, err := machine.EnabledRoutes.MarshalJSON() if err != nil { return nil, err } @@ -476,8 +477,8 @@ func (m Machine) toNode( } endpoints := []string{} - if len(m.Endpoints) != 0 { - be, err := m.Endpoints.MarshalJSON() + if len(machine.Endpoints) != 0 { + be, err := machine.Endpoints.MarshalJSON() if err != nil { return nil, err } @@ -488,8 +489,8 @@ func (m Machine) toNode( } hostinfo := tailcfg.Hostinfo{} - if len(m.HostInfo) != 0 { - hi, err := m.HostInfo.MarshalJSON() + if len(machine.HostInfo) != 0 { + hi, err := machine.HostInfo.MarshalJSON() if err != nil { return nil, err } @@ -507,29 +508,34 @@ func (m Machine) toNode( } var keyExpiry time.Time - if m.Expiry != nil { - keyExpiry = *m.Expiry + if machine.Expiry != nil { + keyExpiry = *machine.Expiry } else { keyExpiry = time.Time{} } var hostname string if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS - hostname = fmt.Sprintf("%s.%s.%s", m.Name, m.Namespace.Name, baseDomain) + hostname = fmt.Sprintf( + "%s.%s.%s", + machine.Name, + machine.Namespace.Name, + baseDomain, + ) } else { - hostname = m.Name + hostname = machine.Name } n := tailcfg.Node{ - ID: tailcfg.NodeID(m.ID), // this is the actual ID + ID: tailcfg.NodeID(machine.ID), // this is the actual ID StableID: tailcfg.StableNodeID( - strconv.FormatUint(m.ID, BASE_10), + strconv.FormatUint(machine.ID, BASE_10), ), // in headscale, unlike tailcontrol server, IDs are permanent Name: hostname, - User: tailcfg.UserID(m.NamespaceID), - Key: tailcfg.NodeKey(nKey), + User: tailcfg.UserID(machine.NamespaceID), + Key: tailcfg.NodeKey(nodeKey), KeyExpiry: keyExpiry, - Machine: tailcfg.MachineKey(mKey), + Machine: tailcfg.MachineKey(machineKey), DiscoKey: discoKey, Addresses: addrs, AllowedIPs: allowedIPs, @@ -537,68 +543,73 @@ func (m Machine) toNode( DERP: derp, Hostinfo: hostinfo, - Created: m.CreatedAt, - LastSeen: m.LastSeen, + Created: machine.CreatedAt, + LastSeen: machine.LastSeen, KeepAlive: true, - MachineAuthorized: m.Registered, + MachineAuthorized: machine.Registered, Capabilities: []string{tailcfg.CapabilityFileSharing}, } return &n, nil } -func (m *Machine) toProto() *v1.Machine { - machine := &v1.Machine{ - Id: m.ID, - MachineKey: m.MachineKey, +func (machine *Machine) toProto() *v1.Machine { + machineProto := &v1.Machine{ + Id: machine.ID, + MachineKey: machine.MachineKey, - NodeKey: m.NodeKey, - DiscoKey: m.DiscoKey, - IpAddress: m.IPAddress, - Name: m.Name, - Namespace: m.Namespace.toProto(), + NodeKey: machine.NodeKey, + DiscoKey: machine.DiscoKey, + IpAddress: machine.IPAddress, + Name: machine.Name, + Namespace: machine.Namespace.toProto(), - Registered: m.Registered, + Registered: machine.Registered, // TODO(kradalby): Implement register method enum converter // RegisterMethod: , - CreatedAt: timestamppb.New(m.CreatedAt), + CreatedAt: timestamppb.New(machine.CreatedAt), } - if m.AuthKey != nil { - machine.PreAuthKey = m.AuthKey.toProto() + if machine.AuthKey != nil { + machineProto.PreAuthKey = machine.AuthKey.toProto() } - if m.LastSeen != nil { - machine.LastSeen = timestamppb.New(*m.LastSeen) + if machine.LastSeen != nil { + machineProto.LastSeen = timestamppb.New(*machine.LastSeen) } - if m.LastSuccessfulUpdate != nil { - machine.LastSuccessfulUpdate = timestamppb.New(*m.LastSuccessfulUpdate) + if machine.LastSuccessfulUpdate != nil { + machineProto.LastSuccessfulUpdate = timestamppb.New( + *machine.LastSuccessfulUpdate, + ) } - if m.Expiry != nil { - machine.Expiry = timestamppb.New(*m.Expiry) + if machine.Expiry != nil { + machineProto.Expiry = timestamppb.New(*machine.Expiry) } - return machine + return machineProto } // RegisterMachine is executed from the CLI to register a new Machine using its MachineKey. -func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, error) { - ns, err := h.GetNamespace(namespace) +func (h *Headscale) RegisterMachine( + key string, + namespaceName string, +) (*Machine, error) { + namespace, err := h.GetNamespace(namespaceName) if err != nil { return nil, err } - mKey, err := wgkey.ParseHex(key) + machineKey, err := wgkey.ParseHex(key) if err != nil { return nil, err } - m := Machine{} - if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is( + machine := Machine{} + if result := h.db.First(&machine, "machine_key = ?", machineKey.HexString()); errors.Is( result.Error, gorm.ErrRecordNotFound, ) { @@ -607,15 +618,15 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err log.Trace(). Caller(). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Attempting to register machine") - if m.isAlreadyRegistered() { + if machine.isAlreadyRegistered() { err := errors.New("Machine already registered") log.Error(). Caller(). Err(err). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Attempting to register machine") return nil, err @@ -626,7 +637,7 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err log.Error(). Caller(). Err(err). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Could not find IP for the new machine") return nil, err @@ -634,27 +645,27 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err log.Trace(). Caller(). - Str("machine", m.Name). + Str("machine", machine.Name). Str("ip", ip.String()). Msg("Found IP for host") - m.IPAddress = ip.String() - m.NamespaceID = ns.ID - m.Registered = true - m.RegisterMethod = "cli" - h.db.Save(&m) + machine.IPAddress = ip.String() + machine.NamespaceID = namespace.ID + machine.Registered = true + machine.RegisterMethod = "cli" + h.db.Save(&machine) log.Trace(). Caller(). - Str("machine", m.Name). + Str("machine", machine.Name). Str("ip", ip.String()). Msg("Machine registered with the database") - return &m, nil + return &machine, nil } -func (m *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) { - hostInfo, err := m.GetHostInfo() +func (machine *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) { + hostInfo, err := machine.GetHostInfo() if err != nil { return nil, err } @@ -662,8 +673,8 @@ func (m *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) { return hostInfo.RoutableIPs, nil } -func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) { - data, err := m.EnabledRoutes.MarshalJSON() +func (machine *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) { + data, err := machine.EnabledRoutes.MarshalJSON() if err != nil { return nil, err } @@ -686,13 +697,13 @@ func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) { return routes, nil } -func (m *Machine) IsRoutesEnabled(routeStr string) bool { +func (machine *Machine) IsRoutesEnabled(routeStr string) bool { route, err := netaddr.ParseIPPrefix(routeStr) if err != nil { return false } - enabledRoutes, err := m.GetEnabledRoutes() + enabledRoutes, err := machine.GetEnabledRoutes() if err != nil { return false } @@ -708,7 +719,7 @@ func (m *Machine) IsRoutesEnabled(routeStr string) bool { // EnableNodeRoute enables new routes based on a list of new routes. It will _replace_ the // previous list of routes. -func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error { +func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error { newRoutes := make([]netaddr.IPPrefix, len(routeStrs)) for index, routeStr := range routeStrs { route, err := netaddr.ParseIPPrefix(routeStr) @@ -719,7 +730,7 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error { newRoutes[index] = route } - availableRoutes, err := m.GetAdvertisedRoutes() + availableRoutes, err := machine.GetAdvertisedRoutes() if err != nil { return err } @@ -728,7 +739,7 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error { if !containsIpPrefix(availableRoutes, newRoute) { return fmt.Errorf( "route (%s) is not available on node %s", - m.Name, + machine.Name, newRoute, ) } @@ -739,10 +750,10 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error { return err } - m.EnabledRoutes = datatypes.JSON(routes) - h.db.Save(&m) + machine.EnabledRoutes = datatypes.JSON(routes) + h.db.Save(&machine) - err = h.RequestMapUpdates(m.NamespaceID) + err = h.RequestMapUpdates(machine.NamespaceID) if err != nil { return err } @@ -750,13 +761,13 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error { return nil } -func (m *Machine) RoutesToProto() (*v1.Routes, error) { - availableRoutes, err := m.GetAdvertisedRoutes() +func (machine *Machine) RoutesToProto() (*v1.Routes, error) { + availableRoutes, err := machine.GetAdvertisedRoutes() if err != nil { return nil, err } - enabledRoutes, err := m.GetEnabledRoutes() + enabledRoutes, err := machine.GetEnabledRoutes() if err != nil { return nil, err } diff --git a/namespaces.go b/namespaces.go index bea922d..858a7aa 100644 --- a/namespaces.go +++ b/namespaces.go @@ -32,12 +32,12 @@ type Namespace struct { // CreateNamespace creates a new Namespace. Returns error if could not be created // or another namespace already exists. func (h *Headscale) CreateNamespace(name string) (*Namespace, error) { - n := Namespace{} - if err := h.db.Where("name = ?", name).First(&n).Error; err == nil { + namespace := Namespace{} + if err := h.db.Where("name = ?", name).First(&namespace).Error; err == nil { return nil, errorNamespaceExists } - n.Name = name - if err := h.db.Create(&n).Error; err != nil { + namespace.Name = name + if err := h.db.Create(&namespace).Error; err != nil { log.Error(). Str("func", "CreateNamespace"). Err(err). @@ -46,22 +46,22 @@ func (h *Headscale) CreateNamespace(name string) (*Namespace, error) { return nil, err } - return &n, nil + return &namespace, nil } // DestroyNamespace destroys a Namespace. Returns error if the Namespace does // not exist or if there are machines associated with it. func (h *Headscale) DestroyNamespace(name string) error { - n, err := h.GetNamespace(name) + namespace, err := h.GetNamespace(name) if err != nil { return errorNamespaceNotFound } - m, err := h.ListMachinesInNamespace(name) + machines, err := h.ListMachinesInNamespace(name) if err != nil { return err } - if len(m) > 0 { + if len(machines) > 0 { return errorNamespaceNotEmptyOfNodes } @@ -69,14 +69,14 @@ func (h *Headscale) DestroyNamespace(name string) error { if err != nil { return err } - for _, p := range keys { - err = h.DestroyPreAuthKey(&p) + for _, key := range keys { + err = h.DestroyPreAuthKey(&key) if err != nil { return err } } - if result := h.db.Unscoped().Delete(&n); result.Error != nil { + if result := h.db.Unscoped().Delete(&namespace); result.Error != nil { return result.Error } @@ -86,7 +86,7 @@ func (h *Headscale) DestroyNamespace(name string) error { // RenameNamespace renames a Namespace. Returns error if the Namespace does // not exist or if another Namespace exists with the new name. func (h *Headscale) RenameNamespace(oldName, newName string) error { - n, err := h.GetNamespace(oldName) + oldNamespace, err := h.GetNamespace(oldName) if err != nil { return err } @@ -98,13 +98,13 @@ func (h *Headscale) RenameNamespace(oldName, newName string) error { return err } - n.Name = newName + oldNamespace.Name = newName - if result := h.db.Save(&n); result.Error != nil { + if result := h.db.Save(&oldNamespace); result.Error != nil { return result.Error } - err = h.RequestMapUpdates(n.ID) + err = h.RequestMapUpdates(oldNamespace.ID) if err != nil { return err } @@ -114,15 +114,15 @@ func (h *Headscale) RenameNamespace(oldName, newName string) error { // GetNamespace fetches a namespace by name. func (h *Headscale) GetNamespace(name string) (*Namespace, error) { - n := Namespace{} - if result := h.db.First(&n, "name = ?", name); errors.Is( + namespace := Namespace{} + if result := h.db.First(&namespace, "name = ?", name); errors.Is( result.Error, gorm.ErrRecordNotFound, ) { return nil, errorNamespaceNotFound } - return &n, nil + return &namespace, nil } // ListNamespaces gets all the existing namespaces. @@ -137,13 +137,13 @@ func (h *Headscale) ListNamespaces() ([]Namespace, error) { // ListMachinesInNamespace gets all the nodes in a given namespace. func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) { - n, err := h.GetNamespace(name) + namespace, err := h.GetNamespace(name) if err != nil { return nil, err } machines := []Machine{} - if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil { + if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: namespace.ID}).Find(&machines).Error; err != nil { return nil, err } @@ -176,17 +176,18 @@ func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error } // SetMachineNamespace assigns a Machine to a namespace. -func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error { - n, err := h.GetNamespace(namespaceName) +func (h *Headscale) SetMachineNamespace(machine *Machine, namespaceName string) error { + namespace, err := h.GetNamespace(namespaceName) if err != nil { return err } - m.NamespaceID = n.ID - h.db.Save(&m) + machine.NamespaceID = namespace.ID + h.db.Save(&machine) return nil } +// TODO(kradalby): Remove the need for this. // RequestMapUpdates signals the KV worker to update the maps for this namespace. func (h *Headscale) RequestMapUpdates(namespaceID uint) error { namespace := Namespace{} @@ -194,8 +195,8 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error { return err } - v, err := h.getValue("namespaces_pending_updates") - if err != nil || v == "" { + namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates") + if err != nil || namespacesPendingUpdates == "" { err = h.setValue( "namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name), @@ -207,7 +208,7 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error { return nil } names := []string{} - err = json.Unmarshal([]byte(v), &names) + err = json.Unmarshal([]byte(namespacesPendingUpdates), &names) if err != nil { err = h.setValue( "namespaces_pending_updates", @@ -235,16 +236,16 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error { } func (h *Headscale) checkForNamespacesPendingUpdates() { - v, err := h.getValue("namespaces_pending_updates") + namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates") if err != nil { return } - if v == "" { + if namespacesPendingUpdates == "" { return } namespaces := []string{} - err = json.Unmarshal([]byte(v), &namespaces) + err = json.Unmarshal([]byte(namespacesPendingUpdates), &namespaces) if err != nil { return } @@ -255,11 +256,11 @@ func (h *Headscale) checkForNamespacesPendingUpdates() { Msg("Sending updates to nodes in namespacespace") h.setLastStateChangeToNow(namespace) } - newV, err := h.getValue("namespaces_pending_updates") + newPendingUpdateValue, err := h.getValue("namespaces_pending_updates") if err != nil { return } - if v == newV { // only clear when no changes, so we notified everybody + if namespacesPendingUpdates == newPendingUpdateValue { // only clear when no changes, so we notified everybody err = h.setValue("namespaces_pending_updates", "") if err != nil { log.Error(). @@ -273,7 +274,7 @@ func (h *Headscale) checkForNamespacesPendingUpdates() { } func (n *Namespace) toUser() *tailcfg.User { - u := tailcfg.User{ + user := tailcfg.User{ ID: tailcfg.UserID(n.ID), LoginName: n.Name, DisplayName: n.Name, @@ -283,11 +284,11 @@ func (n *Namespace) toUser() *tailcfg.User { Created: time.Time{}, } - return &u + return &user } func (n *Namespace) toLogin() *tailcfg.Login { - l := tailcfg.Login{ + login := tailcfg.Login{ ID: tailcfg.LoginID(n.ID), LoginName: n.Name, DisplayName: n.Name, @@ -295,14 +296,14 @@ func (n *Namespace) toLogin() *tailcfg.Login { Domain: "headscale.net", } - return &l + return &login } -func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile { +func getMapResponseUserProfiles(machine Machine, peers Machines) []tailcfg.UserProfile { namespaceMap := make(map[string]Namespace) - namespaceMap[m.Namespace.Name] = m.Namespace - for _, p := range peers { - namespaceMap[p.Namespace.Name] = p.Namespace // not worth checking if already is there + namespaceMap[machine.Namespace.Name] = machine.Namespace + for _, peer := range peers { + namespaceMap[peer.Namespace.Name] = peer.Namespace // not worth checking if already is there } profiles := []tailcfg.UserProfile{} diff --git a/oidc.go b/oidc.go index c77a249..e68e112 100644 --- a/oidc.go +++ b/oidc.go @@ -68,10 +68,10 @@ func (h *Headscale) initOIDC() error { // RegisterOIDC redirects to the OIDC provider for authentication // Puts machine key in cache so the callback can retrieve it using the oidc state param // Listens in /oidc/register/:mKey. -func (h *Headscale) RegisterOIDC(c *gin.Context) { - mKeyStr := c.Param("mkey") +func (h *Headscale) RegisterOIDC(ctx *gin.Context) { + mKeyStr := ctx.Param("mkey") if mKeyStr == "" { - c.String(http.StatusBadRequest, "Wrong params") + ctx.String(http.StatusBadRequest, "Wrong params") return } @@ -79,7 +79,7 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) { b := make([]byte, RANDOM_BYTE_SIZE) if _, err := rand.Read(b); err != nil { log.Error().Msg("could not read 16 bytes from rand") - c.String(http.StatusInternalServerError, "could not read 16 bytes from rand") + ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand") return } @@ -92,7 +92,7 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) { authUrl := h.oauth2Config.AuthCodeURL(stateStr) log.Debug().Msgf("Redirecting to %s for authentication", authUrl) - c.Redirect(http.StatusFound, authUrl) + ctx.Redirect(http.StatusFound, authUrl) } // OIDCCallback handles the callback from the OIDC endpoint @@ -100,19 +100,19 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) { // TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities // TODO: Add groups information from OIDC tokens into machine HostInfo // Listens in /oidc/callback. -func (h *Headscale) OIDCCallback(c *gin.Context) { - code := c.Query("code") - state := c.Query("state") +func (h *Headscale) OIDCCallback(ctx *gin.Context) { + code := ctx.Query("code") + state := ctx.Query("state") if code == "" || state == "" { - c.String(http.StatusBadRequest, "Wrong params") + ctx.String(http.StatusBadRequest, "Wrong params") return } oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code) if err != nil { - c.String(http.StatusBadRequest, "Could not exchange code for token") + ctx.String(http.StatusBadRequest, "Could not exchange code for token") return } @@ -121,7 +121,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string) if !rawIDTokenOK { - c.String(http.StatusBadRequest, "Could not extract ID Token") + ctx.String(http.StatusBadRequest, "Could not extract ID Token") return } @@ -130,7 +130,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { idToken, err := verifier.Verify(context.Background(), rawIDToken) if err != nil { - c.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error()) + ctx.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error()) return } @@ -145,7 +145,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { // Extract custom claims var claims IDTokenClaims if err = idToken.Claims(&claims); err != nil { - c.String( + ctx.String( http.StatusBadRequest, fmt.Sprintf("Failed to decode id token claims: %s", err), ) @@ -159,7 +159,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { if !mKeyFound { log.Error(). Msg("requested machine state key expired before authorisation completed") - c.String(http.StatusBadRequest, "state has expired") + ctx.String(http.StatusBadRequest, "state has expired") return } @@ -167,16 +167,19 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { if !mKeyOK { log.Error().Msg("could not get machine key from cache") - c.String(http.StatusInternalServerError, "could not get machine key from cache") + ctx.String( + http.StatusInternalServerError, + "could not get machine key from cache", + ) return } // retrieve machine information - m, err := h.GetMachineByMachineKey(mKeyStr) + machine, err := h.GetMachineByMachineKey(mKeyStr) if err != nil { log.Error().Msg("machine key not found in database") - c.String( + ctx.String( http.StatusInternalServerError, "could not get machine info from database", ) @@ -186,19 +189,19 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { now := time.Now().UTC() - if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok { + if namespaceName, ok := h.getNamespaceFromEmail(claims.Email); ok { // register the machine if it's new - if !m.Registered { + if !machine.Registered { log.Debug().Msg("Registering new machine after successful callback") - ns, err := h.GetNamespace(nsName) + namespace, err := h.GetNamespace(namespaceName) if err != nil { - ns, err = h.CreateNamespace(nsName) + namespace, err = h.CreateNamespace(namespaceName) if err != nil { log.Error(). Msgf("could not create new namespace '%s'", claims.Email) - c.String( + ctx.String( http.StatusInternalServerError, "could not create new namespace", ) @@ -209,7 +212,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { ip, err := h.getAvailableIP() if err != nil { - c.String( + ctx.String( http.StatusInternalServerError, "could not get an IP from the pool", ) @@ -217,17 +220,17 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { return } - m.IPAddress = ip.String() - m.NamespaceID = ns.ID - m.Registered = true - m.RegisterMethod = "oidc" - m.LastSuccessfulUpdate = &now - h.db.Save(&m) + machine.IPAddress = ip.String() + machine.NamespaceID = namespace.ID + machine.Registered = true + machine.RegisterMethod = "oidc" + machine.LastSuccessfulUpdate = &now + h.db.Save(&machine) } - h.updateMachineExpiry(m) + h.updateMachineExpiry(machine) - c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(` + ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`

headscale

@@ -243,9 +246,9 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { log.Error(). Str("email", claims.Email). Str("username", claims.Username). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Email could not be mapped to a namespace") - c.String( + ctx.String( http.StatusBadRequest, "email from claim could not be mapped to a namespace", ) diff --git a/poll.go b/poll.go index 3f7b293..1927853 100644 --- a/poll.go +++ b/poll.go @@ -233,7 +233,7 @@ func (h *Headscale) PollNetMapStream( ) { go h.scheduledPollWorker(cancelKeepAlive, updateChan, keepAliveChan, mKey, req, m) - c.Stream(func(w io.Writer) bool { + c.Stream(func(writer io.Writer) bool { log.Trace(). Str("handler", "PollNetMapStream"). Str("machine", m.Name). @@ -252,7 +252,7 @@ func (h *Headscale) PollNetMapStream( Str("channel", "pollData"). Int("bytes", len(data)). Msg("Sending data received via pollData channel") - _, err := w.Write(data) + _, err := writer.Write(data) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). @@ -305,7 +305,7 @@ func (h *Headscale) PollNetMapStream( Str("channel", "keepAlive"). Int("bytes", len(data)). Msg("Sending keep alive message") - _, err := w.Write(data) + _, err := writer.Write(data) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). @@ -370,7 +370,7 @@ func (h *Headscale) PollNetMapStream( Err(err). Msg("Could not get the map update") } - _, err = w.Write(data) + _, err = writer.Write(data) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). diff --git a/sharing.go b/sharing.go index 8f65414..741deb6 100644 --- a/sharing.go +++ b/sharing.go @@ -18,13 +18,16 @@ type SharedMachine struct { } // AddSharedMachineToNamespace adds a machine as a shared node to a namespace. -func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error { - if m.NamespaceID == ns.ID { +func (h *Headscale) AddSharedMachineToNamespace( + machine *Machine, + namespace *Namespace, +) error { + if machine.NamespaceID == namespace.ID { return errorSameNamespace } sharedMachines := []SharedMachine{} - if err := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).Find(&sharedMachines).Error; err != nil { + if err := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID).Find(&sharedMachines).Error; err != nil { return err } if len(sharedMachines) > 0 { @@ -32,10 +35,10 @@ func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error } sharedMachine := SharedMachine{ - MachineID: m.ID, - Machine: *m, - NamespaceID: ns.ID, - Namespace: *ns, + MachineID: machine.ID, + Machine: *machine, + NamespaceID: namespace.ID, + Namespace: *namespace, } h.db.Save(&sharedMachine) @@ -43,14 +46,17 @@ func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error } // RemoveSharedMachineFromNamespace removes a shared machine from a namespace. -func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace) error { - if m.NamespaceID == ns.ID { +func (h *Headscale) RemoveSharedMachineFromNamespace( + machine *Machine, + namespace *Namespace, +) error { + if machine.NamespaceID == namespace.ID { // Can't unshare from primary namespace return errorMachineNotShared } sharedMachine := SharedMachine{} - result := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID). + result := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID). Unscoped(). Delete(&sharedMachine) if result.Error != nil { @@ -61,7 +67,7 @@ func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace) return errorMachineNotShared } - err := h.RequestMapUpdates(ns.ID) + err := h.RequestMapUpdates(namespace.ID) if err != nil { return err } @@ -70,9 +76,9 @@ func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace) } // RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces. -func (h *Headscale) RemoveSharedMachineFromAllNamespaces(m *Machine) error { +func (h *Headscale) RemoveSharedMachineFromAllNamespaces(machine *Machine) error { sharedMachine := SharedMachine{} - if result := h.db.Where("machine_id = ?", m.ID).Unscoped().Delete(&sharedMachine); result.Error != nil { + if result := h.db.Where("machine_id = ?", machine.ID).Unscoped().Delete(&sharedMachine); result.Error != nil { return result.Error } diff --git a/swagger.go b/swagger.go index 01b2eb5..9e62d39 100644 --- a/swagger.go +++ b/swagger.go @@ -13,8 +13,8 @@ import ( //go:embed gen/openapiv2/headscale/v1/headscale.swagger.json var apiV1JSON []byte -func SwaggerUI(c *gin.Context) { - t := template.Must(template.New("swagger").Parse(` +func SwaggerUI(ctx *gin.Context) { + swaggerTemplate := template.Must(template.New("swagger").Parse(` @@ -47,12 +47,12 @@ func SwaggerUI(c *gin.Context) { `)) var payload bytes.Buffer - if err := t.Execute(&payload, struct{}{}); err != nil { + if err := swaggerTemplate.Execute(&payload, struct{}{}); err != nil { log.Error(). Caller(). Err(err). Msg("Could not render Swagger") - c.Data( + ctx.Data( http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Swagger"), @@ -61,9 +61,9 @@ func SwaggerUI(c *gin.Context) { return } - c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes()) + ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes()) } -func SwaggerAPIv1(c *gin.Context) { - c.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON) +func SwaggerAPIv1(ctx *gin.Context) { + ctx.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON) } diff --git a/utils.go b/utils.go index 85e3ba2..803cfc5 100644 --- a/utils.go +++ b/utils.go @@ -36,7 +36,7 @@ func decode( func decodeMsg( msg []byte, - v interface{}, + output interface{}, pubKey *wgkey.Key, privKey *wgkey.Private, ) error { @@ -45,7 +45,7 @@ func decodeMsg( return err } // fmt.Println(string(decrypted)) - if err := json.Unmarshal(decrypted, v); err != nil { + if err := json.Unmarshal(decrypted, output); err != nil { return fmt.Errorf("response: %v", err) } @@ -78,13 +78,17 @@ func encode(v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, e return encodeMsg(b, pubKey, privKey) } -func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) { +func encodeMsg( + payload []byte, + pubKey *wgkey.Key, + privKey *wgkey.Private, +) ([]byte, error) { var nonce [24]byte if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil { panic(err) } pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey) - msg := box.Seal(nonce[:], b, &nonce, pub, pri) + msg := box.Seal(nonce[:], payload, &nonce, pub, pri) return msg, nil }