/* ** Zabbix ** Copyright (C) 2001-2025 Zabbix SIA ** ** Licensed under the Apache License, Version 2.0 (the "License"); ** you may not use this file except in compliance with the License. ** You may obtain a copy of the License at ** ** http://www.apache.org/licenses/LICENSE-2.0 ** ** Unless required by applicable law or agreed to in writing, software ** distributed under the License is distributed on an "AS IS" BASIS, ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ** See the License for the specific language governing permissions and ** limitations under the License. **/ package dbconn import ( "context" "database/sql" "database/sql/driver" "errors" stdlog "log" "os" "sync" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "golang.zabbix.com/plugin/mssql/plugin/params" "golang.zabbix.com/sdk/log" ) var ( _ driver.Driver = (*driverMock)(nil) _ driver.DriverContext = (*driverMock)(nil) _ driver.Connector = (*connectorMock)(nil) //nolint:gochecknoglobals // global driver instance. mockDriver = &driverMock{} ) type connectorMock struct { driver driver.Driver name string } type driverMock struct { openErr error driver driver.Driver } //nolint:gochecknoinits func init() { sql.Register("testdriver", mockDriver) } func (d *driverMock) Open(name string) (driver.Conn, error) { if d.openErr != nil { return nil, d.openErr } return d.driver.Open(name) } func (d *driverMock) OpenConnector(name string) (driver.Connector, error) { if d.openErr != nil { return nil, d.openErr } return &connectorMock{d.driver, name}, nil } func (d *driverMock) reset() { d.openErr = nil d.driver = nil } func (c *connectorMock) Connect(context.Context) (driver.Conn, error) { return c.driver.Open(c.name) } func (c *connectorMock) Driver() driver.Driver { return c.driver } func TestConnCollection_Init(t *testing.T) { t.Parallel() sampleLogr := &struct{ log.Logger }{} type fields struct { conns map[connConfig]*sql.DB keepAlive int queryTimeout int logr log.Logger driverName string } type args struct { keepAlive int queryTimetout int logr log.Logger } tests := []struct { name string fields fields args args want *ConnCollection }{ { "+valid", fields{}, args{10, 11, sampleLogr}, &ConnCollection{ conns: make(map[connConfig]*sql.DB), keepAlive: 10, queryTimeout: 11, logr: sampleLogr, driverName: "sqlserver", }, }, { "-overwrite", fields{ conns: map[connConfig]*sql.DB{{}: nil}, keepAlive: 3, queryTimeout: 4, logr: log.New("aaa"), driverName: "lol", }, args{10, 11, sampleLogr}, &ConnCollection{ conns: make(map[connConfig]*sql.DB), keepAlive: 10, queryTimeout: 11, logr: sampleLogr, driverName: "sqlserver", }, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() c := &ConnCollection{ conns: tt.fields.conns, keepAlive: tt.fields.keepAlive, queryTimeout: tt.fields.queryTimeout, logr: tt.fields.logr, driverName: tt.fields.driverName, } c.Init(tt.args.keepAlive, tt.args.queryTimetout, tt.args.logr) if diff := cmp.Diff( tt.want, c, cmp.AllowUnexported(ConnCollection{}, sync.Mutex{}), ); diff != "" { t.Fatalf("ConnCollection.Init() = %s", diff) } }) } } //nolint:paralleltest func TestConnCollection_WithConnHandlerFunc(t *testing.T) { log.DefaultLogger = stdlog.New(os.Stdout, "", stdlog.LstdFlags) type fields struct { getErr error dsn string } type args struct { metricParams map[string]string extraParams []string } tests := []struct { name string fields fields args args wantMetricParams map[string]string wantExtraParams []string want any wantErr bool }{ { "+valid", fields{ dsn: "pigeon://8888:dddd@uri:1433?app+name=Zabbix+agent+2+MSSQL+plugin&keepAlive=0", }, args{ metricParams: map[string]string{ params.URI.Name(): "pigeon://uri", params.User.Name(): "8888", params.Password.Name(): "dddd", }, }, map[string]string{ params.URI.Name(): "pigeon://uri", params.User.Name(): "8888", params.Password.Name(): "dddd", }, nil, "handler called", false, }, { "+extraParams", fields{ dsn: "pigeon://7777:dddd@uri:1433?app+name=Zabbix+agent+2+MSSQL+plugin&keepAlive=0", }, args{ metricParams: map[string]string{ params.URI.Name(): "pigeon://uri", params.User.Name(): "7777", params.Password.Name(): "dddd", "extra": "param", }, extraParams: []string{"extra", "spicey"}, }, map[string]string{ params.URI.Name(): "pigeon://uri", params.User.Name(): "7777", params.Password.Name(): "dddd", "extra": "param", }, []string{"extra", "spicey"}, "handler called", false, }, { "-getErr", fields{ dsn: "pigeon://6666:dddd@uri:1433?app+name=Zabbix+agent+2+MSSQL+plugin&keepAlive=0", getErr: errors.New("fail"), }, args{ metricParams: map[string]string{ params.URI.Name(): "pigeon://uri", params.User.Name(): "6666", params.Password.Name(): "dddd", }, }, nil, nil, nil, true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { //nolint:paralleltest c := &ConnCollection{ conns: map[connConfig]*sql.DB{}, driverName: "testdriver", logr: log.New("aaa"), queryTimeout: 1, } db, m, err := sqlmock.NewWithDSN( tt.fields.dsn, sqlmock.MonitorPingsOption(true), ) if err != nil { t.Fatalf("failed to open sqlmock: %s", err.Error()) } mockDriver.driver = db.Driver() defer mockDriver.reset() m.ExpectPing().WillReturnError(tt.fields.getErr) got, err := c.WithConnHandlerFunc( func( _ context.Context, db *sql.DB, metricParams map[string]string, extraParams ...string, ) (any, error) { if db == nil { t.Fatalf( "ConnCollection.WithConnHandlerFunc() db is nil", ) } if diff := cmp.Diff( tt.wantMetricParams, metricParams, ); diff != "" { t.Fatalf( "ConnCollection.WithConnHandlerFunc() = %s", diff, ) } if diff := cmp.Diff( tt.wantExtraParams, extraParams, ); diff != "" { t.Fatalf( "ConnCollection.WithConnHandlerFunc() = %s", diff, ) } return "handler called", nil }, )(tt.args.metricParams, tt.args.extraParams...) if (err != nil) != tt.wantErr { t.Fatalf( "ConnCollection.WithConnHandlerFunc() "+ "error = %v, wantErr %v", err, tt.wantErr, ) } if diff := cmp.Diff(tt.want, got); diff != "" { t.Fatalf("ConnCollection.WithConnHandlerFunc() = %s", diff) } }) } } //nolint:paralleltest func TestConnCollection_PingHandler(t *testing.T) { log.DefaultLogger = stdlog.New(os.Stdout, "", stdlog.LstdFlags) type expect struct { ping bool } type fields struct { getErr error pingErr error dsn string } type args struct { metricParams map[string]string } tests := []struct { name string expect expect fields fields args args want any wantErr bool }{ { "+valid", expect{true}, fields{ dsn: "pigeon://aaaa:dddd@uri:1433?app+name=Zabbix+agent+2+MSSQL+plugin&keepAlive=0", }, args{ metricParams: map[string]string{ params.URI.Name(): "pigeon://uri", params.User.Name(): "aaaa", params.Password.Name(): "dddd", }, }, 1, false, }, { "-getErr", expect{false}, fields{ getErr: errors.New("fail"), dsn: "pigeon://aaaa:bbbb@uri:1433?app+name=Zabbix+agent+2+MSSQL+plugin&keepAlive=0", }, args{ metricParams: map[string]string{ params.URI.Name(): "pigeon://uri", params.User.Name(): "aaaa", params.Password.Name(): "bbbb", }, }, 0, false, }, { "-pingErr", expect{true}, fields{ pingErr: errors.New("fail"), dsn: "pigeon://aaaa:cccc@uri:1433?app+name=Zabbix+agent+2+MSSQL+plugin&keepAlive=0", }, args{ metricParams: map[string]string{ params.URI.Name(): "pigeon://uri", params.User.Name(): "aaaa", params.Password.Name(): "cccc", }, }, 0, false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { //nolint:paralleltest c := &ConnCollection{ conns: map[connConfig]*sql.DB{}, logr: log.New("test"), driverName: "testdriver", queryTimeout: 1, } db, m, err := sqlmock.NewWithDSN( tt.fields.dsn, sqlmock.MonitorPingsOption(true), ) if err != nil { t.Fatalf("failed to open sqlmock: %s", err.Error()) } mockDriver.driver = db.Driver() defer mockDriver.reset() m.ExpectPing().WillReturnError(tt.fields.getErr) if tt.expect.ping { m.ExpectPing().WillReturnError(tt.fields.pingErr) } got, err := c.PingHandler(tt.args.metricParams) if (err != nil) != tt.wantErr { t.Fatalf( "ConnCollection.PingHandler() error = %v, wantErr %v", err, tt.wantErr, ) } if diff := cmp.Diff(tt.want, got); diff != "" { t.Fatalf("ConnCollection.PingHandler() = %s", diff) } if err := m.ExpectationsWereMet(); err != nil { t.Fatalf( "ConnCollection.PingHandler() "+ "expectations where not met: %s", err.Error(), ) } }) } } //nolint:paralleltest,tparallel func TestConnCollection_Close(t *testing.T) { log.DefaultLogger = stdlog.New(os.Stdout, "", stdlog.LstdFlags) type fields struct { closeErr error } tests := []struct { name string fields fields }{ {"+valid", fields{}}, {"-closeErr", fields{errors.New("fail")}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { //nolint:paralleltest t.Parallel() db, m, err := sqlmock.New() if err != nil { t.Fatalf("failed to open sqlmock: %s", err.Error()) } m.ExpectClose().WillReturnError(tt.fields.closeErr) c := &ConnCollection{ conns: map[connConfig]*sql.DB{{}: db}, logr: log.New("test"), } c.Close() if err := m.ExpectationsWereMet(); err != nil { t.Fatalf("ConnCollection.Close() = %s", err.Error()) } }) } } //nolint:paralleltest func TestConnCollection_get(t *testing.T) { log.DefaultLogger = stdlog.New(os.Stdout, "", stdlog.LstdFlags) type expect struct { newConn bool } type fields struct { conns map[connConfig]*sql.DB dsn string newConnErr error driverName string } type args struct { conf connConfig } tests := []struct { name string expect expect fields fields args args wantReceiver *ConnCollection wantNil bool wantErr bool }{ { "+validExisting", expect{false}, fields{ conns: map[connConfig]*sql.DB{ {URI: "pigeon://uri", User: "aaaa", Password: "bbbb"}: {}, }, driverName: "testdriver", }, args{ conf: connConfig{ URI: "pigeon://uri", User: "aaaa", Password: "bbbb", }, }, &ConnCollection{ conns: map[connConfig]*sql.DB{ {URI: "pigeon://uri", User: "aaaa", Password: "bbbb"}: {}, }, driverName: "testdriver", }, false, false, }, { "+validNew", expect{true}, fields{ conns: map[connConfig]*sql.DB{}, dsn: "pigeon://rrrr:tttt@uri:1433?app+name=Zabbix+agent+2+MSSQL+plugin&keepAlive=0", driverName: "testdriver", }, args{ conf: connConfig{ User: "rrrr", Password: "tttt", URI: "pigeon://uri", }, }, &ConnCollection{ conns: map[connConfig]*sql.DB{ {User: "rrrr", Password: "tttt", URI: "pigeon://uri"}: {}, }, driverName: "testdriver", }, false, false, }, { "+prevCons", expect{true}, fields{ conns: map[connConfig]*sql.DB{ {}: {}, }, dsn: "pigeon://jjjj:tttt@uri:1433?app+name=Zabbix+agent+2+MSSQL+plugin&keepAlive=0", driverName: "testdriver", }, args{ conf: connConfig{ User: "jjjj", Password: "tttt", URI: "pigeon://uri", }, }, &ConnCollection{ conns: map[connConfig]*sql.DB{ {}: {}, {User: "jjjj", Password: "tttt", URI: "pigeon://uri"}: {}, }, driverName: "testdriver", }, false, false, }, { "-newConnErr", expect{true}, fields{ conns: map[connConfig]*sql.DB{}, dsn: "pigeon://kkkk:tttt@uri:1433?app+name=Zabbix+agent+2+MSSQL+plugin&keepAlive=0", newConnErr: errors.New("fail"), driverName: "testdriver", }, args{ conf: connConfig{ User: "kkkk", Password: "tttt", URI: "pigeon://uri", }, }, &ConnCollection{ conns: map[connConfig]*sql.DB{}, driverName: "testdriver", }, true, true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { //nolint:paralleltest var ( db *sql.DB m sqlmock.Sqlmock err error ) if tt.expect.newConn { db, m, err = sqlmock.NewWithDSN( tt.fields.dsn, sqlmock.MonitorPingsOption(true), ) if err != nil { t.Fatalf("failed to open sqlmock: %s", err.Error()) } mockDriver.driver = db.Driver() defer mockDriver.reset() m.ExpectPing().WillReturnError(tt.fields.newConnErr) } c := &ConnCollection{ conns: tt.fields.conns, driverName: tt.fields.driverName, logr: log.New("test"), } got, err := c.get(context.Background(), tt.args.conf) if (err != nil) != tt.wantErr { t.Fatalf( "ConnCollection.get() error = %v, wantErr %v", err, tt.wantErr, ) } if (got == nil) != tt.wantNil { t.Fatalf( "ConnCollection.get() got = %v, wantNil %v", got, tt.wantNil, ) } if diff := cmp.Diff( tt.wantReceiver, c, cmp.AllowUnexported(ConnCollection{}, sync.Mutex{}), cmp.Comparer( func(x, y *sql.DB) bool { return (x == nil) == (y == nil) }, ), cmpopts.SortMaps( func(x, y connConfig) bool { return x.User < y.User }, ), cmpopts.IgnoreFields(ConnCollection{}, "logr"), ); diff != "" { t.Fatalf("ConnCollection.get() = %s", diff) } if m != nil { if err := m.ExpectationsWereMet(); err != nil { t.Fatalf("ConnCollection.get() = %s", err.Error()) } } }, ) } } //nolint:paralleltest func TestConnCollection_newConn(t *testing.T) { log.DefaultLogger = stdlog.New(os.Stdout, "", stdlog.LstdFlags) type expect struct { open bool ping bool } type fields struct { keepAlive int openErr error pingErr error dsn string driverName string defaultScheme string } type args struct { conf *connConfig } tests := []struct { name string expect expect fields fields args args wantNil bool wantErr bool }{ { "+valid", expect{true, true}, fields{ keepAlive: 4, dsn: "pigeon://aaaa:bbbb@uri:1433?app+name=Zabbix+agent+2+MSSQL+plugin&keepAlive=4", driverName: "testdriver", }, args{&connConfig{ User: "aaaa", Password: "bbbb", URI: "pigeon://uri", }}, false, false, }, { "+named", expect{true, true}, fields{ keepAlive: 4, dsn: "pigeon://aaaa:bbbb@uri/InstanceName?app+name=Zabbix+agent+2+MSSQL+plugin&keepAlive=4", driverName: "testdriver", }, args{&connConfig{ User: "aaaa", Password: "bbbb", URI: "pigeon://uri/InstanceName", }}, false, false, }, { "+namedWithPort", expect{true, true}, fields{ keepAlive: 4, dsn: "pigeon://aaaa:bbbb@uri:1435/InstanceName?app+name=Zabbix+agent+2+MSSQL+plugin&keepAlive=4", driverName: "testdriver", }, args{&connConfig{ User: "aaaa", Password: "bbbb", URI: "pigeon://uri:1435/InstanceName", }}, false, false, }, { "+validWithTLS", expect{true, true}, fields{ keepAlive: 4, dsn: "pigeon://aaaa:bbbb@uri:1433?" + "TrustServerCertificate=false&" + "app+name=Zabbix+agent+2+MSSQL+plugin&" + "certificate=abc&" + "encrypt=true&" + "hostNameInCertificate=server&" + "keepAlive=4&" + "tlsMinVersion=1.3", driverName: "testdriver", }, args{&connConfig{ User: "aaaa", Password: "bbbb", URI: "pigeon://uri", CACertPath: "abc", TrustServerCertificate: "false", HostNameInCertificate: "server", Encrypt: "true", TLSMinVersion: "1.3", }}, false, false, }, { "-newConnURIErr", expect{false, false}, fields{ keepAlive: 4, openErr: nil, driverName: "testdriver", }, args{&connConfig{ User: "aaaa", Password: "bbbb", URI: "://", }}, true, true, }, { "-openErr", expect{true, false}, fields{ keepAlive: 4, openErr: errors.New("fail"), dsn: "pigeon://cccc:bbbb@uri:1433?app+name=Zabbix+agent+2+MSSQL+plugin&keepAlive=4", driverName: "testdriver", }, args{&connConfig{ User: "cccc", Password: "bbbb", URI: "pigeon://uri", }}, true, true, }, } for _, tt := range tests { t.Run( tt.name, func(t *testing.T) { //nolint:paralleltest if tt.fields.defaultScheme != "" { prevScheme := params.URIDefaults.Scheme defer func() { params.URIDefaults.Scheme = prevScheme }() params.URIDefaults.Scheme = tt.fields.defaultScheme } var ( db *sql.DB m sqlmock.Sqlmock err error ) if tt.expect.open { db, m, err = sqlmock.NewWithDSN( tt.fields.dsn, sqlmock.MonitorPingsOption(true), ) if err != nil { t.Fatalf("failed to open sqlmock: %s", err.Error()) } mockDriver.openErr = tt.fields.openErr mockDriver.driver = db.Driver() defer mockDriver.reset() if tt.expect.ping { m.ExpectPing().WillReturnError(tt.fields.pingErr) } } c := &ConnCollection{ keepAlive: tt.fields.keepAlive, driverName: tt.fields.driverName, logr: log.New("test"), } got, err := c.newConn(context.Background(), tt.args.conf) if (err != nil) != tt.wantErr { t.Fatalf( "ConnCollection.newConn() error = %v, wantErr %v", err, tt.wantErr, ) } if (got == nil) != tt.wantNil { t.Fatalf( "ConnCollection.newConn() got = %v, wantNil %v", got, tt.wantNil, ) } if m != nil { if err := m.ExpectationsWereMet(); err != nil { t.Fatalf( "ConnCollection.newConn() "+ "expectations where not met: %s", err.Error(), ) } } }, ) } } func Test_newConnConfig(t *testing.T) { t.Parallel() type args struct { metricParams map[string]string } tests := []struct { name string args args want connConfig }{ { "+valid", args{ map[string]string{ "URI": "pigeon://uri", "User": "aaaa", "Password": "bbbb", "CACertPath": "/a/b/c", "TrustServerCertificate": "false", "HostNameInCertificate": "true", "Encrypt": "false", "TLSMinVersion": "1.2", "Database": "testdb", }, }, connConfig{ URI: "pigeon://uri", User: "aaaa", Password: "bbbb", CACertPath: "/a/b/c", TrustServerCertificate: "false", HostNameInCertificate: "true", Encrypt: "false", TLSMinVersion: "1.2", Database: "testdb", }, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() got := newConnConfig(tt.args.metricParams) if diff := cmp.Diff(tt.want, got); diff != "" { t.Fatalf("newConnConfig() = %s", diff) } }) } }