package notifier

import (
	"context"
	"fmt"
	"strings"
	"sync"

	"github.com/juanfont/headscale/hscontrol/types"
	"github.com/juanfont/headscale/hscontrol/util"
	"github.com/rs/zerolog/log"
	"tailscale.com/types/key"
)

type Notifier struct {
	l         sync.RWMutex
	nodes     map[string]chan<- types.StateUpdate
	connected map[key.MachinePublic]bool
}

func NewNotifier() *Notifier {
	return &Notifier{
		nodes:     make(map[string]chan<- types.StateUpdate),
		connected: make(map[key.MachinePublic]bool),
	}
}

func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpdate) {
	log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to add node")
	defer log.Trace().
		Caller().
		Str("key", machineKey.ShortString()).
		Msg("releasing lock to add node")

	n.l.Lock()
	defer n.l.Unlock()

	n.nodes[machineKey.String()] = c
	n.connected[machineKey] = true

	log.Trace().
		Str("machine_key", machineKey.ShortString()).
		Int("open_chans", len(n.nodes)).
		Msg("Added new channel")
}

func (n *Notifier) RemoveNode(machineKey key.MachinePublic) {
	log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to remove node")
	defer log.Trace().
		Caller().
		Str("key", machineKey.ShortString()).
		Msg("releasing lock to remove node")

	n.l.Lock()
	defer n.l.Unlock()

	if len(n.nodes) == 0 {
		return
	}

	delete(n.nodes, machineKey.String())
	n.connected[machineKey] = false

	log.Trace().
		Str("machine_key", machineKey.ShortString()).
		Int("open_chans", len(n.nodes)).
		Msg("Removed channel")
}

// IsConnected reports if a node is connected to headscale and has a
// poll session open.
func (n *Notifier) IsConnected(machineKey key.MachinePublic) bool {
	n.l.RLock()
	defer n.l.RUnlock()

	return n.connected[machineKey]
}

// TODO(kradalby): This returns a pointer and can be dangerous.
func (n *Notifier) ConnectedMap() map[key.MachinePublic]bool {
	return n.connected
}

func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
	n.NotifyWithIgnore(ctx, update)
}

func (n *Notifier) NotifyWithIgnore(
	ctx context.Context,
	update types.StateUpdate,
	ignore ...string,
) {
	log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify")
	defer log.Trace().
		Caller().
		Interface("type", update.Type).
		Msg("releasing lock, finished notifying")

	n.l.RLock()
	defer n.l.RUnlock()

	for key, c := range n.nodes {
		if util.IsStringInSlice(ignore, key) {
			continue
		}

		select {
		case <-ctx.Done():
			log.Error().
				Err(ctx.Err()).
				Str("mkey", key).
				Any("origin", ctx.Value("origin")).
				Any("hostname", ctx.Value("hostname")).
				Msgf("update not sent, context cancelled")

			return
		case c <- update:
			log.Trace().
				Str("mkey", key).
				Any("origin", ctx.Value("origin")).
				Any("hostname", ctx.Value("hostname")).
				Msgf("update successfully sent on chan")
		}
	}
}

func (n *Notifier) NotifyByMachineKey(
	ctx context.Context,
	update types.StateUpdate,
	mKey key.MachinePublic,
) {
	log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify")
	defer log.Trace().
		Caller().
		Interface("type", update.Type).
		Msg("releasing lock, finished notifying")

	n.l.RLock()
	defer n.l.RUnlock()

	if c, ok := n.nodes[mKey.String()]; ok {
		select {
		case <-ctx.Done():
			log.Error().
				Err(ctx.Err()).
				Str("mkey", mKey.String()).
				Any("origin", ctx.Value("origin")).
				Any("hostname", ctx.Value("hostname")).
				Msgf("update not sent, context cancelled")

			return
		case c <- update:
			log.Trace().
				Str("mkey", mKey.String()).
				Any("origin", ctx.Value("origin")).
				Any("hostname", ctx.Value("hostname")).
				Msgf("update successfully sent on chan")
		}
	}
}

func (n *Notifier) String() string {
	n.l.RLock()
	defer n.l.RUnlock()

	str := []string{"Notifier, in map:\n"}

	for k, v := range n.nodes {
		str = append(str, fmt.Sprintf("\t%s: %v\n", k, v))
	}

	return strings.Join(str, "")
}