/* ** Zabbix ** Copyright (C) 2001-2024 Zabbix SIA ** ** Licensed under the Apache License, Version 2.0 (the "License"); ** you may not use this file except in compliance with the License. ** You may obtain a copy of the License at ** ** http://www.apache.org/licenses/LICENSE-2.0 ** ** Unless required by applicable law or agreed to in writing, software ** distributed under the License is distributed on an "AS IS" BASIS, ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ** See the License for the specific language governing permissions and ** limitations under the License. **/ package nvmlmock import ( "testing" "github.com/google/go-cmp/cmp" "golang.zabbix.com/plugin/nvidia/pkg/nvml" ) var ( _ nvml.Runner = (*MockRunner)(nil) _ Mocker = (*MockRunner)(nil) ) // Mocker any mock should implement to collect information if all its and submock calls are done. type Mocker interface { ExpectedCallsDone() bool SubMocks() []Mocker } // MockRunner is mock for NVML Runner. type MockRunner struct { nvml.Runner expectations []*Expectation callIdx int t *testing.T } // NewMockRunner creates new mock runner. func NewMockRunner(t *testing.T) *MockRunner { t.Helper() return &MockRunner{ t: t, expectations: []*Expectation{}, } } // ExpectCalls sets calls that are expected by mock. func (m *MockRunner) ExpectCalls(expectations ...*Expectation) *MockRunner { m.expectations = expectations return m } // SubMocks returns submocks of the mock. func (m *MockRunner) SubMocks() []Mocker { var subMocks []Mocker for _, expect := range m.expectations { for _, out := range expect.out { subMock, ok := out.(Mocker) if !ok { continue } subMocks = append(subMocks, subMock) subMocks = append(subMocks, subMock.SubMocks()...) } } return subMocks } // ExpectedCallsDone checks if all expected calls of mock and it's submocks are done. func (m *MockRunner) ExpectedCallsDone() bool { done := true expected := len(m.expectations) received := m.callIdx if expected > received { m.t.Errorf("received %d out of %d expected calls", received, expected) for _, e := range m.expectations[received:] { m.t.Errorf("Not called %q", e.funcName) } done = false } for _, sub := range m.SubMocks() { if !sub.ExpectedCallsDone() { done = false } } return done } // Init is mock function. func (m *MockRunner) Init() error { m.t.Helper() _, err := m.handleFunctionCall("Init") return err } // InitV2 is mock function. func (m *MockRunner) InitV2() error { m.t.Helper() _, err := m.handleFunctionCall("InitV2") return err } // ShutdownNVML is mock function. func (m *MockRunner) ShutdownNVML() error { m.t.Helper() _, err := m.handleFunctionCall("ShutdownNVML") return err } // GetDriverVersion is mock function. func (m *MockRunner) GetDriverVersion() (string, error) { m.t.Helper() res, err := m.handleFunctionCall("GetDriverVersion") // Type assertion to ensure res.resultArgs[0] is a string version, ok := res.out[0].(string) if !ok { m.t.Fatalf("expected string in GetDriverVersion, got %T", res.out[0]) } return version, err } // GetNVMLVersion is mock function. func (m *MockRunner) GetNVMLVersion() (string, error) { m.t.Helper() res, err := m.handleFunctionCall("GetNVMLVersion") version, ok := res.out[0].(string) if !ok { m.t.Fatalf("expected string in GetNVMLVersion, got %T", res.out[0]) } return version, err } // GetDeviceCountV2 is mock function. func (m *MockRunner) GetDeviceCountV2() (uint, error) { m.t.Helper() res, err := m.handleFunctionCall("GetDeviceCountV2") count, ok := res.out[0].(uint) if !ok { m.t.Fatalf("expected uint in GetDeviceCountV2, got %T", res.out[0]) } return count, err } // GetDeviceByIndexV2 is mock function. // //nolint:ireturn,nolintlint func (m *MockRunner) GetDeviceByIndexV2(index uint) (nvml.Device, error) { m.t.Helper() res, err := m.handleFunctionCall("GetDeviceByIndexV2", index) if res.out[0] == nil { return nil, err } device, ok := res.out[0].(*MockDevice) if !ok { m.t.Fatalf("expected *MockRunner in GetDeviceByIndexV2, got %T", res.out[0]) } device.t = m.t return device, err } // GetDeviceByUUID is mock function. // //nolint:ireturn,nolintlint func (m *MockRunner) GetDeviceByUUID(uuid string) (nvml.Device, error) { m.t.Helper() res, err := m.handleFunctionCall("GetDeviceByUUID", uuid) if res.out[0] == nil { return nil, err } mockDevice, ok := res.out[0].(*MockDevice) if ok { return mockDevice, err } device, ok := res.out[0].(nvml.Device) if !ok { m.t.Fatalf("expected %T in GetDeviceByUUID, got %T", device, res.out[0]) } return device, err } // handleFunctionCall is handler for mock function calls. // Takes in function name and any number of arguments function received. func (m *MockRunner) handleFunctionCall(name string, receivedArgs ...any) (*Expectation, error) { m.t.Helper() if m.callIdx >= len(m.expectations) { m.t.Fatalf("no more calls expected but got call for %q", name) } expect := m.expectations[m.callIdx] m.callIdx++ if expect.funcName != name { m.t.Fatalf("got call for %q while expected for %q", name, expect.funcName) } if receivedArgs == nil { receivedArgs = []any{} } // Compare expectedArgs and receivedArgs using cmp if diff := cmp.Diff(expect.args, receivedArgs); diff != "" { m.t.Fatalf("arguments mismatch in %s call %d:\nexpected: %v\nreceived: %v\ndiff: %s", name, m.callIdx, expect.args, receivedArgs, diff, ) } return expect, expect.err }