Add and fix stylecheck (golint replacement)

This commit is contained in:
Kristoffer Dalby 2021-11-15 17:24:24 +00:00
parent 0c005a6b01
commit 715542ac1c
No known key found for this signature in database
GPG key ID: 09F62DC067465735
21 changed files with 83 additions and 83 deletions

View file

@ -29,7 +29,6 @@ linters:
- gocritic - gocritic
# We should strive to enable these: # We should strive to enable these:
- stylecheck
- wrapcheck - wrapcheck
- goerr113 - goerr113
- forcetypeassert - forcetypeassert

24
acls.go
View file

@ -25,11 +25,11 @@ const (
) )
const ( const (
PORT_RANGE_BEGIN = 0 Base10 = 10
PORT_RANGE_END = 65535 BitSize16 = 16
BASE_10 = 10 portRangeBegin = 0
BIT_SIZE_16 = 16 portRangeEnd = 65535
EXPECTED_TOKEN_ITEMS = 2 expectedTokenItems = 2
) )
// LoadACLPolicy loads the ACL policy from the specify path, and generates the ACL rules. // LoadACLPolicy loads the ACL policy from the specify path, and generates the ACL rules.
@ -122,7 +122,7 @@ func (h *Headscale) generateACLPolicyDestPorts(
d string, d string,
) ([]tailcfg.NetPortRange, error) { ) ([]tailcfg.NetPortRange, error) {
tokens := strings.Split(d, ":") tokens := strings.Split(d, ":")
if len(tokens) < EXPECTED_TOKEN_ITEMS || len(tokens) > 3 { if len(tokens) < expectedTokenItems || len(tokens) > 3 {
return nil, errInvalidPortFormat return nil, errInvalidPortFormat
} }
@ -133,7 +133,7 @@ func (h *Headscale) generateACLPolicyDestPorts(
// tag:montreal-webserver:80,443 // tag:montreal-webserver:80,443
// tag:api-server:443 // tag:api-server:443
// example-host-1:* // example-host-1:*
if len(tokens) == EXPECTED_TOKEN_ITEMS { if len(tokens) == expectedTokenItems {
alias = tokens[0] alias = tokens[0]
} else { } else {
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1]) alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
@ -257,7 +257,7 @@ func (h *Headscale) expandAlias(alias string) ([]string, error) {
func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) { func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
if portsStr == "*" { if portsStr == "*" {
return &[]tailcfg.PortRange{ return &[]tailcfg.PortRange{
{First: PORT_RANGE_BEGIN, Last: PORT_RANGE_END}, {First: portRangeBegin, Last: portRangeEnd},
}, nil }, nil
} }
@ -266,7 +266,7 @@ func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
rang := strings.Split(portStr, "-") rang := strings.Split(portStr, "-")
switch len(rang) { switch len(rang) {
case 1: case 1:
port, err := strconv.ParseUint(rang[0], BASE_10, BIT_SIZE_16) port, err := strconv.ParseUint(rang[0], Base10, BitSize16)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -275,12 +275,12 @@ func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
Last: uint16(port), Last: uint16(port),
}) })
case EXPECTED_TOKEN_ITEMS: case expectedTokenItems:
start, err := strconv.ParseUint(rang[0], BASE_10, BIT_SIZE_16) start, err := strconv.ParseUint(rang[0], Base10, BitSize16)
if err != nil { if err != nil {
return nil, err return nil, err
} }
last, err := strconv.ParseUint(rang[1], BASE_10, BIT_SIZE_16) last, err := strconv.ParseUint(rang[1], Base10, BitSize16)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -43,18 +43,18 @@ type ACLTest struct {
// UnmarshalJSON allows to parse the Hosts directly into netaddr objects. // UnmarshalJSON allows to parse the Hosts directly into netaddr objects.
func (hosts *Hosts) UnmarshalJSON(data []byte) error { func (hosts *Hosts) UnmarshalJSON(data []byte) error {
newHosts := Hosts{} newHosts := Hosts{}
hostIpPrefixMap := make(map[string]string) hostIPPrefixMap := make(map[string]string)
ast, err := hujson.Parse(data) ast, err := hujson.Parse(data)
if err != nil { if err != nil {
return err return err
} }
ast.Standardize() ast.Standardize()
data = ast.Pack() data = ast.Pack()
err = json.Unmarshal(data, &hostIpPrefixMap) err = json.Unmarshal(data, &hostIPPrefixMap)
if err != nil { if err != nil {
return err return err
} }
for host, prefixStr := range hostIpPrefixMap { for host, prefixStr := range hostIPPrefixMap {
if !strings.Contains(prefixStr, "/") { if !strings.Contains(prefixStr, "/") {
prefixStr += "/32" prefixStr += "/32"
} }

6
api.go
View file

@ -18,7 +18,7 @@ import (
"tailscale.com/types/wgkey" "tailscale.com/types/wgkey"
) )
const RESERVED_RESPONSE_HEADER_SIZE = 4 const reservedResponseHeaderSize = 4
// KeyHandler provides the Headscale pub key // KeyHandler provides the Headscale pub key
// Listens in /key. // Listens in /key.
@ -367,7 +367,7 @@ func (h *Headscale) getMapResponse(
} }
} }
// declare the incoming size on the first 4 bytes // declare the incoming size on the first 4 bytes
data := make([]byte, RESERVED_RESPONSE_HEADER_SIZE) data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody))) binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...) data = append(data, respBody...)
@ -397,7 +397,7 @@ func (h *Headscale) getMapKeepAliveResponse(
return nil, err return nil, err
} }
} }
data := make([]byte, RESERVED_RESPONSE_HEADER_SIZE) data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody))) binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...) data = append(data, respBody...)

24
app.go
View file

@ -47,11 +47,11 @@ import (
) )
const ( const (
AUTH_PREFIX = "Bearer " AuthPrefix = "Bearer "
POSTGRESQL = "postgresql" Postgres = "postgresql"
SQLITE = "sqlite3" Sqlite = "sqlite3"
UPDATE_RATE_MILLISECONDS = 5000 updateInterval = 5000
HTTP_READ_TIMEOUT = 30 * time.Second HTTPReadTimeout = 30 * time.Second
) )
// Config contains the initial Headscale configuration. // Config contains the initial Headscale configuration.
@ -154,7 +154,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
var dbString string var dbString string
switch cfg.DBtype { switch cfg.DBtype {
case POSTGRESQL: case Postgres:
dbString = fmt.Sprintf( dbString = fmt.Sprintf(
"host=%s port=%d dbname=%s user=%s password=%s sslmode=disable", "host=%s port=%d dbname=%s user=%s password=%s sslmode=disable",
cfg.DBhost, cfg.DBhost,
@ -163,7 +163,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
cfg.DBuser, cfg.DBuser,
cfg.DBpass, cfg.DBpass,
) )
case SQLITE: case Sqlite:
dbString = cfg.DBpath dbString = cfg.DBpath
default: default:
return nil, errors.New("unsupported DB") return nil, errors.New("unsupported DB")
@ -321,7 +321,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
token := authHeader[0] token := authHeader[0]
if !strings.HasPrefix(token, AUTH_PREFIX) { if !strings.HasPrefix(token, AuthPrefix) {
log.Error(). log.Error().
Caller(). Caller().
Str("client_address", client.Addr.String()). Str("client_address", client.Addr.String()).
@ -363,7 +363,7 @@ func (h *Headscale) httpAuthenticationMiddleware(ctx *gin.Context) {
authHeader := ctx.GetHeader("authorization") authHeader := ctx.GetHeader("authorization")
if !strings.HasPrefix(authHeader, AUTH_PREFIX) { if !strings.HasPrefix(authHeader, AuthPrefix) {
log.Error(). log.Error().
Caller(). Caller().
Str("client_address", ctx.ClientIP()). Str("client_address", ctx.ClientIP()).
@ -511,13 +511,13 @@ func (h *Headscale) Serve() error {
} }
// I HATE THIS // I HATE THIS
go h.watchForKVUpdates(UPDATE_RATE_MILLISECONDS) go h.watchForKVUpdates(updateInterval)
go h.expireEphemeralNodes(UPDATE_RATE_MILLISECONDS) go h.expireEphemeralNodes(updateInterval)
httpServer := &http.Server{ httpServer := &http.Server{
Addr: h.cfg.Addr, Addr: h.cfg.Addr,
Handler: router, Handler: router,
ReadTimeout: HTTP_READ_TIMEOUT, ReadTimeout: HTTPReadTimeout,
// Go does not handle timeouts in HTTP very well, and there is // Go does not handle timeouts in HTTP very well, and there is
// no good way to handle streaming timeouts, therefore we need to // no good way to handle streaming timeouts, therefore we need to
// keep this at unlimited and be careful to clean up connections // keep this at unlimited and be careful to clean up connections

View file

@ -55,7 +55,7 @@ func (h *Headscale) AppleMobileConfig(ctx *gin.Context) {
<p>Or</p> <p>Or</p>
<p>Use your terminal to configure the default setting for Tailscale by issuing:</p> <p>Use your terminal to configure the default setting for Tailscale by issuing:</p>
<code>defaults write io.tailscale.ipn.macos ControlURL {{.Url}}</code> <code>defaults write io.tailscale.ipn.macos ControlURL {{.URL}}</code>
<p>Restart Tailscale.app and log in.</p> <p>Restart Tailscale.app and log in.</p>
@ -63,7 +63,7 @@ func (h *Headscale) AppleMobileConfig(ctx *gin.Context) {
</html>`)) </html>`))
config := map[string]interface{}{ config := map[string]interface{}{
"Url": h.cfg.ServerURL, "URL": h.cfg.ServerURL,
} }
var payload bytes.Buffer var payload bytes.Buffer
@ -102,7 +102,7 @@ func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) {
return return
} }
contentId, err := uuid.NewV4() contentID, err := uuid.NewV4()
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "ApplePlatformConfig"). Str("handler", "ApplePlatformConfig").
@ -118,8 +118,8 @@ func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) {
} }
platformConfig := AppleMobilePlatformConfig{ platformConfig := AppleMobilePlatformConfig{
UUID: contentId, UUID: contentID,
Url: h.cfg.ServerURL, URL: h.cfg.ServerURL,
} }
var payload bytes.Buffer var payload bytes.Buffer
@ -165,7 +165,7 @@ func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) {
config := AppleMobileConfig{ config := AppleMobileConfig{
UUID: id, UUID: id,
Url: h.cfg.ServerURL, URL: h.cfg.ServerURL,
Payload: payload.String(), Payload: payload.String(),
} }
@ -193,13 +193,13 @@ func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) {
type AppleMobileConfig struct { type AppleMobileConfig struct {
UUID uuid.UUID UUID uuid.UUID
Url string URL string
Payload string Payload string
} }
type AppleMobilePlatformConfig struct { type AppleMobilePlatformConfig struct {
UUID uuid.UUID UUID uuid.UUID
Url string URL string
} }
var commonTemplate = template.Must( var commonTemplate = template.Must(
@ -212,7 +212,7 @@ var commonTemplate = template.Must(
<key>PayloadDisplayName</key> <key>PayloadDisplayName</key>
<string>Headscale</string> <string>Headscale</string>
<key>PayloadDescription</key> <key>PayloadDescription</key>
<string>Configure Tailscale login server to: {{.Url}}</string> <string>Configure Tailscale login server to: {{.URL}}</string>
<key>PayloadIdentifier</key> <key>PayloadIdentifier</key>
<string>com.github.juanfont.headscale</string> <string>com.github.juanfont.headscale</string>
<key>PayloadRemovalDisallowed</key> <key>PayloadRemovalDisallowed</key>
@ -243,7 +243,7 @@ var iosTemplate = template.Must(template.New("iosTemplate").Parse(`
<true/> <true/>
<key>ControlURL</key> <key>ControlURL</key>
<string>{{.Url}}</string> <string>{{.URL}}</string>
</dict> </dict>
`)) `))
@ -261,6 +261,6 @@ var macosTemplate = template.Must(template.New("macosTemplate").Parse(`
<true/> <true/>
<key>ControlURL</key> <key>ControlURL</key>
<string>{{.Url}}</string> <string>{{.URL}}</string>
</dict> </dict>
`)) `))

View file

@ -29,7 +29,7 @@ var createNamespaceCmd = &cobra.Command{
Short: "Creates a new namespace", Short: "Creates a new namespace",
Args: func(cmd *cobra.Command, args []string) error { Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 { if len(args) < 1 {
return fmt.Errorf("Missing parameters") return fmt.Errorf("missing parameters")
} }
return nil return nil
@ -71,7 +71,7 @@ var destroyNamespaceCmd = &cobra.Command{
Short: "Destroys a namespace", Short: "Destroys a namespace",
Args: func(cmd *cobra.Command, args []string) error { Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 { if len(args) < 1 {
return fmt.Errorf("Missing parameters") return fmt.Errorf("missing parameters")
} }
return nil return nil
@ -197,7 +197,7 @@ var renameNamespaceCmd = &cobra.Command{
Args: func(cmd *cobra.Command, args []string) error { Args: func(cmd *cobra.Command, args []string) error {
expectedArguments := 2 expectedArguments := 2
if len(args) < expectedArguments { if len(args) < expectedArguments {
return fmt.Errorf("Missing parameters") return fmt.Errorf("missing parameters")
} }
return nil return nil

View file

@ -451,7 +451,7 @@ func nodesToPtables(
tableData = append( tableData = append(
tableData, tableData,
[]string{ []string{
strconv.FormatUint(machine.Id, headscale.BASE_10), strconv.FormatUint(machine.Id, headscale.Base10),
machine.Name, machine.Name,
nodeKey.ShortString(), nodeKey.ShortString(),
namespace, namespace,

View file

@ -13,7 +13,7 @@ import (
) )
const ( const (
DEFAULT_PRE_AUTH_KEY_EXPIRY = 24 * time.Hour DefaultPreAuthKeyExpiry = 24 * time.Hour
) )
func init() { func init() {
@ -31,7 +31,7 @@ func init() {
createPreAuthKeyCmd.PersistentFlags(). createPreAuthKeyCmd.PersistentFlags().
Bool("ephemeral", false, "Preauthkey for ephemeral nodes") Bool("ephemeral", false, "Preauthkey for ephemeral nodes")
createPreAuthKeyCmd.Flags(). createPreAuthKeyCmd.Flags().
DurationP("expiration", "e", DEFAULT_PRE_AUTH_KEY_EXPIRY, "Human-readable expiration of the key (30m, 24h, 365d...)") DurationP("expiration", "e", DefaultPreAuthKeyExpiry, "Human-readable expiration of the key (30m, 24h, 365d...)")
} }
var preauthkeysCmd = &cobra.Command{ var preauthkeysCmd = &cobra.Command{

View file

@ -45,7 +45,7 @@ var listRoutesCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
machineId, err := cmd.Flags().GetUint64("identifier") machineID, err := cmd.Flags().GetUint64("identifier")
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,
@ -61,7 +61,7 @@ var listRoutesCmd = &cobra.Command{
defer conn.Close() defer conn.Close()
request := &v1.GetMachineRouteRequest{ request := &v1.GetMachineRouteRequest{
MachineId: machineId, MachineId: machineID,
} }
response, err := client.GetMachineRoute(ctx, request) response, err := client.GetMachineRoute(ctx, request)
@ -111,7 +111,8 @@ omit the route you do not want to enable.
`, `,
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
machineId, err := cmd.Flags().GetUint64("identifier")
machineID, err := cmd.Flags().GetUint64("identifier")
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,
@ -138,7 +139,7 @@ omit the route you do not want to enable.
defer conn.Close() defer conn.Close()
request := &v1.EnableMachineRoutesRequest{ request := &v1.EnableMachineRoutesRequest{
MachineId: machineId, MachineId: machineID,
Routes: routes, Routes: routes,
} }

View file

@ -53,7 +53,7 @@ func LoadConfig(path string) error {
viper.SetDefault("cli.timeout", "5s") viper.SetDefault("cli.timeout", "5s")
if err := viper.ReadInConfig(); err != nil { if err := viper.ReadInConfig(); err != nil {
return fmt.Errorf("Fatal error reading config file: %w", err) return fmt.Errorf("fatal error reading config file: %w", err)
} }
// Collect any validation errors and return them all at once // Collect any validation errors and return them all at once
@ -306,7 +306,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
minInactivityTimeout, _ := time.ParseDuration("65s") minInactivityTimeout, _ := time.ParseDuration("65s")
if viper.GetDuration("ephemeral_node_inactivity_timeout") <= minInactivityTimeout { if viper.GetDuration("ephemeral_node_inactivity_timeout") <= minInactivityTimeout {
err := fmt.Errorf( err := fmt.Errorf(
"ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s\n", "ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s",
viper.GetString("ephemeral_node_inactivity_timeout"), viper.GetString("ephemeral_node_inactivity_timeout"),
minInactivityTimeout, minInactivityTimeout,
) )

6
db.go
View file

@ -24,7 +24,7 @@ func (h *Headscale) initDB() error {
} }
h.db = db h.db = db
if h.dbType == POSTGRESQL { if h.dbType == Postgres {
db.Exec("create extension if not exists \"uuid-ossp\";") db.Exec("create extension if not exists \"uuid-ossp\";")
} }
err = db.AutoMigrate(&Machine{}) err = db.AutoMigrate(&Machine{})
@ -66,12 +66,12 @@ func (h *Headscale) openDB() (*gorm.DB, error) {
} }
switch h.dbType { switch h.dbType {
case SQLITE: case Sqlite:
db, err = gorm.Open(sqlite.Open(h.dbString), &gorm.Config{ db, err = gorm.Open(sqlite.Open(h.dbString), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true, DisableForeignKeyConstraintWhenMigrating: true,
Logger: log, Logger: log,
}) })
case POSTGRESQL: case Postgres:
db, err = gorm.Open(postgres.Open(h.dbString), &gorm.Config{ db, err = gorm.Open(postgres.Open(h.dbString), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true, DisableForeignKeyConstraintWhenMigrating: true,
Logger: log, Logger: log,

View file

@ -32,7 +32,7 @@ func loadDERPMapFromPath(path string) (*tailcfg.DERPMap, error) {
} }
func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) { func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
ctx, cancel := context.WithTimeout(context.Background(), HTTP_READ_TIMEOUT) ctx, cancel := context.WithTimeout(context.Background(), HTTPReadTimeout)
defer cancel() defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", addr.String(), nil) req, err := http.NewRequestWithContext(ctx, "GET", addr.String(), nil)
@ -41,7 +41,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
} }
client := http.Client{ client := http.Client{
Timeout: HTTP_READ_TIMEOUT, Timeout: HTTPReadTimeout,
} }
resp, err := client.Do(req) resp, err := client.Do(req)

6
dns.go
View file

@ -11,7 +11,7 @@ import (
) )
const ( const (
BYTE_SIZE = 8 ByteSize = 8
) )
// generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`. // generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`.
@ -47,10 +47,10 @@ func generateMagicDNSRootDomains(
maskBits, _ := netRange.Mask.Size() maskBits, _ := netRange.Mask.Size()
// lastOctet is the last IP byte covered by the mask // lastOctet is the last IP byte covered by the mask
lastOctet := maskBits / BYTE_SIZE lastOctet := maskBits / ByteSize
// wildcardBits is the number of bits not under the mask in the lastOctet // wildcardBits is the number of bits not under the mask in the lastOctet
wildcardBits := BYTE_SIZE - maskBits%BYTE_SIZE wildcardBits := ByteSize - maskBits%ByteSize
// min is the value in the lastOctet byte of the IP // min is the value in the lastOctet byte of the IP
// max is basically 2^wildcardBits - i.e., the value when all the wildcardBits are set to 1 // max is basically 2^wildcardBits - i.e., the value when all the wildcardBits are set to 1

View file

@ -338,7 +338,7 @@ func (api headscaleV1APIServer) DebugCreateMachine(
return nil, err return nil, err
} }
routes, err := stringToIpPrefix(request.GetRoutes()) routes, err := stringToIPPrefix(request.GetRoutes())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -529,7 +529,7 @@ func (machine Machine) toNode(
node := tailcfg.Node{ node := tailcfg.Node{
ID: tailcfg.NodeID(machine.ID), // this is the actual ID ID: tailcfg.NodeID(machine.ID), // this is the actual ID
StableID: tailcfg.StableNodeID( StableID: tailcfg.StableNodeID(
strconv.FormatUint(machine.ID, BASE_10), strconv.FormatUint(machine.ID, Base10),
), // in headscale, unlike tailcontrol server, IDs are permanent ), // in headscale, unlike tailcontrol server, IDs are permanent
Name: hostname, Name: hostname,
User: tailcfg.UserID(machine.NamespaceID), User: tailcfg.UserID(machine.NamespaceID),
@ -736,7 +736,7 @@ func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error {
} }
for _, newRoute := range newRoutes { for _, newRoute := range newRoutes {
if !containsIpPrefix(availableRoutes, newRoute) { if !containsIPPrefix(availableRoutes, newRoute) {
return fmt.Errorf( return fmt.Errorf(
"route (%s) is not available on node %s", "route (%s) is not available on node %s",
machine.Name, machine.Name,

View file

@ -321,7 +321,7 @@ func getMapResponseUserProfiles(machine Machine, peers Machines) []tailcfg.UserP
func (n *Namespace) toProto() *v1.Namespace { func (n *Namespace) toProto() *v1.Namespace {
return &v1.Namespace{ return &v1.Namespace{
Id: strconv.FormatUint(uint64(n.ID), BASE_10), Id: strconv.FormatUint(uint64(n.ID), Base10),
Name: n.Name, Name: n.Name,
CreatedAt: timestamppb.New(n.CreatedAt), CreatedAt: timestamppb.New(n.CreatedAt),
} }

20
oidc.go
View file

@ -18,9 +18,9 @@ import (
) )
const ( const (
OIDC_STATE_CACHE_EXPIRATION = time.Minute * 5 oidcStateCacheExpiration = time.Minute * 5
OIDC_STATE_CACHE_CLEANUP_INTERVAL = time.Minute * 10 oidcStateCacheCleanupInterval = time.Minute * 10
RANDOM_BYTE_SIZE = 16 randomByteSize = 16
) )
type IDTokenClaims struct { type IDTokenClaims struct {
@ -57,8 +57,8 @@ func (h *Headscale) initOIDC() error {
// init the state cache if it hasn't been already // init the state cache if it hasn't been already
if h.oidcStateCache == nil { if h.oidcStateCache == nil {
h.oidcStateCache = cache.New( h.oidcStateCache = cache.New(
OIDC_STATE_CACHE_EXPIRATION, oidcStateCacheExpiration,
OIDC_STATE_CACHE_CLEANUP_INTERVAL, oidcStateCacheCleanupInterval,
) )
} }
@ -76,7 +76,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
return return
} }
randomBlob := make([]byte, RANDOM_BYTE_SIZE) randomBlob := make([]byte, randomByteSize)
if _, err := rand.Read(randomBlob); err != nil { if _, err := rand.Read(randomBlob); err != nil {
log.Error().Msg("could not read 16 bytes from rand") log.Error().Msg("could not read 16 bytes from rand")
ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand") ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
@ -87,12 +87,12 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
stateStr := hex.EncodeToString(randomBlob)[:32] stateStr := hex.EncodeToString(randomBlob)[:32]
// place the machine key into the state cache, so it can be retrieved later // place the machine key into the state cache, so it can be retrieved later
h.oidcStateCache.Set(stateStr, mKeyStr, OIDC_STATE_CACHE_EXPIRATION) h.oidcStateCache.Set(stateStr, mKeyStr, oidcStateCacheExpiration)
authUrl := h.oauth2Config.AuthCodeURL(stateStr) authURL := h.oauth2Config.AuthCodeURL(stateStr)
log.Debug().Msgf("Redirecting to %s for authentication", authUrl) log.Debug().Msgf("Redirecting to %s for authentication", authURL)
ctx.Redirect(http.StatusFound, authUrl) ctx.Redirect(http.StatusFound, authURL)
} }
// OIDCCallback handles the callback from the OIDC endpoint // OIDCCallback handles the callback from the OIDC endpoint

View file

@ -16,8 +16,8 @@ import (
) )
const ( const (
KEEP_ALIVE_INTERVAL = 60 * time.Second keepAliveInterval = 60 * time.Second
UPDATE_CHECK_INTERVAL = 10 * time.Second updateCheckInterval = 10 * time.Second
) )
// PollNetMapHandler takes care of /machine/:id/map // PollNetMapHandler takes care of /machine/:id/map
@ -495,8 +495,8 @@ func (h *Headscale) scheduledPollWorker(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
machine *Machine, machine *Machine,
) { ) {
keepAliveTicker := time.NewTicker(KEEP_ALIVE_INTERVAL) keepAliveTicker := time.NewTicker(keepAliveInterval)
updateCheckerTicker := time.NewTicker(UPDATE_CHECK_INTERVAL) updateCheckerTicker := time.NewTicker(updateCheckInterval)
for { for {
select { select {

View file

@ -156,7 +156,7 @@ func (h *Headscale) generateKey() (string, error) {
func (key *PreAuthKey) toProto() *v1.PreAuthKey { func (key *PreAuthKey) toProto() *v1.PreAuthKey {
protoKey := v1.PreAuthKey{ protoKey := v1.PreAuthKey{
Namespace: key.Namespace.Name, Namespace: key.Namespace.Name,
Id: strconv.FormatUint(key.ID, BASE_10), Id: strconv.FormatUint(key.ID, Base10),
Key: key.Key, Key: key.Key,
Ephemeral: key.Ephemeral, Ephemeral: key.Ephemeral,
Reusable: key.Reusable, Reusable: key.Reusable,

View file

@ -197,7 +197,7 @@ func ipPrefixToString(prefixes []netaddr.IPPrefix) []string {
return result return result
} }
func stringToIpPrefix(prefixes []string) ([]netaddr.IPPrefix, error) { func stringToIPPrefix(prefixes []string) ([]netaddr.IPPrefix, error) {
result := make([]netaddr.IPPrefix, len(prefixes)) result := make([]netaddr.IPPrefix, len(prefixes))
for index, prefixStr := range prefixes { for index, prefixStr := range prefixes {
@ -212,7 +212,7 @@ func stringToIpPrefix(prefixes []string) ([]netaddr.IPPrefix, error) {
return result, nil return result, nil
} }
func containsIpPrefix(prefixes []netaddr.IPPrefix, prefix netaddr.IPPrefix) bool { func containsIPPrefix(prefixes []netaddr.IPPrefix, prefix netaddr.IPPrefix) bool {
for _, p := range prefixes { for _, p := range prefixes {
if prefix == p { if prefix == p {
return true return true