package errors

import (
	"bytes"
	"errors"
	"fmt"
	"log"
	"net/http"
	"net/http/httptest"
	"os"
	"path/filepath"
	"strconv"
	"strings"
	"testing"

	"github.com/mholt/caddy/caddyhttp/httpserver"
)

func TestErrors(t *testing.T) {
	// create a temporary page
	const content = "This is a error page"

	path, err := createErrorPageFile("errors_test.html", content)
	if err != nil {
		t.Fatal(err)
	}
	defer os.Remove(path)

	buf := bytes.Buffer{}
	em := ErrorHandler{
		ErrorPages: map[int]string{
			http.StatusNotFound:  path,
			http.StatusForbidden: "not_exist_file",
		},
		Log: log.New(&buf, "", 0),
	}
	_, notExistErr := os.Open("not_exist_file")

	testErr := errors.New("test error")
	tests := []struct {
		next         httpserver.Handler
		expectedCode int
		expectedBody string
		expectedLog  string
		expectedErr  error
	}{
		{
			next:         genErrorHandler(http.StatusOK, nil, "normal"),
			expectedCode: http.StatusOK,
			expectedBody: "normal",
			expectedLog:  "",
			expectedErr:  nil,
		},
		{
			next:         genErrorHandler(http.StatusMovedPermanently, testErr, ""),
			expectedCode: http.StatusMovedPermanently,
			expectedBody: "",
			expectedLog:  fmt.Sprintf("[ERROR %d %s] %v\n", http.StatusMovedPermanently, "/", testErr),
			expectedErr:  testErr,
		},
		{
			next:         genErrorHandler(http.StatusBadRequest, nil, ""),
			expectedCode: 0,
			expectedBody: fmt.Sprintf("%d %s\n", http.StatusBadRequest,
				http.StatusText(http.StatusBadRequest)),
			expectedLog: "",
			expectedErr: nil,
		},
		{
			next:         genErrorHandler(http.StatusNotFound, nil, ""),
			expectedCode: 0,
			expectedBody: content,
			expectedLog:  "",
			expectedErr:  nil,
		},
		{
			next:         genErrorHandler(http.StatusForbidden, nil, ""),
			expectedCode: 0,
			expectedBody: fmt.Sprintf("%d %s\n", http.StatusForbidden,
				http.StatusText(http.StatusForbidden)),
			expectedLog: fmt.Sprintf("[NOTICE %d /] could not load error page: %v\n",
				http.StatusForbidden, notExistErr),
			expectedErr: nil,
		},
	}

	req, err := http.NewRequest("GET", "/", nil)
	if err != nil {
		t.Fatal(err)
	}
	for i, test := range tests {
		em.Next = test.next
		buf.Reset()
		rec := httptest.NewRecorder()
		code, err := em.ServeHTTP(rec, req)

		if err != test.expectedErr {
			t.Errorf("Test %d: Expected error %v, but got %v",
				i, test.expectedErr, err)
		}
		if code != test.expectedCode {
			t.Errorf("Test %d: Expected status code %d, but got %d",
				i, test.expectedCode, code)
		}
		if body := rec.Body.String(); body != test.expectedBody {
			t.Errorf("Test %d: Expected body %q, but got %q",
				i, test.expectedBody, body)
		}
		if log := buf.String(); !strings.Contains(log, test.expectedLog) {
			t.Errorf("Test %d: Expected log %q, but got %q",
				i, test.expectedLog, log)
		}
	}
}

func TestVisibleErrorWithPanic(t *testing.T) {
	const panicMsg = "I'm a panic"
	eh := ErrorHandler{
		ErrorPages: make(map[int]string),
		Debug:      true,
		Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
			panic(panicMsg)
		}),
	}

	req, err := http.NewRequest("GET", "/", nil)
	if err != nil {
		t.Fatal(err)
	}
	rec := httptest.NewRecorder()

	code, err := eh.ServeHTTP(rec, req)

	if code != 0 {
		t.Errorf("Expected error handler to return 0 (it should write to response), got status %d", code)
	}
	if err != nil {
		t.Errorf("Expected error handler to return nil error (it should panic!), but got '%v'", err)
	}

	body := rec.Body.String()

	if !strings.Contains(body, "[PANIC /] caddyhttp/errors/errors_test.go") {
		t.Errorf("Expected response body to contain error log line, but it didn't:\n%s", body)
	}
	if !strings.Contains(body, panicMsg) {
		t.Errorf("Expected response body to contain panic message, but it didn't:\n%s", body)
	}
	if len(body) < 500 {
		t.Errorf("Expected response body to contain stack trace, but it was too short: len=%d", len(body))
	}
}

func TestGenericErrorPage(t *testing.T) {
	// create temporary generic error page
	const genericErrorContent = "This is a generic error page"

	genericErrorPagePath, err := createErrorPageFile("generic_error_test.html", genericErrorContent)
	if err != nil {
		t.Fatal(err)
	}
	defer os.Remove(genericErrorPagePath)

	// create temporary error page
	const notFoundErrorContent = "This is a error page"

	notFoundErrorPagePath, err := createErrorPageFile("not_found.html", notFoundErrorContent)
	if err != nil {
		t.Fatal(err)
	}
	defer os.Remove(notFoundErrorPagePath)

	buf := bytes.Buffer{}
	em := ErrorHandler{
		GenericErrorPage: genericErrorPagePath,
		ErrorPages: map[int]string{
			http.StatusNotFound: notFoundErrorPagePath,
		},
		Log: log.New(&buf, "", 0),
	}

	tests := []struct {
		next         httpserver.Handler
		expectedCode int
		expectedBody string
		expectedLog  string
		expectedErr  error
	}{
		{
			next:         genErrorHandler(http.StatusNotFound, nil, ""),
			expectedCode: 0,
			expectedBody: notFoundErrorContent,
			expectedLog:  "",
			expectedErr:  nil,
		},
		{
			next:         genErrorHandler(http.StatusInternalServerError, nil, ""),
			expectedCode: 0,
			expectedBody: genericErrorContent,
			expectedLog:  "",
			expectedErr:  nil,
		},
	}

	req, err := http.NewRequest("GET", "/", nil)
	if err != nil {
		t.Fatal(err)
	}

	for i, test := range tests {
		em.Next = test.next
		buf.Reset()
		rec := httptest.NewRecorder()
		code, err := em.ServeHTTP(rec, req)

		if err != test.expectedErr {
			t.Errorf("Test %d: Expected error %v, but got %v",
				i, test.expectedErr, err)
		}
		if code != test.expectedCode {
			t.Errorf("Test %d: Expected status code %d, but got %d",
				i, test.expectedCode, code)
		}
		if body := rec.Body.String(); body != test.expectedBody {
			t.Errorf("Test %d: Expected body %q, but got %q",
				i, test.expectedBody, body)
		}
		if log := buf.String(); !strings.Contains(log, test.expectedLog) {
			t.Errorf("Test %d: Expected log %q, but got %q",
				i, test.expectedLog, log)
		}
	}
}

func genErrorHandler(status int, err error, body string) httpserver.Handler {
	return httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
		if len(body) > 0 {
			w.Header().Set("Content-Length", strconv.Itoa(len(body)))
			fmt.Fprint(w, body)
		}
		return status, err
	})
}

func createErrorPageFile(name string, content string) (string, error) {
	errorPageFilePath := filepath.Join(os.TempDir(), name)
	f, err := os.Create(errorPageFilePath)
	if err != nil {
		return "", err
	}

	_, err = f.WriteString(content)
	if err != nil {
		return "", err
	}
	f.Close()

	return errorPageFilePath, nil
}