package mongodb

import (
	"errors"
	"fmt"
	"time"

	"git.zabbix.com/ap/plugin-support/zbxerr"
	"gopkg.in/mgo.v2"
	"gopkg.in/mgo.v2/bson"
)

const (
	mustFail = "mustFail"
)

type MockConn struct {
	dbs map[string]*MockMongoDatabase
}

func NewMockConn() *MockConn {
	return &MockConn{
		dbs: make(map[string]*MockMongoDatabase),
	}
}

func (conn *MockConn) DB(name string) Database {
	if db, ok := conn.dbs[name]; ok {
		return db
	}

	conn.dbs[name] = &MockMongoDatabase{
		name:        name,
		collections: make(map[string]*MockMongoCollection),
	}

	return conn.dbs[name]
}

func (conn *MockConn) DatabaseNames() (names []string, err error) {
	for _, db := range conn.dbs {
		if db.name == mustFail {
			return nil, zbxerr.ErrorCannotFetchData
		}

		names = append(names, db.name)
	}

	return
}

func (conn *MockConn) Ping() error {
	return nil
}

func (conn *MockConn) GetMaxTimeMS() int64 {
	return 3000
}

type MockSession interface {
	DB(name string) Database
	DatabaseNames() (names []string, err error)
	GetMaxTimeMS() int64
	Ping() error
}

type MockMongoDatabase struct {
	name        string
	collections map[string]*MockMongoCollection
	RunFunc     func(dbName, cmd string) ([]byte, error)
}

func (d *MockMongoDatabase) C(name string) Collection {
	if col, ok := d.collections[name]; ok {
		return col
	}

	d.collections[name] = &MockMongoCollection{
		name:    name,
		queries: make(map[interface{}]*MockMongoQuery),
	}

	return d.collections[name]
}

func (d *MockMongoDatabase) CollectionNames() (names []string, err error) {
	for _, col := range d.collections {
		if col.name == mustFail {
			return nil, errors.New("fail")
		}

		names = append(names, col.name)
	}

	return
}

func (d *MockMongoDatabase) Run(cmd, result interface{}) error {
	if d.RunFunc == nil {
		d.RunFunc = func(dbName, _ string) ([]byte, error) {
			if dbName == mustFail {
				return nil, errors.New("fail")
			}

			return bson.Marshal(map[string]int{"ok": 1})
		}
	}

	if result == nil {
		return nil
	}

	bsonDcmd := *(cmd.(*bson.D))
	cmdName := bsonDcmd[0].Name

	data, err := d.RunFunc(d.name, cmdName)
	if err != nil {
		return err
	}

	return bson.Unmarshal(data, result)
}

type MockMongoCollection struct {
	name    string
	queries map[interface{}]*MockMongoQuery
}

func (c *MockMongoCollection) Find(query interface{}) Query {
	queryHash := fmt.Sprintf("%v", query)
	if q, ok := c.queries[queryHash]; ok {
		return q
	}

	c.queries[queryHash] = &MockMongoQuery{
		collection: c.name,
		query:      query,
	}

	return c.queries[queryHash]
}

type MockMongoQuery struct {
	collection string
	query      interface{}
	sortFields []string
	DataFunc   func(collection string, query interface{}, sortFields ...string) ([]byte, error)
}

func (q *MockMongoQuery) retrieve(result interface{}) error {
	if q.DataFunc == nil {
		return mgo.ErrNotFound
	}

	if result == nil {
		return nil
	}

	data, err := q.DataFunc(q.collection, q.query, q.sortFields...)
	if err != nil {
		return err
	}

	return bson.Unmarshal(data, result)
}

func (q *MockMongoQuery) All(result interface{}) error {
	return q.retrieve(result)
}

func (q *MockMongoQuery) Count() (n int, err error) {
	return 1, nil
}

func (q *MockMongoQuery) Limit(n int) Query {
	return q
}

func (q *MockMongoQuery) One(result interface{}) error {
	return q.retrieve(result)
}

func (q *MockMongoQuery) SetMaxTime(_ time.Duration) Query {
	return q
}

func (q *MockMongoQuery) Sort(fields ...string) Query {
	q.sortFields = fields
	return q
}