/*
** Zabbix
** Copyright (C) 2001-2024 Zabbix SIA
**
** Licensed under the Apache License, Version 2.0 (the "License");
** you may not use this file except in compliance with the License.
** You may obtain a copy of the License at
**
**     http://www.apache.org/licenses/LICENSE-2.0
**
** Unless required by applicable law or agreed to in writing, software
** distributed under the License is distributed on an "AS IS" BASIS,
** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
** See the License for the specific language governing permissions and
** limitations under the License.
**/

package nvmlmock

import (
	"testing"

	"github.com/google/go-cmp/cmp"
	"golang.zabbix.com/plugin/nvidia/pkg/nvml"
)

var (
	_ nvml.Device = (*MockDevice)(nil)
	_ Mocker      = (*MockDevice)(nil)
)

// MockDevice is mock for NVML device.
type MockDevice struct {
	nvml.Device
	expectations []*Expectation
	callIdx      int
	t            *testing.T
}

// NewMockDevice returns new mock device.
func NewMockDevice(t *testing.T) *MockDevice {
	t.Helper()

	return &MockDevice{
		t:            t,
		expectations: []*Expectation{},
	}
}

// ExpectCalls sets calls that are expected by mock.
func (m *MockDevice) ExpectCalls(expectations ...*Expectation) *MockDevice {
	m.expectations = expectations

	return m
}

// ExpectedCallsDone checks if all expected calls of mock and it's submocks are done.
func (m *MockDevice) ExpectedCallsDone() bool {
	m.t.Helper()
	expected := len(m.expectations)
	received := m.callIdx

	if expected == received {
		return true
	}

	for _, e := range m.expectations[received:] {
		m.t.Errorf("Not called %q", e.funcName)
	}

	m.t.Errorf("received %d out of %d expected calls", received, expected)

	return false
}

// SubMocks returns submocks of the mock.
func (m *MockDevice) SubMocks() []Mocker {
	var subMocks []Mocker

	for _, expect := range m.expectations {
		for _, out := range expect.out {
			subMock, ok := out.(Mocker)
			if !ok {
				continue
			}

			subMocks = append(subMocks, subMock)
			subMocks = append(subMocks, subMock.SubMocks()...)
		}
	}

	return subMocks
}

// GetUUID is mock function.
func (m *MockDevice) GetUUID() (string, error) {
	m.t.Helper()

	res, err := m.handleFunctionCall("GetUUID")

	uuid, ok := res.out[0].(string)
	if !ok {
		m.t.Fatalf("expected string in GetUUID, got %T", res.out[0])
	}

	return uuid, err
}

// GetName is mock function.
func (m *MockDevice) GetName() (string, error) {
	m.t.Helper()
	res, err := m.handleFunctionCall("GetName")

	uuid, ok := res.out[0].(string)
	if !ok {
		m.t.Fatalf("expected string in GetName, got %T", res.out[0])
	}

	return uuid, err
}

// GetSerial is mock function.
func (m *MockDevice) GetSerial() (string, error) {
	m.t.Helper()
	res, err := m.handleFunctionCall("GetSerial")

	serial, ok := res.out[0].(string)
	if !ok {
		m.t.Fatalf("expected string in GetSerial, got %T", res.out[0])
	}

	return serial, err
}

// GetTemperature is mock function.
func (m *MockDevice) GetTemperature() (int, error) {
	m.t.Helper()
	res, err := m.handleFunctionCall("GetTemperature")

	temperature, ok := res.out[0].(int)
	if !ok {
		m.t.Fatalf("expected string in GetTemperature, got %T", res.out[0])
	}

	return temperature, err
}

// GetFanSpeed is mock function.
func (m *MockDevice) GetFanSpeed() (uint, error) {
	m.t.Helper()
	res, err := m.handleFunctionCall("GetFanSpeed")

	fanSpeed, ok := res.out[0].(uint)
	if !ok {
		m.t.Fatalf("expected string in GetFanSpeed, got %T", res.out[0])
	}

	return fanSpeed, err
}

// GetPerformanceState is mock function.
func (m *MockDevice) GetPerformanceState() (uint, error) {
	m.t.Helper()
	res, err := m.handleFunctionCall("GetPerformanceState")

	state, ok := res.out[0].(uint)
	if !ok {
		m.t.Fatalf("expected string in GetPerformanceState, got %T", res.out[0])
	}

	return state, err
}

// GetPowerManagementLimit is mock function.
func (m *MockDevice) GetPowerManagementLimit() (uint, error) {
	m.t.Helper()
	res, err := m.handleFunctionCall("GetPowerManagementLimit")

	limit, ok := res.out[0].(uint)
	if !ok {
		m.t.Fatalf("expected string in GetPowerManagementLimit, got %T", res.out[0])
	}

	return limit, err
}

// GetPowerUsage is mock function.
func (m *MockDevice) GetPowerUsage() (uint, error) {
	m.t.Helper()
	res, err := m.handleFunctionCall("GetPowerUsage")

	usage, ok := res.out[0].(uint)
	if !ok {
		m.t.Fatalf("expected string in GetPowerUsage, got %T", res.out[0])
	}

	return usage, err
}

// GetClockInfo is mock function.
func (m *MockDevice) GetClockInfo(clock nvml.ClockType) (uint, error) {
	m.t.Helper()
	res, err := m.handleFunctionCall("GetClockInfo", clock)

	c, ok := res.out[0].(uint)
	if !ok {
		m.t.Fatalf("expected string in GetClockInfo, got %T", res.out[0])
	}

	return c, err
}

// GetTotalEnergyConsumption is mock function.
func (m *MockDevice) GetTotalEnergyConsumption() (uint64, error) {
	m.t.Helper()
	res, err := m.handleFunctionCall("GetTotalEnergyConsumption")

	consumption, ok := res.out[0].(uint64)
	if !ok {
		m.t.Fatalf("expected uint in GetTotalEnergyConsumption, got %T", res.out[0])
	}

	return consumption, err
}

// GetUtilizationRates is mock function.
func (m *MockDevice) GetUtilizationRates() (uint, uint, error) {
	m.t.Helper()
	res, err := m.handleFunctionCall("GetUtilizationRates")

	util1, ok := res.out[0].(uint)
	if !ok {
		m.t.Fatalf("expected uint for util1 in GetUtilizationRates, got %T", res.out[0])
	}

	util2, ok := res.out[1].(uint)
	if !ok {
		m.t.Fatalf("expected uint for util2 in GetUtilizationRates, got %T", res.out[1])
	}

	return util1, util2, err
}

// GetMemoryErrorCounter is mock function.
func (m *MockDevice) GetMemoryErrorCounter(
	memoryType nvml.MemoryErrorType,
	memoryLocation nvml.MemoryLocation,
	counterType nvml.EccCounterType) (uint64, error) {
	m.t.Helper()
	res, err := m.handleFunctionCall("GetMemoryErrorCounter",
		memoryType, memoryLocation, counterType,
	)

	rate, ok := res.out[0].(uint64)
	if !ok {
		m.t.Fatalf("expected uint64 for rate in GetMemoryErrorCounter, got %T", res.out[0])
	}

	return rate, err
}

// GetEncoderUtilization is mock function.
func (m *MockDevice) GetEncoderUtilization() (uint, uint, error) {
	m.t.Helper()
	res, err := m.handleFunctionCall("GetEncoderUtilization")

	util1, ok := res.out[0].(uint)
	if !ok {
		m.t.Fatalf("expected uint for util1 in GetEncoderUtilisation, got %T", res.out[0])
	}

	util2, ok := res.out[1].(uint)
	if !ok {
		m.t.Fatalf("expected uint for util2 in GetEncoderUtilisation, got %T", res.out[1])
	}

	return util1, util2, err
}

// GetDecoderUtilization is mock function.
func (m *MockDevice) GetDecoderUtilization() (uint, uint, error) {
	m.t.Helper()
	res, err := m.handleFunctionCall("GetDecoderUtilization")

	util1, ok := res.out[0].(uint)
	if !ok {
		m.t.Fatalf("expected uint for util1 in GetDecoderUtilisation, got %T", res.out[0])
	}

	util2, ok := res.out[1].(uint)
	if !ok {
		m.t.Fatalf("expected uint for util2 in GetDecoderUtilisation, got %T", res.out[1])
	}

	return util1, util2, err
}

// GetEccMode is mock function.
func (m *MockDevice) GetEccMode() (bool, bool, error) {
	m.t.Helper()
	res, err := m.handleFunctionCall("GetEccMode")

	current, ok := res.out[0].(bool)
	if !ok {
		m.t.Fatalf("expected bool in GetEccMode, got %T", res.out[0])
	}

	pending, ok := res.out[1].(bool)
	if !ok {
		m.t.Fatalf("expected bool in GetEccMode, got %T", res.out[1])
	}

	return current, pending, err
}

// GetPCIeThroughput is mock function.
func (m *MockDevice) GetPCIeThroughput(pcie nvml.PcieMetricType) (uint, error) {
	m.t.Helper()
	res, err := m.handleFunctionCall("GetPCIeThroughput", pcie)

	throughput, ok := res.out[0].(uint)
	if !ok {
		m.t.Fatalf("expected uint in GetPCIeThroughput, got %T", res.out[0])
	}

	return throughput, err
}

// GetMemoryInfoV2 is mock function.
func (m *MockDevice) GetMemoryInfoV2() (*nvml.MemoryInfoV2, error) {
	m.t.Helper()
	res, err := m.handleFunctionCall("GetMemoryInfoV2")

	if err != nil {
		return nil, err
	}

	info, ok := res.out[0].(*nvml.MemoryInfoV2)
	if !ok {
		m.t.Fatalf("expected %T in GetMemoryInfoV2, got %T", info, res.out[0])
	}

	return info, err
}

// GetMemoryInfo is mock function.
func (m *MockDevice) GetMemoryInfo() (*nvml.MemoryInfo, error) {
	m.t.Helper()
	res, err := m.handleFunctionCall("GetMemoryInfo")

	if err != nil {
		return nil, err
	}

	info, ok := res.out[0].(*nvml.MemoryInfo)
	if !ok {
		m.t.Fatalf("expected %T in GetMemoryInfo, got %T", info, res.out[0])
	}

	return info, err
}

// GetBAR1MemoryInfo is mock function.
func (m *MockDevice) GetBAR1MemoryInfo() (*nvml.MemoryInfo, error) {
	m.t.Helper()
	res, err := m.handleFunctionCall("GetBAR1MemoryInfo")

	info, ok := res.out[0].(*nvml.MemoryInfo)
	if !ok {
		m.t.Fatalf("expected %T in GetBAR1MemoryInfo, got %T", info, res.out[0])
	}

	return info, err
}

// GetEncoderStats is mock function.
func (m *MockDevice) GetEncoderStats() (uint, uint, uint, error) {
	m.t.Helper()
	res, err := m.handleFunctionCall("GetEncoderStats")

	sessions, ok := res.out[0].(uint)
	if !ok {
		m.t.Fatalf("expected %T in GetEncoderStats, got %T", sessions, res.out[0])
	}

	fps, ok := res.out[1].(uint)
	if !ok {
		m.t.Fatalf("expected %T in GetEncoderStats, got %T", fps, res.out[1])
	}

	latency, ok := res.out[2].(uint)
	if !ok {
		m.t.Fatalf("expected %T in GetEncoderStats, got %T", latency, res.out[2])
	}

	return sessions, fps, latency, err
}

// handleFunctionCall is handler for mock function calls.
// Takes in function name and any number of arguments function received.
func (m *MockDevice) handleFunctionCall(name string, receivedArgs ...any) (*Expectation, error) {
	m.t.Helper()

	if m.callIdx >= len(m.expectations) {
		m.t.Fatalf("no more calls expected but got call for %q", name)
	}

	expect := m.expectations[m.callIdx]
	m.callIdx++

	if expect.funcName != name {
		m.t.Fatalf("got call for %q while expected for %q", name, expect.funcName)
	}

	if receivedArgs == nil {
		receivedArgs = []any{}
	}

	// Compare expectedArgs and receivedArgs using cmp
	if diff := cmp.Diff(expect.args, receivedArgs); diff != "" {
		m.t.Fatalf(`arguments mismatch in %s call %d:\nexpected: %v\nreceived: %v\ndiff: %s`,
			name,
			m.callIdx,
			expect.args,
			receivedArgs,
			diff,
		)
	}

	return expect, expect.err
}