/* ** Zabbix ** Copyright 2001-2024 Zabbix SIA ** ** Licensed under the Apache License, Version 2.0 (the "License"); ** you may not use this file except in compliance with the License. ** You may obtain a copy of the License at ** ** http://www.apache.org/licenses/LICENSE-2.0 ** ** Unless required by applicable law or agreed to in writing, software ** distributed under the License is distributed on an "AS IS" BASIS, ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ** See the License for the specific language governing permissions and ** limitations under the License. **/ package plugin import ( "context" "database/sql" "fmt" "net" "net/url" "path/filepath" "strings" "sync" "time" "github.com/jackc/pgx/v4/pgxpool" "github.com/jackc/pgx/v4/stdlib" "github.com/omeid/go-yarn" "golang.zabbix.com/sdk/metric" "golang.zabbix.com/sdk/tlsconfig" "golang.zabbix.com/sdk/uri" "golang.zabbix.com/sdk/zbxerr" ) const ( // pgx dns field names password = "password" sslMode = "sslmode" rootCA = "sslrootcert" cert = "sslcert" key = "sslkey" cacheMode = "statement_cache_mode" // connType disable = "disable" require = "require" verifyCa = "verify-ca" verifyFull = "verify-full" MinSupportedPGVersion = 100000 ) type PostgresClient interface { Query(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) QueryByName(ctx context.Context, queryName string, args ...interface{}) (rows *sql.Rows, err error) QueryRow(ctx context.Context, query string, args ...interface{}) (row *sql.Row, err error) QueryRowByName(ctx context.Context, queryName string, args ...interface{}) (row *sql.Row, err error) PostgresVersion() int } // PGConn holds pointer to the Pool of PostgreSQL Instance. type PGConn struct { client *sql.DB callTimeout time.Duration ctx context.Context lastTimeAccess time.Time version int queryStorage *yarn.Yarn address string } type connID struct { uri uri.URI cacheMode string } var errorQueryNotFound = "query %q not found" // Query wraps pgxpool.Query. func (conn *PGConn) Query(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { rows, err = conn.client.QueryContext(ctx, query, args...) if ctxErr := ctx.Err(); ctxErr != nil { err = ctxErr } return } // QueryByName executes a query from queryStorage by its name and returns a single row. func (conn *PGConn) QueryByName(ctx context.Context, queryName string, args ...interface{}) (rows *sql.Rows, err error) { if sql, ok := (*conn.queryStorage).Get(queryName + sqlExt); ok { normalizedSQL := strings.TrimRight(strings.TrimSpace(sql), ";") return conn.Query(ctx, normalizedSQL, args...) } return nil, fmt.Errorf(errorQueryNotFound, queryName) } // QueryRow wraps pgxpool.QueryRow. func (conn *PGConn) QueryRow(ctx context.Context, query string, args ...interface{}) (row *sql.Row, err error) { row = conn.client.QueryRowContext(ctx, query, args...) if ctxErr := ctx.Err(); ctxErr != nil { err = ctxErr } return } // QueryRowByName executes a query from queryStorage by its name and returns a single row. func (conn *PGConn) QueryRowByName(ctx context.Context, queryName string, args ...interface{}) (row *sql.Row, err error) { if sql, ok := (*conn.queryStorage).Get(queryName + sqlExt); ok { normalizedSQL := strings.TrimRight(strings.TrimSpace(sql), ";") return conn.QueryRow(ctx, normalizedSQL, args...) } return nil, fmt.Errorf(errorQueryNotFound, queryName) } // GetPostgresVersion exec SQL query to retrieve the version of PostgreSQL server we are currently connected to. func getPostgresVersion(ctx context.Context, conn *sql.DB) (version int, err error) { err = conn.QueryRowContext(ctx, `select current_setting('server_version_num');`).Scan(&version) return } // PostgresVersion returns the version of PostgreSQL server we are currently connected to. func (conn *PGConn) PostgresVersion() int { return conn.version } // updateAccessTime updates the last time a connection was accessed. func (conn *PGConn) updateAccessTime() { conn.lastTimeAccess = time.Now() } // ConnManager is a thread-safe structure for manage connections. type ConnManager struct { sync.Mutex connMutex sync.Mutex connections map[connID]*PGConn keepAlive time.Duration connectTimeout time.Duration callTimeout time.Duration Destroy context.CancelFunc queryStorage yarn.Yarn } // NewConnManager initializes connManager structure and runs Go Routine that watches for unused connections. func NewConnManager(keepAlive, connectTimeout, callTimeout, hkInterval time.Duration, queryStorage yarn.Yarn) *ConnManager { ctx, cancel := context.WithCancel(context.Background()) connMgr := &ConnManager{ connections: make(map[connID]*PGConn), keepAlive: keepAlive, connectTimeout: connectTimeout, callTimeout: callTimeout, Destroy: cancel, // Destroy stops originated goroutines and closes connections. queryStorage: queryStorage, } go connMgr.housekeeper(ctx, hkInterval) return connMgr } // closeUnused closes each connection that has not been accessed at least within the keepalive interval. func (c *ConnManager) closeUnused() { c.connMutex.Lock() defer c.connMutex.Unlock() for ci, conn := range c.connections { if time.Since(conn.lastTimeAccess) > c.keepAlive { conn.client.Close() delete(c.connections, ci) Impl.Debugf("[%s] Closed unused connection: %s", Name, ci.uri.Addr()) } } } // closeAll closes all existed connections. func (c *ConnManager) closeAll() { c.connMutex.Lock() for ci, conn := range c.connections { conn.client.Close() delete(c.connections, ci) } c.connMutex.Unlock() } // housekeeper repeatedly checks for unused connections and closes them. func (c *ConnManager) housekeeper(ctx context.Context, interval time.Duration) { ticker := time.NewTicker(interval) for { select { case <-ctx.Done(): ticker.Stop() c.closeAll() return case <-ticker.C: c.closeUnused() } } } // create creates a new connection with given credentials. func (c *ConnManager) create(ci connID, details tlsconfig.Details) (*PGConn, error) { c.connMutex.Lock() defer c.connMutex.Unlock() if _, ok := c.connections[ci]; ok { // Should never happen. panic("connection already exists") } ctx := context.Background() host := ci.uri.Host() port := ci.uri.Port() if ci.uri.Scheme() == "unix" { socket := ci.uri.Addr() host = filepath.Dir(socket) ext := filepath.Ext(filepath.Base(socket)) if len(ext) <= 1 { return nil, fmt.Errorf("incorrect socket: %q", socket) } port = ext[1:] } dbname, err := url.QueryUnescape(ci.uri.GetParam("dbname")) if err != nil { return nil, err } client, err := createClient( createDNS(host, port, dbname, ci.uri.User(), ci.uri.Password(), ci.cacheMode, details), c.connectTimeout, ) if err != nil { return nil, err } serverVersion, err := getPostgresVersion(ctx, client) if err != nil { client.Close() return nil, err } if serverVersion < MinSupportedPGVersion { client.Close() return nil, fmt.Errorf("PostgreSQL version %d is not supported", serverVersion) } c.connections[ci] = &PGConn{ client: client, callTimeout: c.callTimeout, version: serverVersion, lastTimeAccess: time.Now(), ctx: ctx, queryStorage: &c.queryStorage, address: ci.uri.Addr(), } Impl.Debugf("[%s] Created new connection: %s", Name, ci.uri.Addr()) return c.connections[ci], nil } func createDNS(host, port, dbname, user, pass, mode string, details tlsconfig.Details) string { dsn := fmt.Sprintf("host=%s port=%s dbname=%s user=%s", host, port, dbname, user) tmp := map[string]string{ password: pass, sslMode: details.TlsConnect, rootCA: details.TlsCaFile, cert: details.TlsCertFile, key: details.TlsKeyFile, cacheMode: mode, } for k, v := range tmp { if v != "" { dsn = fmt.Sprintf("%s %s=%s", dsn, k, v) } } return dsn } func renameTLS(in string) string { switch in { case "required": return "require" case "verify_ca": return "verify-ca" case "verify_full": return "verify-full" default: return in } } func createClient(dsn string, timeout time.Duration) (*sql.DB, error) { config, err := pgxpool.ParseConfig(dsn) if err != nil { return nil, err } config.ConnConfig.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) { d := net.Dialer{} ctxTimeout, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() conn, err := d.DialContext(ctxTimeout, network, addr) return conn, err } return stdlib.OpenDB(*config.ConnConfig), nil } // get returns a connection with given uri if it exists and also updates lastTimeAccess, otherwise returns nil. func (c *ConnManager) get(cd connID) *PGConn { c.connMutex.Lock() defer c.connMutex.Unlock() if conn, ok := c.connections[cd]; ok { conn.updateAccessTime() return conn } return nil } // GetConnection returns an existing connection or creates a new one. func (c *ConnManager) GetConnection(ci connID, params map[string]string) (conn *PGConn, err error) { c.Lock() defer c.Unlock() conn = c.get(ci) if conn != nil { return } details, err := getTlsDetails(params) if err != nil { return nil, err } conn, err = c.create(ci, details) if err != nil { err = zbxerr.ErrorConnectionFailed.Wrap(err) } return } func getTlsDetails(params map[string]string) (tlsconfig.Details, error) { tlsType := renameTLS(params[tlsConnectParam]) validateCA := true if tlsType == "" { tlsType = disable } details := tlsconfig.NewDetails( params[metric.SessionParam], tlsType, params[tlsCAParam], params[tlsCertParam], params[tlsKeyParam], params[uriParam], disable, require, verifyCa, verifyFull, ) if tlsType == disable || tlsType == require { validateCA = false } err := details.Validate(validateCA, false, false) return details, err } func createConnID(params map[string]string) (connID, error) { u, err := uri.NewWithCreds( fmt.Sprintf("%s?dbname=%s", params[uriParam], url.QueryEscape(params[databaseParam])), params[userParam], params[passwordParam], uriDefaults, ) if err != nil { return connID{}, err } return connID{uri: *u, cacheMode: params[cacheModeParam]}, nil }