Implement namespace matching
This commit is contained in:
parent
a347d276bd
commit
677bd9b657
5 changed files with 267 additions and 55 deletions
4
api.go
4
api.go
|
@ -170,7 +170,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||||
Str("machine", m.Name).
|
Str("machine", m.Name).
|
||||||
Msg("Machine registration has expired. Sending a authurl to register")
|
Msg("Machine registration has expired. Sending a authurl to register")
|
||||||
|
|
||||||
if h.cfg.OIDCIssuer != "" {
|
if h.cfg.OIDC.Issuer != "" {
|
||||||
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
|
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
|
||||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
||||||
} else {
|
} else {
|
||||||
|
@ -225,7 +225,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||||
Str("handler", "Registration").
|
Str("handler", "Registration").
|
||||||
Str("machine", m.Name).
|
Str("machine", m.Name).
|
||||||
Msg("The node is sending us a new NodeKey, sending auth url")
|
Msg("The node is sending us a new NodeKey, sending auth url")
|
||||||
if h.cfg.OIDCIssuer != "" {
|
if h.cfg.OIDC.Issuer != "" {
|
||||||
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
||||||
} else {
|
} else {
|
||||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||||
|
|
22
app.go
22
app.go
|
@ -3,9 +3,6 @@ package headscale
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
|
||||||
"github.com/patrickmn/go-cache"
|
|
||||||
"golang.org/x/oauth2"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
|
@ -13,6 +10,10 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/patrickmn/go-cache"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
@ -57,14 +58,19 @@ type Config struct {
|
||||||
|
|
||||||
DNSConfig *tailcfg.DNSConfig
|
DNSConfig *tailcfg.DNSConfig
|
||||||
|
|
||||||
OIDCIssuer string
|
OIDC OIDCConfig
|
||||||
OIDCClientID string
|
|
||||||
OIDCClientSecret string
|
|
||||||
|
|
||||||
MaxMachineRegistrationDuration time.Duration
|
MaxMachineRegistrationDuration time.Duration
|
||||||
DefaultMachineRegistrationDuration time.Duration
|
DefaultMachineRegistrationDuration time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OIDCConfig struct {
|
||||||
|
Issuer string
|
||||||
|
ClientID string
|
||||||
|
ClientSecret string
|
||||||
|
MatchMap map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
// Headscale represents the base app of the service
|
// Headscale represents the base app of the service
|
||||||
type Headscale struct {
|
type Headscale struct {
|
||||||
cfg Config
|
cfg Config
|
||||||
|
@ -122,7 +128,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.OIDCIssuer != "" {
|
if cfg.OIDC.Issuer != "" {
|
||||||
err = h.initOIDC()
|
err = h.initOIDC()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -294,7 +300,6 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time {
|
||||||
|
|
||||||
times = append(times, lastChange)
|
times = append(times, lastChange)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Slice(times, func(i, j int) bool {
|
sort.Slice(times, func(i, j int) bool {
|
||||||
|
@ -305,7 +310,6 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time {
|
||||||
|
|
||||||
if len(times) == 0 {
|
if len(times) == 0 {
|
||||||
return time.Now().UTC()
|
return time.Now().UTC()
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
return times[0]
|
return times[0]
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -73,7 +74,6 @@ func LoadConfig(path string) error {
|
||||||
} else {
|
} else {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetDNSConfig() (*tailcfg.DNSConfig, string) {
|
func GetDNSConfig() (*tailcfg.DNSConfig, string) {
|
||||||
|
@ -206,15 +206,19 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
||||||
ACMEEmail: viper.GetString("acme_email"),
|
ACMEEmail: viper.GetString("acme_email"),
|
||||||
ACMEURL: viper.GetString("acme_url"),
|
ACMEURL: viper.GetString("acme_url"),
|
||||||
|
|
||||||
OIDCIssuer: viper.GetString("oidc_issuer"),
|
OIDC: headscale.OIDCConfig{
|
||||||
OIDCClientID: viper.GetString("oidc_client_id"),
|
Issuer: viper.GetString("oidc.issuer"),
|
||||||
OIDCClientSecret: viper.GetString("oidc_client_secret"),
|
ClientID: viper.GetString("oidc.client_id"),
|
||||||
|
ClientSecret: viper.GetString("oidc.client_secret"),
|
||||||
|
},
|
||||||
|
|
||||||
MaxMachineRegistrationDuration: maxMachineRegistrationDuration, // the maximum duration a client may request for expiry time
|
MaxMachineRegistrationDuration: maxMachineRegistrationDuration, // the maximum duration a client may request for expiry time
|
||||||
DefaultMachineRegistrationDuration: defaultMachineRegistrationDuration, // if a client does not request a specific expiry time, use this duration
|
DefaultMachineRegistrationDuration: defaultMachineRegistrationDuration, // if a client does not request a specific expiry time, use this duration
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cfg.OIDC.MatchMap = loadOIDCMatchMap()
|
||||||
|
|
||||||
h, err := headscale.NewHeadscale(cfg)
|
h, err := headscale.NewHeadscale(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -291,3 +295,15 @@ func HasJsonOutputFlag() bool {
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// loadOIDCMatchMap is a wrapper around viper to verifies that the keys in
|
||||||
|
// the match map is valid regex strings.
|
||||||
|
func loadOIDCMatchMap() map[string]string {
|
||||||
|
strMap := viper.GetStringMapString("oidc.domain_map")
|
||||||
|
|
||||||
|
for oidcMatcher := range strMap {
|
||||||
|
_ = regexp.MustCompile(oidcMatcher)
|
||||||
|
}
|
||||||
|
|
||||||
|
return strMap
|
||||||
|
}
|
||||||
|
|
41
oidc.go
41
oidc.go
|
@ -5,14 +5,16 @@ import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type IDTokenClaims struct {
|
type IDTokenClaims struct {
|
||||||
|
@ -26,7 +28,7 @@ func (h *Headscale) initOIDC() error {
|
||||||
var err error
|
var err error
|
||||||
// grab oidc config if it hasn't been already
|
// grab oidc config if it hasn't been already
|
||||||
if h.oauth2Config == nil {
|
if h.oauth2Config == nil {
|
||||||
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDCIssuer)
|
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
|
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
|
||||||
|
@ -34,8 +36,8 @@ func (h *Headscale) initOIDC() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
h.oauth2Config = &oauth2.Config{
|
h.oauth2Config = &oauth2.Config{
|
||||||
ClientID: h.cfg.OIDCClientID,
|
ClientID: h.cfg.OIDC.ClientID,
|
||||||
ClientSecret: h.cfg.OIDCClientSecret,
|
ClientSecret: h.cfg.OIDC.ClientSecret,
|
||||||
Endpoint: h.oidcProvider.Endpoint(),
|
Endpoint: h.oidcProvider.Endpoint(),
|
||||||
RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")),
|
RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")),
|
||||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||||
|
@ -62,7 +64,6 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
|
||||||
|
|
||||||
b := make([]byte, 16)
|
b := make([]byte, 16)
|
||||||
_, err := rand.Read(b)
|
_, err := rand.Read(b)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Msg("could not read 16 bytes from rand")
|
log.Error().Msg("could not read 16 bytes from rand")
|
||||||
c.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
|
c.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
|
||||||
|
@ -86,7 +87,6 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
|
||||||
// TODO: Add groups information from OIDC tokens into machine HostInfo
|
// TODO: Add groups information from OIDC tokens into machine HostInfo
|
||||||
// Listens in /oidc/callback
|
// Listens in /oidc/callback
|
||||||
func (h *Headscale) OIDCCallback(c *gin.Context) {
|
func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
|
|
||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
state := c.Query("state")
|
state := c.Query("state")
|
||||||
|
|
||||||
|
@ -109,7 +109,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDCClientID})
|
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID})
|
||||||
|
|
||||||
idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -149,7 +149,6 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
|
|
||||||
// retrieve machine information
|
// retrieve machine information
|
||||||
m, err := h.GetMachineByMachineKey(mKeyStr)
|
m, err := h.GetMachineByMachineKey(mKeyStr)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Msg("machine key not found in database")
|
log.Error().Msg("machine key not found in database")
|
||||||
c.String(http.StatusInternalServerError, "could not get machine info from database")
|
c.String(http.StatusInternalServerError, "could not get machine info from database")
|
||||||
|
@ -158,9 +157,9 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
|
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok {
|
||||||
// register the machine if it's new
|
// register the machine if it's new
|
||||||
if !m.Registered {
|
if !m.Registered {
|
||||||
nsName := strings.ReplaceAll(claims.Email, "@", "-") // TODO: Implement a better email sanitisation
|
|
||||||
|
|
||||||
log.Debug().Msg("Registering new machine after successful callback")
|
log.Debug().Msg("Registering new machine after successful callback")
|
||||||
|
|
||||||
|
@ -202,4 +201,24 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||||
</html>
|
</html>
|
||||||
|
|
||||||
`, claims.Email)))
|
`, claims.Email)))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Error().
|
||||||
|
Str("email", claims.Email).
|
||||||
|
Str("username", claims.Username).
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Msg("Email could not be mapped to a namespace")
|
||||||
|
c.String(http.StatusBadRequest, "email from claim could not be mapped to a namespace")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Headscale) getNamespaceFromEmail(email string) (string, bool) {
|
||||||
|
for match, namespace := range h.cfg.OIDC.MatchMap {
|
||||||
|
regex := regexp.MustCompile(match)
|
||||||
|
if regex.MatchString(email) {
|
||||||
|
return namespace, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
173
oidc_test.go
Normal file
173
oidc_test.go
Normal file
|
@ -0,0 +1,173 @@
|
||||||
|
package headscale
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/patrickmn/go-cache"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/types/wgkey"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHeadscale_getNamespaceFromEmail(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
cfg Config
|
||||||
|
db *gorm.DB
|
||||||
|
dbString string
|
||||||
|
dbType string
|
||||||
|
dbDebug bool
|
||||||
|
publicKey *wgkey.Key
|
||||||
|
privateKey *wgkey.Private
|
||||||
|
aclPolicy *ACLPolicy
|
||||||
|
aclRules *[]tailcfg.FilterRule
|
||||||
|
lastStateChange sync.Map
|
||||||
|
oidcProvider *oidc.Provider
|
||||||
|
oauth2Config *oauth2.Config
|
||||||
|
oidcStateCache *cache.Cache
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
email string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
want string
|
||||||
|
want1 bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "match all",
|
||||||
|
fields: fields{
|
||||||
|
cfg: Config{
|
||||||
|
OIDC: OIDCConfig{
|
||||||
|
MatchMap: map[string]string{
|
||||||
|
".*": "space",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
email: "test@example.no",
|
||||||
|
},
|
||||||
|
want: "space",
|
||||||
|
want1: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "match user",
|
||||||
|
fields: fields{
|
||||||
|
cfg: Config{
|
||||||
|
OIDC: OIDCConfig{
|
||||||
|
MatchMap: map[string]string{
|
||||||
|
"specific@user\\.no": "user-namespace",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
email: "specific@user.no",
|
||||||
|
},
|
||||||
|
want: "user-namespace",
|
||||||
|
want1: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "match domain",
|
||||||
|
fields: fields{
|
||||||
|
cfg: Config{
|
||||||
|
OIDC: OIDCConfig{
|
||||||
|
MatchMap: map[string]string{
|
||||||
|
".*@example\\.no": "example",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
email: "test@example.no",
|
||||||
|
},
|
||||||
|
want: "example",
|
||||||
|
want1: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi match domain",
|
||||||
|
fields: fields{
|
||||||
|
cfg: Config{
|
||||||
|
OIDC: OIDCConfig{
|
||||||
|
MatchMap: map[string]string{
|
||||||
|
".*@example\\.no": "exammple",
|
||||||
|
".*@gmail\\.com": "gmail",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
email: "someuser@gmail.com",
|
||||||
|
},
|
||||||
|
want: "gmail",
|
||||||
|
want1: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no match domain",
|
||||||
|
fields: fields{
|
||||||
|
cfg: Config{
|
||||||
|
OIDC: OIDCConfig{
|
||||||
|
MatchMap: map[string]string{
|
||||||
|
".*@dontknow.no": "never",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
email: "test@wedontknow.no",
|
||||||
|
},
|
||||||
|
want: "",
|
||||||
|
want1: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi no match domain",
|
||||||
|
fields: fields{
|
||||||
|
cfg: Config{
|
||||||
|
OIDC: OIDCConfig{
|
||||||
|
MatchMap: map[string]string{
|
||||||
|
".*@dontknow.no": "never",
|
||||||
|
".*@wedontknow.no": "other",
|
||||||
|
".*\\.no": "stuffy",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
email: "tasy@nonofthem.com",
|
||||||
|
},
|
||||||
|
want: "",
|
||||||
|
want1: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
h := &Headscale{
|
||||||
|
cfg: tt.fields.cfg,
|
||||||
|
db: tt.fields.db,
|
||||||
|
dbString: tt.fields.dbString,
|
||||||
|
dbType: tt.fields.dbType,
|
||||||
|
dbDebug: tt.fields.dbDebug,
|
||||||
|
publicKey: tt.fields.publicKey,
|
||||||
|
privateKey: tt.fields.privateKey,
|
||||||
|
aclPolicy: tt.fields.aclPolicy,
|
||||||
|
aclRules: tt.fields.aclRules,
|
||||||
|
lastStateChange: tt.fields.lastStateChange,
|
||||||
|
oidcProvider: tt.fields.oidcProvider,
|
||||||
|
oauth2Config: tt.fields.oauth2Config,
|
||||||
|
oidcStateCache: tt.fields.oidcStateCache,
|
||||||
|
}
|
||||||
|
got, got1 := h.getNamespaceFromEmail(tt.args.email)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("Headscale.getNamespaceFromEmail() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
if got1 != tt.want1 {
|
||||||
|
t.Errorf("Headscale.getNamespaceFromEmail() got1 = %v, want %v", got1, tt.want1)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue