/*
** Zabbix
** Copyright (C) 2001-2025 Zabbix SIA
**
** Licensed under the Apache License, Version 2.0 (the "License");
** you may not use this file except in compliance with the License.
** You may obtain a copy of the License at
**
**     http://www.apache.org/licenses/LICENSE-2.0
**
** Unless required by applicable law or agreed to in writing, software
** distributed under the License is distributed on an "AS IS" BASIS,
** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
** See the License for the specific language governing permissions and
** limitations under the License.
**/

package dbconn

import (
	"context"
	"database/sql"
	"fmt"
	"net/url"
	"path/filepath"
	"strconv"
	"sync"
	"time"

	"golang.zabbix.com/plugin/mssql/plugin/handlers"
	"golang.zabbix.com/plugin/mssql/plugin/params"
	"golang.zabbix.com/sdk/errs"
	"golang.zabbix.com/sdk/log"
	"golang.zabbix.com/sdk/uri"
)

var (
	_ handlers.HandlerFunc = (*ConnCollection)(nil).WithConnHandlerFunc(nil)
	_ handlers.HandlerFunc = (*ConnCollection)(nil).PingHandler
)

// connConfig is a configuration for a connection to the database.
type connConfig struct {
	URI                    string
	User                   string
	Password               string
	CACertPath             string
	TrustServerCertificate string
	HostNameInCertificate  string
	Encrypt                string
	TLSMinVersion          string
	Database               string
}

// ConnCollection is a collection of connections to the database.
// Allows managing multiple connections.
type ConnCollection struct {
	mu           sync.Mutex
	conns        map[connConfig]*sql.DB
	keepAlive    int
	queryTimeout int
	logr         log.Logger
	driverName   string // always sqlserver, allow to change for unit tests.
}

// Init initializes a pre-allocated connection collection.
func (c *ConnCollection) Init(keepAlive, queryTimeout int, logr log.Logger) {
	c.conns = make(map[connConfig]*sql.DB)
	c.keepAlive = keepAlive
	c.queryTimeout = queryTimeout
	c.logr = logr
	c.driverName = "sqlserver"
}

// WithConnHandlerFunc creates a new function that creates or gets cached DB
// connection for the given metric parameters and calls the given handler
// function with the connection.
func (c *ConnCollection) WithConnHandlerFunc(
	handler handlers.ConnHandlerFunc,
) handlers.HandlerFunc {
	return func(
		metricParams map[string]string, extraParams ...string,
	) (any, error) {
		ctx, cancel := context.WithTimeout(
			context.Background(),
			time.Duration(c.queryTimeout)*time.Second,
		)
		defer cancel()

		conn, err := c.get(ctx, newConnConfig(metricParams))
		if err != nil {
			return nil, errs.Wrap(err, "failed to get conn")
		}

		return handler(ctx, conn, metricParams, extraParams...)
	}
}

// PingHandler tries to ping the database, returning 1 on success 0 on failure.
func (c *ConnCollection) PingHandler(
	metricParams map[string]string, _ ...string,
) (any, error) {
	ctx, cancel := context.WithTimeout(
		context.Background(),
		time.Duration(c.queryTimeout)*time.Second,
	)
	defer cancel()

	conn, err := c.get(ctx, newConnConfig(metricParams))
	if err != nil {
		c.logr.Infof("Failed go get connection for ping: %s", err.Error())

		return 0, nil
	}

	err = conn.PingContext(ctx)
	if err != nil {
		c.logr.Infof("Failed to ping: %s", err.Error())

		return 0, nil
	}

	return 1, nil
}

// Close closes all connections in the collection.
func (c *ConnCollection) Close() {
	c.mu.Lock()
	defer c.mu.Unlock()

	for conf, conn := range c.conns {
		err := conn.Close()
		if err != nil {
			c.logr.Errf("Failed to close connection: %s", err.Error())
		}

		delete(c.conns, conf)
	}
}

//nolint:gocritic // need conf by value.
func (c *ConnCollection) get(
	ctx context.Context,
	conf connConfig,
) (*sql.DB, error) {
	conn := c.getConn(conf)
	if conn != nil {
		return conn, nil
	}

	conn, err := c.newConn(ctx, &conf)
	if err != nil {
		return nil, errs.Wrap(err, "failed to create conn")
	}

	return c.setConn(conf, conn), nil
}

// getConn concurrent connections cache getter.
func (c *ConnCollection) getConn(conf connConfig) *sql.DB { //nolint:gocritic
	c.mu.Lock()
	defer c.mu.Unlock()

	conn, ok := c.conns[conf]
	if !ok {
		return nil
	}

	return conn
}

// setConn concurrent connections cache setter.
//
// Returns the cached connection. If the provider connection is already present
// in cache, it is closed.
//
//nolint:gocritic
func (c *ConnCollection) setConn(
	conf connConfig,
	conn *sql.DB,
) *sql.DB {
	c.mu.Lock()
	defer c.mu.Unlock()

	existingConn, ok := c.conns[conf]
	if ok {
		defer conn.Close() //nolint:errcheck

		c.logr.Debugf("Closed redundant connection: %s", conf.URI)

		return existingConn
	}

	c.conns[conf] = conn

	return conn
}

//nolint:gocyclo,cyclop
//revive:disable:cognitive-complexity
func (c *ConnCollection) newConn(
	ctx context.Context,
	conf *connConfig,
) (*sql.DB, error) {
	c.logr.Infof(
		"Creating new connection to %q, with user %q to database %q, "+
			"with CA certificate %q, "+
			"trust server certificate %q, host name in certificate %q "+
			"encrypt %q, TLS min version %q",
		conf.URI,
		conf.User,
		conf.Database,
		conf.CACertPath,
		conf.TrustServerCertificate,
		conf.HostNameInCertificate,
		conf.Encrypt,
		conf.TLSMinVersion,
	)

	connURI, err := uri.NewWithCreds(
		conf.URI, conf.User, conf.Password, &uri.Defaults{Scheme: params.URIDefaults.Scheme},
	)
	if err != nil {
		return nil, errs.Wrap(err, "failed to set URI defaults")
	}

	u, err := url.Parse(connURI.String())
	if err != nil {
		return nil, errs.Wrap(err, "failed to parse URI")
	}

	if connURI.Port() == "" && connURI.Path() == "" {
		u.Host = fmt.Sprintf("%s:%s", u.Hostname(), params.URIDefaults.Port)
	}

	u.Path = connURI.Path()

	queryParams := u.Query()
	queryParams.Add("app name", "Zabbix agent 2 MSSQL plugin")
	queryParams.Add("keepAlive", strconv.Itoa(c.keepAlive))

	if conf.Database != "" {
		queryParams.Add("database", conf.Database)
	}

	if conf.CACertPath != "" {
		queryParams.Add("certificate", filepath.Clean(conf.CACertPath))
	}

	if conf.TrustServerCertificate != "" {
		queryParams.Add("TrustServerCertificate", conf.TrustServerCertificate)
	}

	if conf.HostNameInCertificate != "" {
		queryParams.Add("hostNameInCertificate", conf.HostNameInCertificate)
	}

	if conf.Encrypt != "" {
		queryParams.Add("encrypt", conf.Encrypt)
	}

	if conf.TLSMinVersion != "" {
		queryParams.Add("tlsMinVersion", conf.TLSMinVersion)
	}

	u.RawQuery = queryParams.Encode()

	db, err := sql.Open(c.driverName, u.String())
	if err != nil {
		return nil, errs.Wrap(err, "failed to open DB connection")
	}

	db.SetConnMaxIdleTime(time.Duration(c.keepAlive) * time.Second)

	err = db.PingContext(ctx)
	if err != nil {
		return nil, errs.Wrap(err, "failed to ping")
	}

	return db, nil
}

func newConnConfig(metricParams map[string]string) connConfig {
	return connConfig{
		URI:                    metricParams[params.URI.Name()],
		User:                   metricParams[params.User.Name()],
		Password:               metricParams[params.Password.Name()],
		CACertPath:             metricParams[params.CACertPath.Name()],
		TrustServerCertificate: metricParams[params.TrustServerCertificate.Name()],
		HostNameInCertificate:  metricParams[params.HostNameInCertificate.Name()],
		Encrypt:                metricParams[params.Encrypt.Name()],
		TLSMinVersion:          metricParams[params.TLSMinVersion.Name()],
		Database:               metricParams[params.Database.Name()],
	}
}