2019-11-12 15:34:54 +08:00
|
|
|
package middleware
|
|
|
|
|
|
|
|
import (
|
2019-12-04 13:49:28 +08:00
|
|
|
"github.com/HFO4/cloudreve/pkg/conf"
|
2020-03-17 15:57:38 +08:00
|
|
|
"github.com/HFO4/cloudreve/pkg/util"
|
2019-11-12 15:34:54 +08:00
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
"github.com/stretchr/testify/assert"
|
2020-03-17 15:57:38 +08:00
|
|
|
"net/http"
|
|
|
|
"net/http/httptest"
|
2019-11-12 15:34:54 +08:00
|
|
|
"testing"
|
|
|
|
)
|
|
|
|
|
|
|
|
func TestSession(t *testing.T) {
|
|
|
|
asserts := assert.New(t)
|
|
|
|
|
2019-12-04 13:49:28 +08:00
|
|
|
{
|
|
|
|
handler := Session("2333")
|
|
|
|
asserts.NotNil(handler)
|
|
|
|
asserts.NotNil(Store)
|
|
|
|
asserts.IsType(emptyFunc(), handler)
|
|
|
|
}
|
|
|
|
{
|
|
|
|
conf.RedisConfig.Server = "123"
|
|
|
|
asserts.Panics(func() {
|
|
|
|
Session("2333")
|
|
|
|
})
|
2020-03-18 09:47:06 +08:00
|
|
|
conf.RedisConfig.Server = ""
|
2019-12-04 13:49:28 +08:00
|
|
|
}
|
|
|
|
|
2019-11-12 15:34:54 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
func emptyFunc() gin.HandlerFunc {
|
|
|
|
return func(c *gin.Context) {}
|
|
|
|
}
|
2020-03-17 15:57:38 +08:00
|
|
|
|
|
|
|
func TestCSRFInit(t *testing.T) {
|
|
|
|
asserts := assert.New(t)
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
sessionFunc := Session("233")
|
|
|
|
{
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
|
|
|
sessionFunc(c)
|
|
|
|
CSRFInit()(c)
|
|
|
|
asserts.True(util.GetSession(c, "CSRF").(bool))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestCSRFCheck(t *testing.T) {
|
|
|
|
asserts := assert.New(t)
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
sessionFunc := Session("233")
|
|
|
|
|
|
|
|
// 通过检查
|
|
|
|
{
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
|
|
|
sessionFunc(c)
|
|
|
|
CSRFInit()(c)
|
|
|
|
CSRFCheck()(c)
|
|
|
|
asserts.False(c.IsAborted())
|
|
|
|
}
|
|
|
|
|
|
|
|
// 未通过检查
|
|
|
|
{
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
|
|
|
sessionFunc(c)
|
|
|
|
CSRFCheck()(c)
|
|
|
|
asserts.True(c.IsAborted())
|
|
|
|
}
|
|
|
|
}
|