Fix: file preview URL in share page should not be accessed directly
This commit is contained in:
parent
79f898e0a9
commit
32c0232105
3 changed files with 65 additions and 1 deletions
|
@ -2,6 +2,7 @@ package middleware
|
|||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/pkg/conf"
|
||||
"github.com/HFO4/cloudreve/pkg/serializer"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-contrib/sessions/memstore"
|
||||
|
@ -32,3 +33,24 @@ func Session(secret string) gin.HandlerFunc {
|
|||
Store.Options(sessions.Options{HttpOnly: true, MaxAge: 7 * 86400, Path: "/"})
|
||||
return sessions.Sessions("cloudreve-session", Store)
|
||||
}
|
||||
|
||||
// CSRFInit 初始化CSRF标记
|
||||
func CSRFInit() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
util.SetSession(c, map[string]interface{}{"CSRF": true})
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// CSRFCheck 检查CSRF标记
|
||||
func CSRFCheck() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if check, ok := util.GetSession(c, "CSRF").(bool); ok && check {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, "来源非法", nil))
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,8 +2,11 @@ package middleware
|
|||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/pkg/conf"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
@ -28,3 +31,41 @@ func TestSession(t *testing.T) {
|
|||
func emptyFunc() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {}
|
||||
}
|
||||
|
||||
func TestCSRFInit(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
sessionFunc := Session("233")
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
||||
sessionFunc(c)
|
||||
CSRFInit()(c)
|
||||
asserts.True(util.GetSession(c, "CSRF").(bool))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRFCheck(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
sessionFunc := Session("233")
|
||||
|
||||
// 通过检查
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
||||
sessionFunc(c)
|
||||
CSRFInit()(c)
|
||||
CSRFCheck()(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
// 未通过检查
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
||||
sessionFunc(c)
|
||||
CSRFCheck()(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -107,7 +107,7 @@ func InitMasterRouter() *gin.Engine {
|
|||
// 验证码
|
||||
site.GET("captcha", controllers.Captcha)
|
||||
// 站点全局配置
|
||||
site.GET("config", controllers.SiteConfig)
|
||||
site.GET("config", middleware.CSRFInit(), controllers.SiteConfig)
|
||||
}
|
||||
|
||||
// 用户相关路由
|
||||
|
@ -231,6 +231,7 @@ func InitMasterRouter() *gin.Engine {
|
|||
)
|
||||
// 预览分享文件
|
||||
share.GET("preview/:id",
|
||||
middleware.CSRFCheck(),
|
||||
middleware.CheckShareUnlocked(),
|
||||
middleware.ShareCanPreview(),
|
||||
middleware.BeforeShareDownload(),
|
||||
|
|
Loading…
Add table
Reference in a new issue