/*
** Copyright (C) 2001-2024 Zabbix SIA
**
** Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
** documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
** rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
** permit persons to whom the Software is furnished to do so, subject to the following conditions:
**
** The above copyright notice and this permission notice shall be included in all copies or substantial portions
** of the Software.
**
** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
** WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
** COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
** TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
** SOFTWARE.
**/

package handlers

import (
	"bytes"
	"context"
	"errors"
	"io"
	"net/http"
	"testing"

	"github.com/google/go-cmp/cmp"
	"golang.zabbix.com/plugin/example/plugin/params"
	"golang.zabbix.com/sdk/errs"
)

var _ systemCalls = sysCallsMock{}

type errReader bytes.Reader

type sysCallsMock struct{}

type mockTransport struct {
	resp *http.Response
	err  error
}

func (errReader) Read(p []byte) (int, error) {
	return 0, errors.New("reader_fail")
}

func (sysCallsMock) environ() []string {
	return []string{
		"foo=bar", "bar=foo", "abc=def",
	}
}

func (sysCallsMock) lookupEnv(key string) (string, bool) {
	switch key {
	case "foo":
		return "bar", true
	case "bar":
		return "foo", true
	case "abc":
		return "def", true
	default:
		return "", false
	}
}

func (t *mockTransport) RoundTrip(*http.Request) (*http.Response, error) {
	if t.err != nil {
		return nil, t.err
	}

	return t.resp, nil
}

func TestHandler_MyIP(t *testing.T) {
	t.Parallel()

	type args struct {
		clientBody io.ReadCloser
		clientErr  error
	}

	tests := []struct {
		name    string
		args    args
		want    any
		wantErr bool
	}{
		{
			"+valid",
			args{
				clientBody: io.NopCloser(
					bytes.NewReader(
						[]byte("127.0.0.1"),
					),
				),
				clientErr: nil,
			},
			"127.0.0.1",
			false,
		},
		{
			"-clientErr",
			args{
				clientErr: errs.New("fail"),
			},
			nil,
			true,
		},
		{
			"-bodyReadErr",
			args{
				clientBody: io.NopCloser(errReader{}),
				clientErr:  errs.New("fail"),
			},
			nil,
			true,
		},
	}
	for _, tt := range tests {
		tt := tt
		t.Run(tt.name, func(t *testing.T) {
			t.Parallel()

			c := &http.Client{
				Transport: &mockTransport{
					resp: &http.Response{
						Body: tt.args.clientBody,
					},
					err: tt.args.clientErr,
				},
			}

			h := &Handler{
				client: c,
			}

			got, err := h.MyIP(context.Background(), nil)
			if (err != nil) != tt.wantErr {
				t.Fatalf("Handler.MyIP() error = %v, wantErr %v", err, tt.wantErr)
			}

			if diff := cmp.Diff(tt.want, got); diff != "" {
				t.Fatalf("Handler.MyIP() = %s", diff)
			}
		})
	}
}

func TestHandler_GoEnvironment(t *testing.T) {
	t.Parallel()

	type args struct {
		extraParams []string
	}

	tests := []struct {
		name    string
		args    args
		want    any
		wantErr bool
	}{
		{
			"+valid",
			args{
				[]string{"foo"},
			},
			map[string]string{"foo": "bar"},
			false,
		},
		{
			"+multiple",
			args{[]string{"foo", "abc"}},
			map[string]string{
				"foo": "bar",
				"abc": "def",
			},
			false,
		},
		{
			"+all",
			args{nil},
			map[string]string{
				"abc": "def",
				"bar": "foo",
				"foo": "bar",
			},
			false,
		},
		{
			"-notFound",
			args{[]string{"test"}},
			nil,
			true,
		},
	}
	for _, tt := range tests {
		tt := tt
		t.Run(tt.name, func(t *testing.T) {
			t.Parallel()

			h := &Handler{
				sysCalls: sysCallsMock{},
			}

			got, err := h.GoEnvironment(context.Background(), nil, tt.args.extraParams...)
			if (err != nil) != tt.wantErr {
				t.Fatalf("Handler.GoEnvironment() error = %v, wantErr %v", err, tt.wantErr)
			}

			if diff := cmp.Diff(tt.want, got); diff != "" {
				t.Fatalf("Handler.GoEnvironment() = %s", diff)
			}
		})
	}
}

func TestNew(t *testing.T) {
	t.Parallel()

	tests := []struct {
		name string
		want *Handler
	}{
		{
			"+valid",
			&Handler{
				client:   http.DefaultClient,
				sysCalls: osWrapper{},
			},
		},
	}
	for _, tt := range tests {
		tt := tt
		t.Run(tt.name, func(t *testing.T) {
			t.Parallel()

			got := New()
			if diff := cmp.Diff(tt.want, got, cmp.AllowUnexported(Handler{})); diff != "" {
				t.Fatalf("New() = %s", diff)
			}
		})
	}
}

func TestWithJSONResponse(t *testing.T) {
	t.Parallel()

	type args struct {
		value  any
		gotErr bool
	}

	tests := []struct {
		name    string
		args    args
		want    any
		wantErr bool
	}{
		{
			"+valid",
			args{
				value: "foobar",
			},
			`"foobar"`,
			false,
		},
		{
			"+jsonObject",
			args{
				value: map[string]string{
					"foo":  "bar",
					"test": "true",
				},
			},
			`{"foo":"bar","test":"true"}`,
			false,
		},
		{
			"-jsonMarshalErr",
			args{
				value: map[struct{ test string }]string{
					{test: "1"}: "bar",
					{test: "2"}: "foo",
				},
			},
			nil,
			true,
		},
		{
			"-handlerErr",
			args{gotErr: true},
			nil,
			true,
		},
	}

	for _, tt := range tests {
		tt := tt
		t.Run(tt.name, func(t *testing.T) {
			t.Parallel()

			got, err := WithJSONResponse(
				func(ctx context.Context, metricParams map[string]string, extraParams ...string) (any, error) {
					if tt.args.gotErr {
						return nil, errs.New("fail")
					}

					return tt.args.value, nil
				},
			)(context.Background(), nil)

			if (err != nil) != tt.wantErr {
				t.Fatalf("WithJSONResponse() error = %v, wantErr %v", err, tt.wantErr)
			}

			if diff := cmp.Diff(tt.want, got); diff != "" {
				t.Fatalf("WithJSONResponse() = %s", diff)
			}
		})
	}
}

func TestWithCredentialValidation(t *testing.T) {
	t.Parallel()

	type args struct {
		value        any
		metricParams map[string]string
	}

	tests := []struct {
		name    string
		args    args
		want    any
		wantErr bool
	}{
		{
			"+valid",
			args{
				value: "foobar",
				metricParams: map[string]string{
					params.UsernameParameterName: "Zabbix",
				},
			},
			"foobar",
			false,
		},
		{
			"+definedPassword",
			args{
				value: "foobar",
				metricParams: map[string]string{
					params.UsernameParameterName: "Admin",
					params.PasswordParameterName: "Foo",
				},
			},
			"foobar",
			false,
		},
		{
			"-invalidUsername",
			args{
				value: "foobar",
				metricParams: map[string]string{
					params.UsernameParameterName: "FooBar",
				},
			},
			nil,
			true,
		},
		{
			"-invalidPassword",
			args{
				value: "foobar",
				metricParams: map[string]string{
					params.UsernameParameterName: "Admin",
					params.PasswordParameterName: "",
				},
			},
			nil,
			true,
		},
	}

	for _, tt := range tests {
		tt := tt
		t.Run(tt.name, func(t *testing.T) {
			t.Parallel()

			got, err := WithCredentialValidation(
				func(ctx context.Context, metricParams map[string]string, extraParams ...string) (any, error) {
					return tt.args.value, nil
				},
			)(context.Background(), tt.args.metricParams)

			if (err != nil) != tt.wantErr {
				t.Fatalf("WithCredentialValidation() error = %v, wantErr %v", err, tt.wantErr)
			}

			if diff := cmp.Diff(tt.want, got); diff != "" {
				t.Fatalf("WithCredentialValidation() = %s", diff)
			}
		})
	}
}

func Test_getAll(t *testing.T) {
	t.Parallel()

	type args struct {
		env []string
	}

	tests := []struct {
		name    string
		args    args
		want    map[string]string
		wantErr bool
	}{
		{
			"+valid",
			args{[]string{"foo=bar", "abc=def"}},
			map[string]string{"foo": "bar", "abc": "def"},
			false,
		},
		{
			"+single",
			args{[]string{"foo=bar"}},
			map[string]string{"foo": "bar"},
			false,
		},
		{
			"-invalidVar",
			args{[]string{"foo:bar"}},
			nil,
			true,
		},
	}

	for _, tt := range tests {
		tt := tt
		t.Run(tt.name, func(t *testing.T) {
			t.Parallel()

			got, err := getAll(tt.args.env)
			if (err != nil) != tt.wantErr {
				t.Fatalf("getAll() error = %v, wantErr %v", err, tt.wantErr)
			}

			if diff := cmp.Diff(tt.want, got); diff != "" {
				t.Fatalf("getAll() = %s", diff)
			}
		})
	}
}