** Zabbix
** Copyright (C) 2001-2024 Zabbix SIA
** Licensed under the Apache License, Version 2.0 (the "License");
** you may not use this file except in compliance with the License.
** You may obtain a copy of the License at
**     http://www.apache.org/licenses/LICENSE-2.0
** Unless required by applicable law or agreed to in writing, software
** distributed under the License is distributed on an "AS IS" BASIS,
** See the License for the specific language governing permissions and
** limitations under the License.

package nvml

#cgo CFLAGS: -I${SRCDIR}/nvml-sdk/include

#include "nvml.h"
import "C" //nolint:gocritic,gci
import (
	"unsafe" //nolint:gocritic


var (
	_ Device = (*NVMLDevice)(nil)
	_ Runner = (*NVMLRunner)(nil)

// NVMLRunner manages the loading and retrieval of functions from a Windows DLL.
// It is responsible for ensuring thread-safe access to the list of loaded
// procedures and facilitates dynamic function calls from the DLL.
type NVMLRunner struct {
	dll         *windows.DLL
	procListMux *sync.Mutex
	procList    map[string]*windows.Proc

// NVMLDevice represents an NVML GPU device, identified by a unique handle.
type NVMLDevice struct {
	handle uintptr
	runner *NVMLRunner // Reference to the Runner (formerly NVMLRunner)

// NewNVMLRunner creates a new NVML Runner instance, loading the NVML library.
func NewNVMLRunner() (*NVMLRunner, error) {
	dll, err := windows.LoadDLL("nvml.dll")
	if err != nil {
		return nil, errs.WrapConst(err, ErrLibraryNotFound) //nolint:wrapcheck

	runner := &NVMLRunner{
		dll:         dll,
		procList:    make(map[string]*windows.Proc),
		procListMux: &sync.Mutex{},

	return runner, nil

// Init initializes the NVML library using the older NVML interface.
func (runner *NVMLRunner) Init() error {
	err := runner.callProc("nvmlInit")
	if err != nil {
		return errs.Wrap(err, "failed while calling procedure")

	return nil

// InitV2 initializes the NVML library using the NVML v2 interface.
func (runner *NVMLRunner) InitV2() error {
	err := runner.callProc("nvmlInit_v2")
	if err != nil {
		return errs.Wrap(err, "failed while calling procedure")

	return nil

// GetNVMLVersion retrieves the version of the NVML library currently in use.
func (runner *NVMLRunner) GetNVMLVersion() (string, error) {
	var version [systemNVMLVersionBufferSize]byte

	err := runner.callProc("nvmlSystemGetNVMLVersion",
	if err != nil {
		return "", errs.Wrap(err, "failed while calling procedure")

	return windows.ByteSliceToString(version[:]), nil

// GetDriverVersion retrieves the version of the NVIDIA driver currently in use.
func (runner *NVMLRunner) GetDriverVersion() (string, error) {
	var version [systemDriverVersionBufferSize]byte

	err := runner.callProc("nvmlSystemGetDriverVersion",
	if err != nil {
		return "", errs.Wrap(err, "failed while calling procedure")

	return windows.ByteSliceToString(version[:]), nil

// GetDeviceCountV2 retrieves the number of NVIDIA devices using the NVML v2 interface.
func (runner *NVMLRunner) GetDeviceCountV2() (uint, error) {
	var deviceCount C.uint

	err := runner.callProc("nvmlDeviceGetCount_v2", uintptr(unsafe.Pointer(&deviceCount)))
	if err != nil {
		return 0, errs.Wrap(err, "failed while calling procedure")

	return uint(deviceCount), nil

// GetDeviceCount retrieves the number of NVIDIA devices using the standard NVML interface.
func (runner *NVMLRunner) GetDeviceCount() (uint, error) {
	var deviceCount C.uint

	err := runner.callProc("nvmlDeviceGetCount", uintptr(unsafe.Pointer(&deviceCount)))
	if err != nil {
		return 0, errs.Wrap(err, "failed while calling procedure")

	return uint(deviceCount), nil

// GetDeviceByIndexV2 retrieves a handle to an NVIDIA device by its index using the NVML v2 interface.
func (runner *NVMLRunner) GetDeviceByIndexV2(index uint) (Device, error) {
	var deviceHandle uintptr

	err := runner.callProc("nvmlDeviceGetHandleByIndex_v2",
	if err != nil {
		return nil, errs.Wrap(err, "failed while calling procedure")

	return &NVMLDevice{
		handle: deviceHandle,
		runner: runner,
	}, nil

// GetDeviceByUUID retrieves a handle to an NVIDIA device by its UUID.
func (runner *NVMLRunner) GetDeviceByUUID(uuid string) (Device, error) {
	var deviceHandle uintptr

	cUUID, err := windows.ByteSliceFromString(uuid)
	if err != nil {
		return nil, errs.Wrap(err, "uuid contains terminator sign")

	if len(cUUID) > deviceUUIDBufferSize {
		return nil, errs.New("uuid string too long")

	err = runner.callProc("nvmlDeviceGetHandleByUUID",
	if err != nil {
		return nil, errs.Wrap(err, "failed while calling procedure")

	return &NVMLDevice{
		handle: deviceHandle,
		runner: runner,
	}, nil

// ShutdownNVML is a wrapper function to cleanly shut down NVML.
func (runner *NVMLRunner) ShutdownNVML() error {
	err := runner.callProc("nvmlShutdown")
	if err != nil {
		return errs.Wrap(err, "failed while calling procedure")

	return nil

// Close releases the resources associated with the loaded DLL in the Runner.
func (runner *NVMLRunner) Close() error {
	callErr := runner.dll.Release()

	err := checkCallError(callErr)
	if err != nil {
		return err

	return nil

// GetTemperature retrieves the temperature of the NVIDIA device using the default sensor.
func (device *NVMLDevice) GetTemperature() (int, error) {
	var temperature C.uint

	err := device.runner.callProc("nvmlDeviceGetTemperature",
	if err != nil {
		return 0, errs.Wrap(err, "failed while calling procedure")

	return int(temperature), nil

// GetName retrieves the name of the NVIDIA device.
func (device *NVMLDevice) GetName() (string, error) {
	var name [deviceNameBufferSize]byte

	err := device.runner.callProc("nvmlDeviceGetName",
	if err != nil {
		return "", errs.Wrap(err, "failed while calling procedure")

	return windows.ByteSliceToString(name[:]), nil

// GetMemoryInfoV2 retrieves detailed memory information for the NVIDIA device using the NVML v2 interface.
func (device *NVMLDevice) GetMemoryInfoV2() (*MemoryInfoV2, error) {
	var nvmlMemInfo C.nvmlMemory_v2_t

	// there is a macto in nvml.h called NVML_STRUCT_VERSION,
	// it does not work for CGO, so doing manually.
	nvmlMemInfo.version = C.uint(unsafe.Sizeof(nvmlMemInfo)) | (2 << 24)

	err := device.runner.callProc("nvmlDeviceGetMemoryInfo_v2",
	if err != nil {
		return nil, errs.Wrap(err, "failed while calling procedure")

	return &MemoryInfoV2{
		Total:    uint64(nvmlMemInfo.total),
		Reserved: uint64(nvmlMemInfo.reserved),
		Free:     uint64(nvmlMemInfo.free),
		Used:     uint64(nvmlMemInfo.used),
	}, nil

// GetMemoryInfo retrieves memory information for the NVIDIA device.
func (device *NVMLDevice) GetMemoryInfo() (*MemoryInfo, error) {
	var memInfo C.nvmlMemory_t

	err := device.runner.callProc("nvmlDeviceGetMemoryInfo",
	if err != nil {
		return nil, errs.Wrap(err, "failed while calling procedure")

	return &MemoryInfo{
		Total: uint64(memInfo.total),
		Free:  uint64(memInfo.free),
		Used:  uint64(memInfo.used),
	}, nil

// GetBAR1MemoryInfo retrieves BAR1 memory information for the NVIDIA device.
func (device *NVMLDevice) GetBAR1MemoryInfo() (*MemoryInfo, error) {
	var memInfo C.nvmlBAR1Memory_t

	err := device.runner.callProc("nvmlDeviceGetBAR1MemoryInfo",
	if err != nil {
		return nil, errs.Wrap(err, "failed while calling procedure")

	return &MemoryInfo{
		Total: uint64(memInfo.bar1Total),
		Free:  uint64(memInfo.bar1Free),
		Used:  uint64(memInfo.bar1Used),
	}, nil

// GetFanSpeed retrieves the current fan speed of the NVIDIA device as a percentage of its maximum speed.
func (device *NVMLDevice) GetFanSpeed() (uint, error) {
	var speed C.uint

	err := device.runner.callProc("nvmlDeviceGetFanSpeed",
	if err != nil {
		return 0, errs.Wrap(err, "failed while calling procedure")

	return uint(speed), nil

// GetPCIeThroughput retrieves the PCIe throughput for the NVIDIA device, based on the specified metric type.
func (device *NVMLDevice) GetPCIeThroughput(metricType PcieMetricType) (uint, error) {
	var throughput C.uint

	err := device.runner.callProc("nvmlDeviceGetPcieThroughput",
	if err != nil {
		return 0, errs.Wrap(err, "failed while calling procedure")

	return uint(throughput), nil

// GetUUID retrieves the UUID of the NVIDIA device.
func (device *NVMLDevice) GetUUID() (string, error) {
	var uuid [deviceUUIDBufferSize]byte

	err := device.runner.callProc("nvmlDeviceGetUUID",
	if err != nil {
		return "", errs.Wrap(err, "failed while calling procedure")

	return windows.ByteSliceToString(uuid[:]), nil

// GetSerial retrieves the serial number of the NVIDIA device.
func (device *NVMLDevice) GetSerial() (string, error) {
	var serial [deviceSerialBufferSize]byte

	err := device.runner.callProc("nvmlDeviceGetSerial",
	if err != nil {
		return "", errs.Wrap(err, "failed while calling procedure")

	return windows.ByteSliceToString(serial[:]), nil

// GetEncoderUtilization retrieves the encoder utilization statistics for the device.
// It returns the following values:
//   - utilization: the percentage of time over the past sampling period during which the encoder was active.
//   - samplingPeriodUs: the sampling period duration in microseconds,
//     indicating how long the utilization metric was measured.
func (device *NVMLDevice) GetEncoderUtilization() (uint, uint, error) {
	var (
		utilization      C.uint
		samplingPeriodUs C.uint

	err := device.runner.callProc("nvmlDeviceGetEncoderUtilization",
	if err != nil {
		return 0, 0, errs.Wrap(err, "failed while calling procedure")

	return uint(utilization), uint(samplingPeriodUs), nil

// GetDecoderUtilization retrieves the decoder utilization statistics for the device.
// It returns the following values:
//   - utilization: the percentage of time over the past sampling period during which the decoder was active.
//   - samplingPeriodUs: the sampling period duration in microseconds, indicating how long the utilization
//     metric was measured.
func (device *NVMLDevice) GetDecoderUtilization() (uint, uint, error) {
	var (
		utilization      C.uint
		samplingPeriodUs C.uint

	err := device.runner.callProc("nvmlDeviceGetDecoderUtilization",
	if err != nil {
		return 0, 0, errs.Wrap(err, "failed while calling procedure")

	return uint(utilization), uint(samplingPeriodUs), nil

// GetMemoryErrorCounter retrieves the ECC memory error count for the specified error type,
// memory location, and counter type.
func (device *NVMLDevice) GetMemoryErrorCounter(
	errorType MemoryErrorType,
	memoryLocation MemoryLocation,
	counterType EccCounterType) (uint64, error) {
	var errorCount C.ulonglong

	err := device.runner.callProc("nvmlDeviceGetMemoryErrorCounter",
	if err != nil {
		return 0, errs.Wrap(err, "failed while calling procedure")

	return uint64(errorCount), nil

// GetTotalEnergyConsumption retrieves the total energy consumption of the NVIDIA device in millijoules.
func (device *NVMLDevice) GetTotalEnergyConsumption() (uint64, error) {
	var energy C.ulonglong

	err := device.runner.callProc("nvmlDeviceGetTotalEnergyConsumption",
	if err != nil {
		return 0, errs.Wrap(err, "failed while calling procedure")

	return uint64(energy), nil

// GetPerformanceState retrieves the performance state (P-state) of the NVIDIA device.
func (device *NVMLDevice) GetPerformanceState() (uint, error) {
	var perfState C.uint

	err := device.runner.callProc("nvmlDeviceGetPerformanceState",
	if err != nil {
		return 0, errs.Wrap(err, "failed while calling procedure")

	return uint(perfState), nil

// GetClockInfo retrieves the clock rate for the specified clock type of the NVIDIA device.
func (device *NVMLDevice) GetClockInfo(clockType ClockType) (uint, error) {
	var clockRate C.uint

	err := device.runner.callProc("nvmlDeviceGetClockInfo",
	if err != nil {
		return 0, errs.Wrap(err, "failed while calling procedure")

	return uint(clockRate), nil

// GetPowerUsage retrieves the power usage of the NVIDIA device in milliwatts.
func (device *NVMLDevice) GetPowerUsage() (uint, error) {
	var power C.uint

	err := device.runner.callProc("nvmlDeviceGetPowerUsage",
	if err != nil {
		return 0, errs.Wrap(err, "failed while calling procedure")

	return uint(power), nil

// GetEncoderStats retrieves statistics related to the encoder activity on the device.
// It returns the following statistics:
//   - sessionCount: the number of active encoder sessions.
//   - averageFps: the average frames per second across all active encoder sessions.
//   - averageLatency: the average latency (in milliseconds) across all active encoder sessions.
func (device *NVMLDevice) GetEncoderStats() (uint, uint, uint, error) {
	var (
		sessionCount   C.uint
		averageFps     C.uint
		averageLatency C.uint

	err := device.runner.callProc("nvmlDeviceGetEncoderStats",
	if err != nil {
		return 0, 0, 0, errs.Wrap(err, "failed while calling procedure")

	return uint(sessionCount), uint(averageFps), uint(averageLatency), nil

// GetPowerManagementLimit retrieves the power management limit of the NVIDIA device in milliwatts.
func (device *NVMLDevice) GetPowerManagementLimit() (uint, error) {
	var powerLimit C.uint

	err := device.runner.callProc("nvmlDeviceGetPowerManagementLimit",
	if err != nil {
		return 0, errs.Wrap(err, "failed while calling procedure")

	return uint(powerLimit), nil

// GetEccMode retrieves the current and pending ECC (Error Correction Code) modes for the device.
// ECC mode indicates whether error correction is enabled or disabled on the device.
// Returns:
//   - `currentEnabled` (bool): `true` if ECC is currently enabled, `false` if disabled.
//   - `pendingEnabled` (bool): `true` if ECC will be enabled on the next reboot, `false` if it will be disabled.
//   - `error` (error): An error if the function fails to retrieve the ECC mode, otherwise `nil`.
func (device *NVMLDevice) GetEccMode() (bool, bool, error) {
	var (
		currentMode C.uint
		pendingMode C.uint

	err := device.runner.callProc("nvmlDeviceGetEccMode",
	if err != nil {
		return false, false, errs.Wrap(err, "failed while calling procedure")

	currentEnabled := currentMode == C.NVML_FEATURE_ENABLED
	pendingEnabled := pendingMode == C.NVML_FEATURE_ENABLED

	return currentEnabled, pendingEnabled, nil

// GetUtilizationRates retrieves the GPU and memory utilization rates for the device.
// The GPU utilization represents the percentage of time over the past sampling period
// that the GPU was actively processing, while the memory utilization indicates the
// percentage of time the memory was being accessed.
// Returns:
//   - gpuUtilization (uint): The GPU utilization rate as a percentage (0-100).
//   - memoryUtilization (uint): The memory utilization rate as a percentage (0-100).
//   - error: An error object if there is a failure in verifying the NVML symbol existence
//     or in retrieving the utilization rates from NVML.
func (device *NVMLDevice) GetUtilizationRates() (uint, uint, error) {
	var util C.nvmlUtilization_t

	err := device.runner.callProc("nvmlDeviceGetUtilizationRates",
	if err != nil {
		return 0, 0, errs.Wrap(err, "failed while calling procedure")

	return uint(util.gpu), uint(util.memory), nil

// getProc retrieves the specified procedure from the Windows DLL.
// If the procedure is already cached in procList, it returns the cached
// *windows.Proc. Otherwise, it finds the procedure using dll.FindProc, caches it,
// and returns it.
func (runner *NVMLRunner) getProc(procName string) (*windows.Proc, error) {
	defer runner.procListMux.Unlock()

	proc, ok := runner.procList[procName]
	if ok {
		return proc, nil

	proc, err := runner.dll.FindProc(procName)
	if err != nil {
		return nil, errs.Wrapf(ErrFunctionNotFound, "failed to get procedure %q", procName)

	runner.procList[procName] = proc

	return proc, nil

func (runner *NVMLRunner) callProc(procName string, args ...uintptr) error {
	proc, err := runner.getProc(procName)
	if err != nil {
		return errs.Wrap(err, "failed getting procedure")

	result, _, callErr := proc.Call(args...)
	err = checkCallError(callErr)
	if err != nil {
		return errs.Wrap(err, "failed making syscall")

	err = mapNVMLResultToError(int(result))
	if err != nil {
		return errs.Wrap(err, "NVML returned error")

	return nil

// checkCallError checks for and interprets errors returned from system calls.
// Parameters:
// - callErr: The error returned from a system call, which may be of type syscall.Errno.
func checkCallError(callErr error) error {
	if callErr == nil {
		return nil

	var errno syscall.Errno
	if errors.As(callErr, &errno) {
		if errno != windows.ERROR_SUCCESS {
			return errs.Errorf("failed with error code %d: %v", errno, callErr)

		return nil

	return callErr