diff --git a/acls.go b/acls.go index 0b365c1..bb8c4bd 100644 --- a/acls.go +++ b/acls.go @@ -10,6 +10,7 @@ import ( "path/filepath" "strconv" "strings" + "time" "github.com/rs/zerolog/log" "github.com/tailscale/hujson" @@ -120,6 +121,16 @@ func (h *Headscale) UpdateACLRules() error { log.Trace().Interface("ACL", rules).Msg("ACL rules generated") h.aclRules = rules + sshRules, err := h.generateSSHRules() + if err != nil { + return err + } + log.Trace().Interface("SSH", sshRules).Msg("SSH rules generated") + if h.sshPolicy == nil { + h.sshPolicy = &tailcfg.SSHPolicy{} + } + h.sshPolicy.Rules = sshRules + return nil } @@ -187,6 +198,111 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) { return rules, nil } +func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { + rules := []*tailcfg.SSHRule{} + + if h.aclPolicy == nil { + return nil, errEmptyPolicy + } + + machines, err := h.ListMachines() + if err != nil { + return nil, err + } + + acceptAction := tailcfg.SSHAction{ + Message: "", + Reject: false, + Accept: true, + SessionDuration: 0, + AllowAgentForwarding: false, + HoldAndDelegate: "", + AllowLocalPortForwarding: true, + } + + rejectAction := tailcfg.SSHAction{ + Message: "", + Reject: true, + Accept: false, + SessionDuration: 0, + AllowAgentForwarding: false, + HoldAndDelegate: "", + AllowLocalPortForwarding: false, + } + + for index, sshACL := range h.aclPolicy.SSHs { + action := rejectAction + switch sshACL.Action { + case "accept": + action = acceptAction + case "check": + checkAction, err := sshCheckAction(sshACL.CheckPeriod) + if err != nil { + log.Error(). + Msgf("Error parsing SSH %d, check action with unparsable duration '%s'", index, sshACL.CheckPeriod) + } else { + action = *checkAction + } + default: + log.Error(). + Msgf("Error parsing SSH %d, unknown action '%s'", index, sshACL.Action) + + return nil, err + } + + principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources)) + for innerIndex, rawSrc := range sshACL.Sources { + expandedSrcs, err := expandAlias( + machines, + *h.aclPolicy, + rawSrc, + h.cfg.OIDC.StripEmaildomain, + ) + if err != nil { + log.Error(). + Msgf("Error parsing SSH %d, Source %d", index, innerIndex) + + return nil, err + } + for _, expandedSrc := range expandedSrcs { + principals = append(principals, &tailcfg.SSHPrincipal{ + NodeIP: expandedSrc, + }) + } + } + + userMap := make(map[string]string, len(sshACL.Users)) + for _, user := range sshACL.Users { + userMap[user] = "=" + } + rules = append(rules, &tailcfg.SSHRule{ + RuleExpires: nil, + Principals: principals, + SSHUsers: userMap, + Action: &action, + }) + } + + return rules, nil +} + +func sshCheckAction(duration string) (*tailcfg.SSHAction, error) { + sessionLength, err := time.ParseDuration(duration) + if err != nil { + return nil, err + } + + return &tailcfg.SSHAction{ + Message: "", + Reject: false, + Accept: true, + SessionDuration: sessionLength, + AllowAgentForwarding: false, + HoldAndDelegate: "", + AllowLocalPortForwarding: true, + }, nil +} + func (h *Headscale) generateACLPolicySrcIP( machines []Machine, aclPolicy ACLPolicy, diff --git a/acls_test.go b/acls_test.go index 41f8d39..0f3c262 100644 --- a/acls_test.go +++ b/acls_test.go @@ -73,6 +73,79 @@ func (s *Suite) TestInvalidAction(c *check.C) { c.Assert(errors.Is(err, errInvalidAction), check.Equals, true) } +func (s *Suite) TestSshRules(c *check.C) { + namespace, err := app.CreateNamespace("user1") + c.Assert(err, check.IsNil) + + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = app.GetMachine("user1", "testmachine") + c.Assert(err, check.NotNil) + hostInfo := tailcfg.Hostinfo{ + OS: "centos", + Hostname: "testmachine", + RequestTags: []string{"tag:test"}, + } + + machine := Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")}, + NamespaceID: namespace.ID, + RegisterMethod: RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + HostInfo: HostInfo(hostInfo), + } + app.db.Save(&machine) + + app.aclPolicy = &ACLPolicy{ + Groups: Groups{ + "group:test": []string{"user1"}, + }, + Hosts: Hosts{ + "client": netip.PrefixFrom(netip.MustParseAddr("100.64.99.42"), 32), + }, + ACLs: []ACL{ + { + Action: "accept", + Sources: []string{"*"}, + Destinations: []string{"*:*"}, + }, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: []string{"group:test"}, + Destinations: []string{"client"}, + Users: []string{"autogroup:nonroot"}, + }, + { + Action: "accept", + Sources: []string{"*"}, + Destinations: []string{"client"}, + Users: []string{"autogroup:nonroot"}, + }, + }, + } + + err = app.UpdateACLRules() + + c.Assert(err, check.IsNil) + c.Assert(app.sshPolicy, check.NotNil) + c.Assert(app.sshPolicy.Rules, check.HasLen, 2) + c.Assert(app.sshPolicy.Rules[0].SSHUsers, check.HasLen, 1) + c.Assert(app.sshPolicy.Rules[0].Principals, check.HasLen, 1) + c.Assert(app.sshPolicy.Rules[0].Principals[0].NodeIP, check.Matches, "100.64.0.1") + + c.Assert(app.sshPolicy.Rules[1].SSHUsers, check.HasLen, 1) + c.Assert(app.sshPolicy.Rules[1].Principals, check.HasLen, 1) + c.Assert(app.sshPolicy.Rules[1].Principals[0].NodeIP, check.Matches, "*") +} + func (s *Suite) TestInvalidGroupInGroup(c *check.C) { // this ACL is wrong because the group in Sources sections doesn't exist app.aclPolicy = &ACLPolicy{ diff --git a/acls_types.go b/acls_types.go index 638a456..da981d3 100644 --- a/acls_types.go +++ b/acls_types.go @@ -17,6 +17,7 @@ type ACLPolicy struct { ACLs []ACL `json:"acls" yaml:"acls"` Tests []ACLTest `json:"tests" yaml:"tests"` AutoApprovers AutoApprovers `json:"autoApprovers" yaml:"autoApprovers"` + SSHs []SSH `json:"ssh" yaml:"ssh"` } // ACL is a basic rule for the ACL Policy. @@ -50,6 +51,15 @@ type AutoApprovers struct { ExitNode []string `json:"exitNode" yaml:"exitNode"` } +// SSH controls who can ssh into which machines. +type SSH struct { + Action string `json:"action" yaml:"action"` + Sources []string `json:"src" yaml:"src"` + Destinations []string `json:"dst" yaml:"dst"` + Users []string `json:"users" yaml:"users"` + CheckPeriod string `json:"checkPeriod,omitempty" yaml:"checkPeriod,omitempty"` +} + // UnmarshalJSON allows to parse the Hosts directly into netip objects. func (hosts *Hosts) UnmarshalJSON(data []byte) error { newHosts := Hosts{} diff --git a/api_common.go b/api_common.go index 9e7ff48..c4f9c79 100644 --- a/api_common.go +++ b/api_common.go @@ -62,6 +62,7 @@ func (h *Headscale) generateMapResponse( DNSConfig: dnsConfig, Domain: h.cfg.BaseDomain, PacketFilter: h.aclRules, + SSHPolicy: h.sshPolicy, DERPMap: h.DERPMap, UserProfiles: profiles, Debug: &tailcfg.Debug{ diff --git a/app.go b/app.go index 69d4079..aec8d68 100644 --- a/app.go +++ b/app.go @@ -88,6 +88,7 @@ type Headscale struct { aclPolicy *ACLPolicy aclRules []tailcfg.FilterRule + sshPolicy *tailcfg.SSHPolicy lastStateChange *xsync.MapOf[string, time.Time] diff --git a/machine.go b/machine.go index b688be6..e9dbb60 100644 --- a/machine.go +++ b/machine.go @@ -744,7 +744,11 @@ func (machine Machine) toNode( KeepAlive: true, MachineAuthorized: !machine.isExpired(), - Capabilities: []string{tailcfg.CapabilityFileSharing}, + Capabilities: []string{ + tailcfg.CapabilityFileSharing, + tailcfg.CapabilityAdmin, + tailcfg.CapabilitySSH, + }, } return &node, nil