/* ** Zabbix ** Copyright (C) 2001-2023 Zabbix SIA ** ** This program is free software; you can redistribute it and/or modify ** it under the terms of the GNU General Public License as published by ** the Free Software Foundation; either version 2 of the License, or ** (at your option) any later version. ** ** This program is distributed in the hope that it will be useful, ** but WITHOUT ANY WARRANTY; without even the implied warranty of ** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the ** GNU General Public License for more details. ** ** You should have received a copy of the GNU General Public License ** along with this program; if not, write to the Free Software ** Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. **/ package zbxcomms import ( "bytes" "compress/zlib" "encoding/binary" "errors" "fmt" "io" "net" "time" "git.zabbix.com/ap/plugin-support/log" "zabbix.com/pkg/tls" ) const ( TimeoutModeFixed = iota TimeoutModeShift ) const headerSize = 4 + 1 + 4 + 4 const tcpProtocol = byte(0x01) const zlibCompress = byte(0x02) const ( connStateAccept = iota + 1 connStateConnect connStateEstablished ) type Connection struct { conn net.Conn tlsConfig *tls.Config state int compress bool timeout time.Duration timeoutMode int } type Listener struct { listener net.Listener tlsconfig *tls.Config } func open(address string, localAddr *net.Addr, timeout time.Duration, connect_timeout time.Duration, timeoutMode int, args ...interface{}) (c *Connection, err error) { c = &Connection{state: connStateConnect, compress: true, timeout: timeout, timeoutMode: timeoutMode} d := net.Dialer{Timeout: connect_timeout, LocalAddr: *localAddr} c.conn, err = d.Dial("tcp", address) if nil != err { return } if err = c.conn.SetDeadline(time.Now().Add(timeout)); err != nil { return } var tlsconfig *tls.Config if len(args) > 0 { var ok bool if tlsconfig, ok = args[0].(*tls.Config); !ok { return nil, fmt.Errorf("invalid TLS configuration parameter of type %T", args[0]) } if tlsconfig != nil { c.conn, err = tls.NewClient(c.conn, tlsconfig, timeout, timeoutMode == TimeoutModeShift, address) } } return } func (c *Connection) write(w io.Writer, data []byte) (err error) { var buf bytes.Buffer flags := tcpProtocol if c.compress { z := zlib.NewWriter(&buf) if _, err = z.Write(data); err != nil { return } z.Close() flags |= zlibCompress } else { buf.Write(data) } var b bytes.Buffer b.Grow(buf.Len() + headerSize) b.Write([]byte{'Z', 'B', 'X', 'D', flags}) if err = binary.Write(&b, binary.LittleEndian, uint32(buf.Len())); nil != err { return err } if err = binary.Write(&b, binary.LittleEndian, uint32(len(data))); nil != err { return err } b.Write(buf.Bytes()) _, err = w.Write(b.Bytes()) return err } func (c *Connection) Write(data []byte) error { if c.timeoutMode == TimeoutModeShift { if err := c.conn.SetWriteDeadline(time.Now().Add(c.timeout)); err != nil { return err } } return c.write(c.conn, data) } func (c *Connection) WriteString(s string) error { return c.Write([]byte(s)) } func (c *Connection) read(r io.Reader, pending []byte) ([]byte, error) { const maxRecvDataSize = 128 * 1048576 var total int var b [2048]byte var reservedSize uint32 s := b[:] if pending != nil { total = len(pending) if total > len(b) { return nil, errors.New("pending data exceeds limit of 2KB bytes") } copy(s, pending) } for total < headerSize { n, err := r.Read(s[total:]) if err != nil && err != io.EOF { return nil, fmt.Errorf("Cannot read message: '%s'", err) } if n == 0 { break } total += n } if total < 13 { if total == 0 { return []byte{}, nil } return nil, fmt.Errorf("Message is missing header.") } if !bytes.Equal(s[:4], []byte{'Z', 'B', 'X', 'D'}) { return nil, fmt.Errorf("Message is using unsupported protocol.") } flags := s[4] if 0 == (flags & tcpProtocol) { return nil, fmt.Errorf("Message is using unsupported protocol version.") } expectedSize := binary.LittleEndian.Uint32(s[5:9]) if expectedSize > maxRecvDataSize { return nil, fmt.Errorf("Message size %d exceeds the maximum size %d bytes.", expectedSize, maxRecvDataSize) } if int(expectedSize) < total-headerSize { return nil, fmt.Errorf("Message is longer than expected.") } if 0 != (flags & zlibCompress) { reservedSize = binary.LittleEndian.Uint32(s[9:13]) } if int(expectedSize) == total-headerSize { if 0 != (flags & zlibCompress) { return c.uncompress(s[headerSize:total], reservedSize) } return s[headerSize:total], nil } sTmp := make([]byte, expectedSize+1) if total > headerSize { copy(sTmp, s[headerSize:total]) } s = sTmp total = total - headerSize for total < int(expectedSize) { n, err := r.Read(s[total:]) if err != nil { return nil, err } if n == 0 { break } total += n } if total != int(expectedSize) { return nil, fmt.Errorf("Message size is shorted or longer than expected.") } if 0 != (flags & zlibCompress) { return c.uncompress(s[:total], reservedSize) } return s[:total], nil } func (c *Connection) uncompress(data []byte, expLen uint32) ([]byte, error) { var b bytes.Buffer b.Grow(int(expLen)) z, err := zlib.NewReader(bytes.NewReader(data)) if nil != err { return nil, fmt.Errorf("Unable to uncompress message: '%s'", err) } len, err := b.ReadFrom(z) z.Close() if nil != err { return nil, fmt.Errorf("Unable to uncompress message: '%s'", err) } if len != int64(expLen) { return nil, fmt.Errorf("Uncompressed message size %d instead of expected %d.", len, expLen) } return b.Bytes(), nil } func (c *Connection) Read() (data []byte, err error) { if c.timeoutMode == TimeoutModeShift { if err = c.conn.SetReadDeadline(time.Now().Add(c.timeout)); err != nil { return } } if c.state == connStateAccept && c.tlsConfig != nil { c.state = connStateEstablished b := make([]byte, 1) var n int if n, err = c.conn.Read(b); err != nil { return } if n == 0 { return nil, errors.New("connection closed") } if b[0] != '\x16' { // unencrypted connection if c.tlsConfig.Accept&tls.ConnUnencrypted == 0 { return nil, errors.New("cannot accept unencrypted connection") } return c.read(c.conn, b) } if c.tlsConfig.Accept&(tls.ConnPSK|tls.ConnCert) == 0 { return nil, errors.New("cannot accept encrypted connection") } var tlsConn net.Conn if tlsConn, err = tls.NewServer(c.conn, c.tlsConfig, b, c.timeout, c.timeoutMode == TimeoutModeShift); err != nil { return } c.conn = tlsConn } return c.read(c.conn, nil) } func (c *Connection) RemoteIP() string { addr, _, _ := net.SplitHostPort(c.conn.RemoteAddr().String()) return addr } func (l *Listener) Accept(timeout time.Duration, timeoutMode int) (c *Connection, err error) { var conn net.Conn if conn, err = l.listener.Accept(); err != nil { return } else { c = &Connection{conn: conn, tlsConfig: l.tlsconfig, state: connStateAccept, timeout: timeout, timeoutMode: timeoutMode} } return } func (c *Connection) Close() (err error) { if c.conn != nil { err = c.conn.Close() } return } func (c *Connection) SetCompress(compress bool) { c.compress = compress } func (c *Listener) Close() (err error) { return c.listener.Close() } func Exchange(addresses *[]string, localAddr *net.Addr, timeout time.Duration, connect_timeout time.Duration, data []byte, args ...interface{}) (b []byte, errs []error, errRead error) { log.Tracef("connecting to %s [timeout:%s, connection timeout:%s]", *addresses, timeout, connect_timeout) var tlsconfig *tls.Config var err error var c *Connection var no_response = false if len(args) > 0 { var ok bool if tlsconfig, ok = args[0].(*tls.Config); !ok { errs = append(errs, fmt.Errorf("invalid TLS configuration parameter of type %T", args[0])) log.Tracef("%s", errs[len(errs)-1]) return nil, errs, nil } if len(args) > 1 { if no_response, ok = args[1].(bool); !ok { errs = append(errs, fmt.Errorf("invalid response handling flag of type %T", args[1])) log.Tracef("%s", errs[len(errs)-1]) return nil, errs, nil } } } for i := 0; i < len(*addresses); i++ { c, err = open((*addresses)[0], localAddr, timeout, connect_timeout, TimeoutModeFixed, tlsconfig) if err == nil { break } errs = append(errs, fmt.Errorf("cannot connect to [%s]: %s", (*addresses)[0], err)) log.Tracef("%s", errs[len(errs)-1]) tmp := (*addresses)[0] *addresses = (*addresses)[1:] *addresses = append(*addresses, tmp) } if err != nil { return nil, errs, nil } defer c.Close() log.Tracef("sending [%s] to [%s]", string(data), (*addresses)[0]) err = c.Write(data) if err != nil { errs = append(errs, fmt.Errorf("cannot send to [%s]: %s", (*addresses)[0], err)) log.Tracef("%s", errs[len(errs)-1]) return nil, errs, nil } log.Tracef("receiving data from [%s]", (*addresses)[0]) b, err = c.Read() if err != nil { errs = append(errs, fmt.Errorf("cannot receive data from [%s]: %s", (*addresses)[0], err)) log.Tracef("%s", errs[len(errs)-1]) return nil, errs, errs[len(errs)-1] } log.Tracef("received [%s] from [%s]", string(b), (*addresses)[0]) if len(b) == 0 && false == no_response { errs = append(errs, fmt.Errorf("connection closed")) log.Tracef("%s", errs[len(errs)-1]) return nil, errs, errs[len(errs)-1] } return b, nil, nil }