diff --git a/config-example.yaml b/config-example.yaml index 40e5c8e..44e36b8 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -140,6 +140,23 @@ ephemeral_node_inactivity_timeout: 30m database: type: sqlite + # Enable debug mode. This setting requires the log.level to be set to "debug" or "trace". + debug: false + + # GORM configuration settings. + gorm: + # Enable prepared statements. + prepare_stmt: true + + # Enable parameterized queries. + parameterized_queries: true + + # Skip logging "record not found" errors. + skip_err_record_not_found: true + + # Threshold for slow queries in milliseconds. + slow_threshold: 1000 + # SQLite config sqlite: path: /var/lib/headscale/db.sqlite diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index c190813..331dba5 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -426,7 +426,7 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) { // TODO(kradalby): Integrate this with zerolog var dbLogger logger.Interface if cfg.Debug { - dbLogger = logger.Default + dbLogger = util.NewDBLogWrapper(&log.Logger, cfg.Gorm.SlowThreshold, cfg.Gorm.SkipErrRecordNotFound, cfg.Gorm.ParameterizedQueries) } else { dbLogger = logger.Default.LogMode(logger.Silent) } @@ -447,7 +447,8 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) { db, err := gorm.Open( sqlite.Open(cfg.Sqlite.Path), &gorm.Config{ - Logger: dbLogger, + PrepareStmt: cfg.Gorm.PrepareStmt, + Logger: dbLogger, }, ) diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index e938768..bff8099 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -120,11 +120,22 @@ type PostgresConfig struct { ConnMaxIdleTimeSecs int } +type GormConfig struct { + Debug bool + SlowThreshold time.Duration + SkipErrRecordNotFound bool + ParameterizedQueries bool + PrepareStmt bool +} + type DatabaseConfig struct { // Type sets the database type, either "sqlite3" or "postgres" Type string Debug bool + // Type sets the gorm configuration + Gorm GormConfig + Sqlite SqliteConfig Postgres PostgresConfig } @@ -486,6 +497,11 @@ func GetDatabaseConfig() DatabaseConfig { type_ := viper.GetString("database.type") + skipErrRecordNotFound := viper.GetBool("database.gorm.skip_err_record_not_found") + slowThreshold := viper.GetDuration("database.gorm.slow_threshold") * time.Millisecond + parameterizedQueries := viper.GetBool("database.gorm.parameterized_queries") + prepareStmt := viper.GetBool("database.gorm.prepare_stmt") + switch type_ { case DatabaseSqlite, DatabasePostgres: break @@ -499,6 +515,13 @@ func GetDatabaseConfig() DatabaseConfig { return DatabaseConfig{ Type: type_, Debug: debug, + Gorm: GormConfig{ + Debug: debug, + SkipErrRecordNotFound: skipErrRecordNotFound, + SlowThreshold: slowThreshold, + ParameterizedQueries: parameterizedQueries, + PrepareStmt: prepareStmt, + }, Sqlite: SqliteConfig{ Path: util.AbsolutePathFromConfigPath( viper.GetString("database.sqlite.path"), diff --git a/hscontrol/util/log.go b/hscontrol/util/log.go index 41d667d..12f646b 100644 --- a/hscontrol/util/log.go +++ b/hscontrol/util/log.go @@ -1,7 +1,14 @@ package util import ( + "context" + "errors" + "time" + + "github.com/rs/zerolog" "github.com/rs/zerolog/log" + "gorm.io/gorm" + gormLogger "gorm.io/gorm/logger" "tailscale.com/types/logger" ) @@ -14,3 +21,71 @@ func TSLogfWrapper() logger.Logf { log.Debug().Caller().Msgf(format, args...) } } + +type DBLogWrapper struct { + Logger *zerolog.Logger + Level zerolog.Level + Event *zerolog.Event + SlowThreshold time.Duration + SkipErrRecordNotFound bool + ParameterizedQueries bool +} + +func NewDBLogWrapper(origin *zerolog.Logger, slowThreshold time.Duration, skipErrRecordNotFound bool, parameterizedQueries bool) *DBLogWrapper { + l := &DBLogWrapper{ + Logger: origin, + Level: origin.GetLevel(), + SlowThreshold: slowThreshold, + SkipErrRecordNotFound: skipErrRecordNotFound, + ParameterizedQueries: parameterizedQueries, + } + + return l +} + +type DBLogWrapperOption func(*DBLogWrapper) + +func (l *DBLogWrapper) LogMode(gormLogger.LogLevel) gormLogger.Interface { + return l +} + +func (l *DBLogWrapper) Info(ctx context.Context, msg string, data ...interface{}) { + l.Logger.Info().Msgf(msg, data...) +} + +func (l *DBLogWrapper) Warn(ctx context.Context, msg string, data ...interface{}) { + l.Logger.Warn().Msgf(msg, data...) +} + +func (l *DBLogWrapper) Error(ctx context.Context, msg string, data ...interface{}) { + l.Logger.Error().Msgf(msg, data...) +} + +func (l *DBLogWrapper) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + elapsed := time.Since(begin) + sql, rowsAffected := fc() + fields := map[string]interface{}{ + "duration": elapsed, + "sql": sql, + "rowsAffected": rowsAffected, + } + + if err != nil && !(errors.Is(err, gorm.ErrRecordNotFound) && l.SkipErrRecordNotFound) { + l.Logger.Error().Err(err).Fields(fields).Msgf("") + return + } + + if l.SlowThreshold != 0 && elapsed > l.SlowThreshold { + l.Logger.Warn().Fields(fields).Msgf("") + return + } + + l.Logger.Debug().Fields(fields).Msgf("") +} + +func (l *DBLogWrapper) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { + if l.ParameterizedQueries { + return sql, nil + } + return sql, params +}