Switch from gRPC localhost to socket
This commit changes the way CLI and grpc-gateway communicates with the gRPC backend to socket, instead of localhost. Unauthenticated access now goes on the socket, while the network interface will require API key (in the future).
This commit is contained in:
parent
72fd2a2780
commit
6aacada852
6 changed files with 50 additions and 134 deletions
|
@ -14,3 +14,5 @@ docker-compose*
|
||||||
README.md
|
README.md
|
||||||
LICENSE
|
LICENSE
|
||||||
.vscode
|
.vscode
|
||||||
|
|
||||||
|
*.sock
|
||||||
|
|
34
app.go
34
app.go
|
@ -39,8 +39,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
LOCALHOST_V4 = "127.0.0.1"
|
|
||||||
LOCALHOST_V6 = "[::1]"
|
|
||||||
AUTH_PREFIX = "Bearer "
|
AUTH_PREFIX = "Bearer "
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -75,6 +73,8 @@ type Config struct {
|
||||||
ACMEEmail string
|
ACMEEmail string
|
||||||
|
|
||||||
DNSConfig *tailcfg.DNSConfig
|
DNSConfig *tailcfg.DNSConfig
|
||||||
|
|
||||||
|
UnixSocket string
|
||||||
}
|
}
|
||||||
|
|
||||||
type DERPConfig struct {
|
type DERPConfig struct {
|
||||||
|
@ -233,8 +233,9 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
||||||
// the server
|
// the server
|
||||||
p, _ := peer.FromContext(ctx)
|
p, _ := peer.FromContext(ctx)
|
||||||
|
|
||||||
if IsLocalhost(p.Addr.String()) {
|
// TODO(kradalby): Figure out what @ means (socket wise) and if it can be exploited
|
||||||
log.Trace().Caller().Str("client_address", p.Addr.String()).Msg("Client connected from localhost")
|
if p.Addr.String() == "@" {
|
||||||
|
log.Trace().Caller().Str("client_address", p.Addr.String()).Msg("Client connecting over socket")
|
||||||
|
|
||||||
return handler(ctx, req)
|
return handler(ctx, req)
|
||||||
}
|
}
|
||||||
|
@ -326,14 +327,19 @@ func (h *Headscale) Serve() error {
|
||||||
|
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
l, err := net.Listen("tcp", h.cfg.Addr)
|
socketListener, err := net.Listen("unix", h.cfg.UnixSocket)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
networkListener, err := net.Listen("tcp", h.cfg.Addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the cmux object that will multiplex 2 protocols on the same port.
|
// Create the cmux object that will multiplex 2 protocols on the same port.
|
||||||
// The two following listeners will be served on the same port below gracefully.
|
// The two following listeners will be served on the same port below gracefully.
|
||||||
m := cmux.New(l)
|
m := cmux.New(networkListener)
|
||||||
// Match gRPC requests here
|
// Match gRPC requests here
|
||||||
grpcListener := m.MatchWithWriters(
|
grpcListener := m.MatchWithWriters(
|
||||||
cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"),
|
cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"),
|
||||||
|
@ -344,16 +350,23 @@ func (h *Headscale) Serve() error {
|
||||||
|
|
||||||
grpcGatewayMux := runtime.NewServeMux()
|
grpcGatewayMux := runtime.NewServeMux()
|
||||||
|
|
||||||
grpcDialOptions := []grpc.DialOption{grpc.WithInsecure()}
|
// Make the grpc-gateway connect to grpc over socket
|
||||||
|
grpcGatewayConn, err := grpc.Dial(
|
||||||
_, port, err := net.SplitHostPort(h.cfg.Addr)
|
h.cfg.UnixSocket,
|
||||||
|
[]grpc.DialOption{
|
||||||
|
grpc.WithInsecure(),
|
||||||
|
grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) {
|
||||||
|
return net.DialTimeout("unix", addr, timeout)
|
||||||
|
}),
|
||||||
|
}...,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect to the gRPC server over localhost to skip
|
// Connect to the gRPC server over localhost to skip
|
||||||
// the authentication.
|
// the authentication.
|
||||||
err = apiV1.RegisterHeadscaleServiceHandlerFromEndpoint(ctx, grpcGatewayMux, LOCALHOST_V4+":"+port, grpcDialOptions)
|
err = apiV1.RegisterHeadscaleServiceHandler(ctx, grpcGatewayMux, grpcGatewayConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -432,6 +445,7 @@ func (h *Headscale) Serve() error {
|
||||||
|
|
||||||
g := new(errgroup.Group)
|
g := new(errgroup.Group)
|
||||||
|
|
||||||
|
g.Go(func() error { return grpcServer.Serve(socketListener) })
|
||||||
g.Go(func() error { return grpcServer.Serve(grpcListener) })
|
g.Go(func() error { return grpcServer.Serve(grpcListener) })
|
||||||
g.Go(func() error { return httpServer.Serve(httpListener) })
|
g.Go(func() error { return httpServer.Serve(httpListener) })
|
||||||
g.Go(func() error { return m.Serve() })
|
g.Go(func() error { return m.Serve() })
|
||||||
|
|
|
@ -48,6 +48,8 @@ func LoadConfig(path string) error {
|
||||||
|
|
||||||
viper.SetDefault("dns_config", nil)
|
viper.SetDefault("dns_config", nil)
|
||||||
|
|
||||||
|
viper.SetDefault("unix_socket", "/var/run/headscale.sock")
|
||||||
|
|
||||||
err := viper.ReadInConfig()
|
err := viper.ReadInConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Fatal error reading config file: %s \n", err)
|
return fmt.Errorf("Fatal error reading config file: %s \n", err)
|
||||||
|
@ -242,6 +244,8 @@ func getHeadscaleConfig() headscale.Config {
|
||||||
|
|
||||||
ACMEEmail: viper.GetString("acme_email"),
|
ACMEEmail: viper.GetString("acme_email"),
|
||||||
ACMEURL: viper.GetString("acme_url"),
|
ACMEURL: viper.GetString("acme_url"),
|
||||||
|
|
||||||
|
UnixSocket: viper.GetString("unix_socket"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -282,11 +286,11 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func getHeadscaleGRPCClient() (apiV1.HeadscaleServiceClient, *grpc.ClientConn) {
|
func getHeadscaleGRPCClient() (apiV1.HeadscaleServiceClient, *grpc.ClientConn) {
|
||||||
|
// TODO(kradalby): Make configurable
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
grpcOptions := []grpc.DialOption{
|
grpcOptions := []grpc.DialOption{
|
||||||
// TODO(kradalby): Make configurable
|
|
||||||
grpc.WithBlock(),
|
grpc.WithBlock(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -294,19 +298,24 @@ func getHeadscaleGRPCClient() (apiV1.HeadscaleServiceClient, *grpc.ClientConn) {
|
||||||
|
|
||||||
// If the address is not set, we assume that we are on the server hosting headscale.
|
// If the address is not set, we assume that we are on the server hosting headscale.
|
||||||
if address == "" {
|
if address == "" {
|
||||||
log.Debug().Msgf("HEADSCALE_ADDRESS environment is not set, connecting to localhost.")
|
|
||||||
|
|
||||||
cfg := getHeadscaleConfig()
|
cfg := getHeadscaleConfig()
|
||||||
|
|
||||||
_, port, _ := net.SplitHostPort(cfg.Addr)
|
log.Debug().
|
||||||
|
Str("socket", cfg.UnixSocket).
|
||||||
|
Msgf("HEADSCALE_ADDRESS environment is not set, connecting to unix socket.")
|
||||||
|
|
||||||
address = "127.0.0.1" + ":" + port
|
address = cfg.UnixSocket
|
||||||
|
|
||||||
grpcOptions = append(grpcOptions, grpc.WithInsecure())
|
|
||||||
}
|
|
||||||
|
|
||||||
|
grpcOptions = append(
|
||||||
|
grpcOptions,
|
||||||
|
grpc.WithInsecure(),
|
||||||
|
grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) {
|
||||||
|
return net.DialTimeout("unix", addr, timeout)
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
// If we are not connecting to a local server, require an API key for authentication
|
// If we are not connecting to a local server, require an API key for authentication
|
||||||
if !headscale.IsLocalhost(address) {
|
|
||||||
apiKey := os.Getenv("HEADSCALE_API_KEY")
|
apiKey := os.Getenv("HEADSCALE_API_KEY")
|
||||||
if apiKey == "" {
|
if apiKey == "" {
|
||||||
log.Fatal().Msgf("HEADSCALE_API_KEY environment variable needs to be set.")
|
log.Fatal().Msgf("HEADSCALE_API_KEY environment variable needs to be set.")
|
||||||
|
|
|
@ -64,3 +64,8 @@ dns_config:
|
||||||
|
|
||||||
magic_dns: true
|
magic_dns: true
|
||||||
base_domain: example.com
|
base_domain: example.com
|
||||||
|
|
||||||
|
# Unix socket used for the CLI to connect without authentication
|
||||||
|
# Note: for local development, you probably want to change this to:
|
||||||
|
# unix_socket: ./headscale.sock
|
||||||
|
unix_socket: /var/run/headscale.sock
|
||||||
|
|
|
@ -1,106 +0,0 @@
|
||||||
syntax = "proto3";
|
|
||||||
package headscale.v1;
|
|
||||||
option go_package = "github.com/juanfont/headscale/gen/go/v1";
|
|
||||||
|
|
||||||
import "google/protobuf/timestamp.proto";
|
|
||||||
import "google/api/annotations.proto";
|
|
||||||
|
|
||||||
enum RegisterMethod {
|
|
||||||
REGISTER_METHOD_UNSPECIFIED = 0;
|
|
||||||
REGISTER_METHOD_AUTH_KEY = 1;
|
|
||||||
REGISTER_METHOD_CLI = 2;
|
|
||||||
REGISTER_METHOD_OIDC = 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
// message PreAuthKey {
|
|
||||||
// uint64 id = 1;
|
|
||||||
// string key = 2;
|
|
||||||
// uint32 namespace_id = 3;
|
|
||||||
// Namespace namespace = 4;
|
|
||||||
// bool reusable = 5;
|
|
||||||
// bool ephemeral = 6;
|
|
||||||
// bool used = 7;
|
|
||||||
//
|
|
||||||
// google.protobuf.Timestamp created_at = 8;
|
|
||||||
// google.protobuf.Timestamp expiration = 9;
|
|
||||||
// }
|
|
||||||
|
|
||||||
message GetMachineRequest {
|
|
||||||
uint64 machine_id = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message GetMachineResponse {
|
|
||||||
uint64 id = 1;
|
|
||||||
string machine_key = 2;
|
|
||||||
string node_key = 3;
|
|
||||||
string disco_key = 4;
|
|
||||||
string ip_address = 5;
|
|
||||||
string name = 6;
|
|
||||||
uint32 namespace_id = 7;
|
|
||||||
|
|
||||||
bool registered = 8;
|
|
||||||
RegisterMethod register_method = 9;
|
|
||||||
uint32 auth_key_id = 10;
|
|
||||||
// PreAuthKey auth_key = 11;
|
|
||||||
|
|
||||||
google.protobuf.Timestamp last_seen = 12;
|
|
||||||
google.protobuf.Timestamp last_successful_update = 13;
|
|
||||||
google.protobuf.Timestamp expiry = 14;
|
|
||||||
|
|
||||||
// bytes host_info = 15;
|
|
||||||
// bytes endpoints = 16;
|
|
||||||
// bytes enabled_routes = 17;
|
|
||||||
|
|
||||||
// google.protobuf.Timestamp created_at = 18;
|
|
||||||
// google.protobuf.Timestamp updated_at = 19;
|
|
||||||
// google.protobuf.Timestamp deleted_at = 20;
|
|
||||||
}
|
|
||||||
|
|
||||||
message CreateNamespaceRequest {
|
|
||||||
string name = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message CreateNamespaceResponse {
|
|
||||||
string name = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message DeleteNamespaceRequest {
|
|
||||||
string name = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message DeleteNamespaceResponse {
|
|
||||||
}
|
|
||||||
|
|
||||||
message ListNamespacesRequest {
|
|
||||||
}
|
|
||||||
|
|
||||||
message ListNamespacesResponse {
|
|
||||||
repeated string namespaces = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
service HeadscaleService {
|
|
||||||
rpc GetMachine(GetMachineRequest) returns(GetMachineResponse) {
|
|
||||||
option(google.api.http) = {
|
|
||||||
get : "/api/v1/machine/{machine_id}"
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
rpc CreateNamespace(CreateNamespaceRequest) returns(CreateNamespaceResponse) {
|
|
||||||
option(google.api.http) = {
|
|
||||||
post : "/api/v1/namespace"
|
|
||||||
body : "*"
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
rpc DeleteNamespace(DeleteNamespaceRequest) returns(DeleteNamespaceResponse) {
|
|
||||||
option(google.api.http) = {
|
|
||||||
delete : "/api/v1/namespace"
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
rpc ListNamespaces(ListNamespacesRequest) returns(ListNamespacesResponse) {
|
|
||||||
option(google.api.http) = {
|
|
||||||
get : "/api/v1/namespace"
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
8
utils.go
8
utils.go
|
@ -156,11 +156,3 @@ func tailNodesToString(nodes []*tailcfg.Node) string {
|
||||||
func tailMapResponseToString(resp tailcfg.MapResponse) string {
|
func tailMapResponseToString(resp tailcfg.MapResponse) string {
|
||||||
return fmt.Sprintf("{ Node: %s, Peers: %s }", resp.Node.Name, tailNodesToString(resp.Peers))
|
return fmt.Sprintf("{ Node: %s, Peers: %s }", resp.Node.Name, tailNodesToString(resp.Peers))
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsLocalhost(host string) bool {
|
|
||||||
if strings.Contains(host, LOCALHOST_V4) || strings.Contains(host, LOCALHOST_V6) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in a new issue