/* ** Copyright (C) 2001-2025 Zabbix SIA ** ** This program is free software: you can redistribute it and/or modify it under the terms of ** the GNU Affero General Public License as published by the Free Software Foundation, version 3. ** ** This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; ** without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. ** See the GNU Affero General Public License for more details. ** ** You should have received a copy of the GNU Affero General Public License along with this program. ** If not, see <https://www.gnu.org/licenses/>. **/ package mysql import ( "context" "crypto/tls" "database/sql" "fmt" "strings" "sync" "time" "github.com/go-sql-driver/mysql" "github.com/omeid/go-yarn" "golang.zabbix.com/sdk/log" "golang.zabbix.com/sdk/tlsconfig" "golang.zabbix.com/sdk/uri" "golang.zabbix.com/sdk/zbxerr" ) const ( // connType disable = "disabled" require = "required" verifyCa = "verify_ca" verifyFull = "verify_full" ) type MyClient interface { Query(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) QueryByName(ctx context.Context, queryName string, args ...interface{}) (rows *sql.Rows, err error) QueryRow(ctx context.Context, query string, args ...interface{}) (row *sql.Row, err error) } type MyConn struct { client *sql.DB lastAccessTime time.Time lastAccessTimeMu sync.Mutex queryStorage *yarn.Yarn } // ConnManager is thread-safe structure for manage connections. type ConnManager struct { connectionsMu sync.Mutex connections map[connKey]*MyConn keepAlive time.Duration connectTimeout time.Duration callTimeout time.Duration Destroy context.CancelFunc queryStorage yarn.Yarn log log.Logger } type connKey struct { uri uri.URI rawUri string tlsConnect string tlsCA string tlsCert string tlsKey string } // Query wraps DB.QueryRowContext. func (conn *MyConn) Query(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { rows, err = conn.client.QueryContext(ctx, query, args...) if ctxErr := ctx.Err(); ctxErr != nil { err = ctxErr } return } // QueryByName wraps DB.QueryRowContext. func (conn *MyConn) QueryByName(ctx context.Context, name string, args ...interface{}) (rows *sql.Rows, err error) { if sql, ok := (*conn.queryStorage).Get(name + sqlExt); ok { normalizedSQL := strings.TrimRight(strings.TrimSpace(sql), ";") return conn.Query(ctx, normalizedSQL, args...) } return nil, fmt.Errorf("query %s not found", name) } // QueryRow wraps DB.QueryRowContext. func (conn *MyConn) QueryRow(ctx context.Context, query string, args ...interface{}) (row *sql.Row, err error) { row = conn.client.QueryRowContext(ctx, query, args...) if ctxErr := ctx.Err(); ctxErr != nil { err = ctxErr } return } // GetConnection returns an existing connection or creates a new one. func (c *ConnManager) GetConnection(uri uri.URI, params map[string]string) (*MyConn, error) { ck := createConnKey(uri, params) conn := c.getConn(ck) if conn != nil { c.log.Tracef("connection found for host: %s", uri.Host()) conn.updateLastAccessTime() return conn, nil } c.log.Tracef("creating new connection for host: %s", uri.Host()) conn, err := c.create(ck) if err != nil { return nil, err } return c.setConn(ck, conn), nil } // NewConnManager initializes connManager structure and runs Go Routine that watches for unused connections. func NewConnManager( keepAlive, connectTimeout, callTimeout, hkInterval time.Duration, queryStorage yarn.Yarn, logger log.Logger, ) *ConnManager { ctx, cancel := context.WithCancel(context.Background()) connMgr := &ConnManager{ connections: make(map[connKey]*MyConn), keepAlive: keepAlive, connectTimeout: connectTimeout, callTimeout: callTimeout, Destroy: cancel, // Destroy stops originated goroutines and closes connections. queryStorage: queryStorage, log: logger, } go connMgr.housekeeper(ctx, hkInterval) return connMgr } // updateLastAccessTime updates the last time a connection was accessed. func (conn *MyConn) updateLastAccessTime() { conn.lastAccessTimeMu.Lock() defer conn.lastAccessTimeMu.Unlock() conn.lastAccessTime = time.Now() } func (conn *MyConn) getLastAccessTime() time.Time { conn.lastAccessTimeMu.Lock() defer conn.lastAccessTimeMu.Unlock() return conn.lastAccessTime } // closeUnused closes each connection that has not been accessed at least within the keepalive interval. func (c *ConnManager) closeUnused() { c.connectionsMu.Lock() defer c.connectionsMu.Unlock() for ck, conn := range c.connections { if time.Since(conn.getLastAccessTime()) > c.keepAlive { conn.client.Close() delete(c.connections, ck) log.Debugf("[%s] Closed unused connection: %s", pluginName, ck.uri.Addr()) } } } // closeAll closes all existed connections. func (c *ConnManager) closeAll() { c.connectionsMu.Lock() defer c.connectionsMu.Unlock() for uri, conn := range c.connections { conn.client.Close() delete(c.connections, uri) } } // housekeeper repeatedly checks for unused connections and closes them. func (c *ConnManager) housekeeper(ctx context.Context, interval time.Duration) { ticker := time.NewTicker(interval) for { select { case <-ctx.Done(): ticker.Stop() c.closeAll() return case <-ticker.C: c.closeUnused() } } } // create creates a new connection with given credentials. func (c *ConnManager) create(ck connKey) (*MyConn, error) { details, err := getTLSDetails(ck) if err != nil { return nil, err } tlsConfig, err := c.getTLSConfig(details) if err != nil { return nil, err } config, err := getMySQLConfig(ck.uri, tlsConfig, c.connectTimeout, c.callTimeout) if err != nil { return nil, err } connector, err := mysql.NewConnector(config) if err != nil { return nil, zbxerr.New("failed to create mysql connector").Wrap(err) } client := sql.OpenDB(connector) log.Debugf("[%s] Created new connection: %s", pluginName, ck.uri.Addr()) return &MyConn{client: client, lastAccessTime: time.Now(), queryStorage: &c.queryStorage}, nil } // getConn concurrent connections cache getter. func (c *ConnManager) getConn(ck connKey) *MyConn { //nolint:gocritic c.connectionsMu.Lock() defer c.connectionsMu.Unlock() conn, ok := c.connections[ck] if !ok { return nil } return conn } // setConn concurrent connections cache setter. // // Returns the cached connection. If the provider connection is already present // in cache, it is closed. // //nolint:gocritic func (c *ConnManager) setConn(ck connKey, conn *MyConn) *MyConn { c.connectionsMu.Lock() defer c.connectionsMu.Unlock() existingConn, ok := c.connections[ck] if ok { defer conn.client.Close() //nolint:errcheck log.Debugf("[%s] Closed redundant connection: %s", pluginName, ck.uri.Addr()) return existingConn } c.connections[ck] = conn return conn } func getMySQLConfig(uri uri.URI, tls *tls.Config, connectTimeout, callTimeout time.Duration) (*mysql.Config, error) { config := mysql.NewConfig() config.User = uri.User() config.Passwd = uri.Password() config.Net = uri.Scheme() config.Addr = uri.Addr() config.Timeout = connectTimeout config.ReadTimeout = callTimeout config.InterpolateParams = true if tls == nil { return config, nil } err := mysql.RegisterTLSConfig(uri.String(), tls) if err != nil { return nil, zbxerr.New("failed to register TLS config").Wrap(err) } config.TLSConfig = uri.String() return config, nil } func (c *ConnManager) getTLSConfig(details *tlsconfig.Details) (*tls.Config, error) { var ( tlsConf *tls.Config err error ) switch details.TlsConnect { case "required": tlsConf, err = c.getRequiredTLSConfig(details) if err != nil { return nil, zbxerr.New("failed to get TLS config for required connection").Wrap(err) } case "verify_ca": tlsConf, err = details.GetTLSConfig(true) if err != nil { return nil, zbxerr.New("failed to get TLS config for verify_ca connection").Wrap(err) } tlsConf.VerifyPeerCertificate = tlsconfig.VerifyPeerCertificateFunc("", tlsConf.RootCAs) case "verify_full": tlsConf, err = details.GetTLSConfig(false) if err != nil { return nil, zbxerr.New("failed to get TLS config for verify_full connection").Wrap(err) } tlsConf.VerifyPeerCertificate = tlsconfig.VerifyPeerCertificateFunc(tlsConf.ServerName, tlsConf.RootCAs) default: return nil, nil } return tlsConf, nil } func (c *ConnManager) getRequiredTLSConfig(details *tlsconfig.Details) (*tls.Config, error) { if details.TlsCaFile != "" { c.log.Warningf("server CA will not be verified for %s", details.TlsConnect) } clientCerts, err := details.LoadCertificates() if err != nil { return nil, err } return &tls.Config{Certificates: clientCerts, InsecureSkipVerify: true}, nil } func createConnKey(uri uri.URI, params map[string]string) connKey { tlsType := params[tlsConnectParam] if tlsType == "" { tlsType = disable } return connKey{ uri: uri, rawUri: params[uriParam], tlsConnect: tlsType, tlsCA: params[tlsCAParam], tlsCert: params[tlsCertParam], tlsKey: params[tlsKeyParam], } } func getTLSDetails(ck connKey) (*tlsconfig.Details, error) { var ( validateCA = true validateClient = false ) details := tlsconfig.NewDetails( "", ck.tlsConnect, ck.tlsCA, ck.tlsCert, ck.tlsKey, ck.rawUri, disable, require, verifyCa, verifyFull, ) if ck.tlsConnect == disable || ck.tlsConnect == require { validateCA = false } if details.TlsKeyFile != "" || details.TlsCertFile != "" { validateClient = true } err := details.Validate(validateCA, validateClient, validateClient) if err != nil { return nil, zbxerr.ErrorInvalidConfiguration.Wrap(err) } return &details, nil }