/* ** Zabbix ** Copyright (C) 2001-2025 Zabbix SIA ** ** This program is free software; you can redistribute it and/or modify ** it under the terms of the GNU General Public License as published by ** the Free Software Foundation; either version 2 of the License, or ** (at your option) any later version. ** ** This program is distributed in the hope that it will be useful, ** but WITHOUT ANY WARRANTY; without even the implied warranty of ** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the ** GNU General Public License for more details. ** ** You should have received a copy of the GNU General Public License ** along with this program; if not, write to the Free Software ** Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. **/ package postgres import ( "context" "crypto/tls" "database/sql" "fmt" "net" "net/url" "path/filepath" "strings" "sync" "time" "github.com/jackc/pgx/v4/pgxpool" "github.com/jackc/pgx/v4/stdlib" "github.com/omeid/go-yarn" "zabbix.com/pkg/log" "zabbix.com/pkg/tlsconfig" "zabbix.com/pkg/uri" "zabbix.com/pkg/zbxerr" ) const MinSupportedPGVersion = 100000 type PostgresClient interface { Query(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) QueryByName(ctx context.Context, queryName string, args ...interface{}) (rows *sql.Rows, err error) QueryRow(ctx context.Context, query string, args ...interface{}) (row *sql.Row, err error) QueryRowByName(ctx context.Context, queryName string, args ...interface{}) (row *sql.Row, err error) PostgresVersion() int } // PGConn holds pointer to the Pool of Postgres Instance. type PGConn struct { client *sql.DB callTimeout time.Duration ctx context.Context lastTimeAccess time.Time version int queryStorage *yarn.Yarn address string } var errorQueryNotFound = "query %q not found" // Query wraps pgxpool.Query. func (conn *PGConn) Query(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { rows, err = conn.client.QueryContext(ctx, query, args...) if ctxErr := ctx.Err(); ctxErr != nil { err = ctxErr } return } // QueryByName executes a query from queryStorage by its name and returns a singe row. func (conn *PGConn) QueryByName(ctx context.Context, queryName string, args ...interface{}) (rows *sql.Rows, err error) { if sql, ok := (*conn.queryStorage).Get(queryName + sqlExt); ok { normalizedSQL := strings.TrimRight(strings.TrimSpace(sql), ";") return conn.Query(ctx, normalizedSQL, args...) } return nil, fmt.Errorf(errorQueryNotFound, queryName) } // QueryRow wraps pgxpool.QueryRow. func (conn *PGConn) QueryRow(ctx context.Context, query string, args ...interface{}) (row *sql.Row, err error) { row = conn.client.QueryRowContext(ctx, query, args...) if ctxErr := ctx.Err(); ctxErr != nil { err = ctxErr } return } // QueryRowByName executes a query from queryStorage by its name and returns a singe row. func (conn *PGConn) QueryRowByName(ctx context.Context, queryName string, args ...interface{}) (row *sql.Row, err error) { if sql, ok := (*conn.queryStorage).Get(queryName + sqlExt); ok { normalizedSQL := strings.TrimRight(strings.TrimSpace(sql), ";") return conn.QueryRow(ctx, normalizedSQL, args...) } return nil, fmt.Errorf(errorQueryNotFound, queryName) } // GetPostgresVersion exec SQL query to retrieve the version of PostgreSQL server we are currently connected to. func getPostgresVersion(ctx context.Context, conn *sql.DB) (version int, err error) { err = conn.QueryRowContext(ctx, `select current_setting('server_version_num');`).Scan(&version) return } // PostgresVersion returns the version of PostgreSQL server we are currently connected to. func (conn *PGConn) PostgresVersion() int { return conn.version } // updateAccessTime updates the last time a connection was accessed. func (conn *PGConn) updateAccessTime() { conn.lastTimeAccess = time.Now() } // ConnManager is a thread-safe structure for manage connections. type ConnManager struct { sync.Mutex connMutex sync.Mutex connections map[uri.URI]*PGConn keepAlive time.Duration connectTimeout time.Duration callTimeout time.Duration Destroy context.CancelFunc queryStorage yarn.Yarn } // NewConnManager initializes connManager structure and runs Go Routine that watches for unused connections. func NewConnManager(keepAlive, connectTimeout, callTimeout, hkInterval time.Duration, queryStorage yarn.Yarn) *ConnManager { ctx, cancel := context.WithCancel(context.Background()) connMgr := &ConnManager{ connections: make(map[uri.URI]*PGConn), keepAlive: keepAlive, connectTimeout: connectTimeout, callTimeout: callTimeout, Destroy: cancel, // Destroy stops originated goroutines and closes connections. queryStorage: queryStorage, } go connMgr.housekeeper(ctx, hkInterval) return connMgr } // closeUnused closes each connection that has not been accessed at least within the keepalive interval. func (c *ConnManager) closeUnused() { c.connMutex.Lock() defer c.connMutex.Unlock() for uri, conn := range c.connections { if time.Since(conn.lastTimeAccess) > c.keepAlive { conn.client.Close() delete(c.connections, uri) log.Debugf("[%s] Closed unused connection: %s", pluginName, uri.Addr()) } } } // closeAll closes all existed connections. func (c *ConnManager) closeAll() { c.connMutex.Lock() for uri, conn := range c.connections { conn.client.Close() delete(c.connections, uri) } c.connMutex.Unlock() } // housekeeper repeatedly checks for unused connections and closes them. func (c *ConnManager) housekeeper(ctx context.Context, interval time.Duration) { ticker := time.NewTicker(interval) for { select { case <-ctx.Done(): ticker.Stop() c.closeAll() return case <-ticker.C: c.closeUnused() } } } // create creates a new connection with given credentials. func (c *ConnManager) create(uri uri.URI, details tlsconfig.Details) (*PGConn, error) { c.connMutex.Lock() defer c.connMutex.Unlock() if _, ok := c.connections[uri]; ok { // Should never happen. panic("connection already exists") } ctx := context.Background() host := uri.Host() port := uri.Port() if uri.Scheme() == "unix" { socket := uri.Addr() host = filepath.Dir(socket) ext := filepath.Ext(filepath.Base(socket)) if len(ext) <= 1 { return nil, fmt.Errorf("incorrect socket: %q", socket) } port = ext[1:] } dbname, err := url.QueryUnescape(uri.GetParam("dbname")) if err != nil { return nil, err } dsn := fmt.Sprintf("host=%s port=%s dbname=%s user=%s", host, port, dbname, uri.User()) if uri.Password() != "" { dsn += " password=" + uri.Password() } client, err := createTLSClient(dsn, c.connectTimeout, details) if err != nil { return nil, err } serverVersion, err := getPostgresVersion(ctx, client) if err != nil { client.Close() return nil, err } if serverVersion < MinSupportedPGVersion { client.Close() return nil, fmt.Errorf("postgres version %d is not supported", serverVersion) } c.connections[uri] = &PGConn{ client: client, callTimeout: c.callTimeout, version: serverVersion, lastTimeAccess: time.Now(), ctx: ctx, queryStorage: &c.queryStorage, address: uri.Addr(), } log.Debugf("[%s] Created new connection: %s", pluginName, uri.Addr()) return c.connections[uri], nil } func createTLSClient(dsn string, timeout time.Duration, details tlsconfig.Details) (*sql.DB, error) { config, err := pgxpool.ParseConfig(dsn) if err != nil { return nil, err } config.ConnConfig.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) { d := net.Dialer{} ctxTimeout, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() conn, err := d.DialContext(ctxTimeout, network, addr) return conn, err } config.ConnConfig.TLSConfig, err = getTLSConfig(details) if err != nil { return nil, err } return stdlib.OpenDB(*config.ConnConfig), nil } func getTLSConfig(details tlsconfig.Details) (*tls.Config, error) { switch details.TlsConnect { case "required": return &tls.Config{InsecureSkipVerify: true}, nil case "verify_ca": return tlsconfig.CreateConfig(details, true) case "verify_full": return tlsconfig.CreateConfig(details, false) } return nil, nil } // get returns a connection with given uri if it exists and also updates lastTimeAccess, otherwise returns nil. func (c *ConnManager) get(uri uri.URI) *PGConn { c.connMutex.Lock() defer c.connMutex.Unlock() if conn, ok := c.connections[uri]; ok { conn.updateAccessTime() return conn } return nil } // GetConnection returns an existing connection or creates a new one. func (c *ConnManager) GetConnection(uri uri.URI, details tlsconfig.Details) (conn *PGConn, err error) { c.Lock() defer c.Unlock() conn = c.get(uri) if conn == nil { conn, err = c.create(uri, details) } if err != nil { err = zbxerr.ErrorConnectionFailed.Wrap(err) } return }