diff --git a/middleware/auth.go b/middleware/auth.go index 2f09cf3..82533a3 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -61,7 +61,6 @@ func AuthRequired() gin.HandlerFunc { } // WebDAVAuth 验证WebDAV登录及权限 -// TODO 测试 func WebDAVAuth() gin.HandlerFunc { return func(c *gin.Context) { // OPTIONS 请求不需要鉴权,否则Windows10下无法保存文档 diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index c5ac39b..1dc83b4 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -103,7 +103,7 @@ func Init() { if conf.SystemConfig.Mode == "master" { secretKey = model.GetSettingByName("secret_key") } else { - secretKey = conf.SystemConfig.SlaveSecret + secretKey = conf.SlaveConfig.Secret if secretKey == "" { util.Log().Panic("未指定 SlaveSecret,请前往配置文件中指定") } diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index 50ac8ce..9a77174 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -22,7 +22,13 @@ type system struct { Listen string `validate:"required"` Debug bool SessionSecret string - SlaveSecret string `validate:"omitempty,gte=64"` +} + +// slave 作为slave存储端配置 +type slave struct { + Secret string `validate:"omitempty,gte=64"` + CallbackTimeout int `validate:"omitempty,gte=1"` + SignatureTTL int `validate:"omitempty,gte=1"` } // captcha 验证码配置 @@ -82,6 +88,7 @@ func Init(path string) { "Redis": RedisConfig, "Thumbnail": ThumbConfig, "CORS": CORSConfig, + "Slave": SlaveConfig, } for sectionName, sectionStruct := range sections { err = mapSection(sectionName, sectionStruct) diff --git a/pkg/conf/defaults.go b/pkg/conf/defaults.go index bc27782..e4639ed 100644 --- a/pkg/conf/defaults.go +++ b/pkg/conf/defaults.go @@ -40,13 +40,20 @@ var CaptchaConfig = &captcha{ var CORSConfig = &cors{ AllowOrigins: []string{"UNSET"}, AllowMethods: []string{"PUT", "POST", "GET", "OPTIONS"}, - AllowHeaders: []string{"Cookie", "Content-Length", "Content-Type", "X-Path", "X-FileName"}, + AllowHeaders: []string{"Cookie", "X-Policy", "Authorization", "Content-Length", "Content-Type", "X-Path", "X-FileName"}, AllowCredentials: false, ExposeHeaders: nil, } +// ThumbConfig 缩略图配置 var ThumbConfig = &thumb{ MaxWidth: 400, MaxHeight: 300, FileSuffix: "._thumb", } + +// SlaveConfig 从机配置 +var SlaveConfig = &slave{ + CallbackTimeout: 20, + SignatureTTL: 60, +} diff --git a/pkg/filesystem/hooks.go b/pkg/filesystem/hooks.go index c0e9d53..e666471 100644 --- a/pkg/filesystem/hooks.go +++ b/pkg/filesystem/hooks.go @@ -6,6 +6,7 @@ import ( model "github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/pkg/conf" "github.com/HFO4/cloudreve/pkg/filesystem/fsctx" + "github.com/HFO4/cloudreve/pkg/request" "github.com/HFO4/cloudreve/pkg/serializer" "github.com/HFO4/cloudreve/pkg/util" "io/ioutil" @@ -212,6 +213,8 @@ func GenericAfterUpdate(ctx context.Context, fs *FileSystem) error { // TODO 测试 func SlaveAfterUpload(ctx context.Context, fs *FileSystem) error { fileHeader := ctx.Value(fsctx.FileHeaderCtx).(FileHeader) + policy := ctx.Value(fsctx.UploadPolicyCtx).(serializer.UploadPolicy) + // 构造一个model.File,用于生成缩略图 file := model.File{ Name: fileHeader.GetFileName(), @@ -219,9 +222,13 @@ func SlaveAfterUpload(ctx context.Context, fs *FileSystem) error { } fs.GenerateThumbnail(ctx, &file) - // TODO 发送回调请求 - - return nil + // 发送回调请求 + callbackBody := serializer.UploadCallback{ + Name: file.Name, + SourceName: file.SourceName, + PicInfo: file.PicInfo, + } + return request.RemoteCallback(policy.CallbackURL, callbackBody) } // GenericAfterUpload 文件上传完成后,包含数据库操作 diff --git a/pkg/request/callback.go b/pkg/request/callback.go new file mode 100644 index 0000000..24ac229 --- /dev/null +++ b/pkg/request/callback.go @@ -0,0 +1,59 @@ +package request + +import ( + "bytes" + "encoding/json" + "errors" + "github.com/HFO4/cloudreve/pkg/auth" + "github.com/HFO4/cloudreve/pkg/conf" + "github.com/HFO4/cloudreve/pkg/serializer" + "github.com/HFO4/cloudreve/pkg/util" + "io/ioutil" + "time" +) + +// RemoteCallback 发送远程存储策略上传回调请求 +func RemoteCallback(url string, body serializer.UploadCallback) error { + callbackBody, err := json.Marshal(body) + if err != nil { + return serializer.NewError(serializer.CodeCallbackError, "无法编码回调正文", err) + } + + resp := generalClient.Request( + "POST", + url, + bytes.NewReader(callbackBody), + WithTimeout(time.Duration(conf.SlaveConfig.CallbackTimeout)*time.Second), + WithCredential(auth.General, int64(conf.SlaveConfig.SignatureTTL)), + ) + + if resp.Err != nil { + return serializer.NewError(serializer.CodeCallbackError, "无法发起回调请求", resp.Err) + } + + // 检查返回HTTP状态码 + if resp.Response.StatusCode != 200 { + util.Log().Debug("服务端返回非正常状态码:%d", resp.Response.StatusCode) + return serializer.NewError(serializer.CodeCallbackError, "服务端返回非正常状态码", nil) + } + + // 检查返回API状态码 + var response serializer.Response + rawResp, err := ioutil.ReadAll(resp.Response.Body) + if err != nil { + return serializer.NewError(serializer.CodeCallbackError, "无法读取响应正文", err) + } + + // 解析回调服务端响应 + err = json.Unmarshal(rawResp, &response) + if err != nil { + util.Log().Debug("无法解析回调服务端响应:%s", string(rawResp)) + return serializer.NewError(serializer.CodeCallbackError, "无法解析服务端返回的响应", err) + } + + if response.Code != 0 { + return serializer.NewError(response.Code, response.Msg, errors.New(response.Error)) + } + + return nil +} diff --git a/pkg/request/callback_test.go b/pkg/request/callback_test.go new file mode 100644 index 0000000..b8e527c --- /dev/null +++ b/pkg/request/callback_test.go @@ -0,0 +1,136 @@ +package request + +import ( + "bytes" + "encoding/json" + "errors" + "github.com/HFO4/cloudreve/pkg/serializer" + "github.com/stretchr/testify/assert" + testMock "github.com/stretchr/testify/mock" + "io/ioutil" + "net/http" + "strings" + "testing" +) + +func TestRemoteCallback(t *testing.T) { + asserts := assert.New(t) + + // 回调成功 + { + clientMock := ClientMock{} + mockResp, _ := json.Marshal(serializer.Response{Code: 0}) + clientMock.On( + "Request", + "POST", + "http://test/test/url", + testMock.Anything, + testMock.Anything, + ).Return(Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewReader(mockResp)), + }, + }) + generalClient = clientMock + resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ + SourceName: "source", + }) + asserts.NoError(resp) + clientMock.AssertExpectations(t) + } + + // 服务端返回业务错误 + { + clientMock := ClientMock{} + mockResp, _ := json.Marshal(serializer.Response{Code: 401}) + clientMock.On( + "Request", + "POST", + "http://test/test/url", + testMock.Anything, + testMock.Anything, + ).Return(Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewReader(mockResp)), + }, + }) + generalClient = clientMock + resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ + SourceName: "source", + }) + asserts.EqualValues(401, resp.(serializer.AppError).Code) + clientMock.AssertExpectations(t) + } + + // 无法解析回调响应 + { + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + "http://test/test/url", + testMock.Anything, + testMock.Anything, + ).Return(Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader("mockResp")), + }, + }) + generalClient = clientMock + resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ + SourceName: "source", + }) + asserts.Error(resp) + clientMock.AssertExpectations(t) + } + + // HTTP状态码非200 + { + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + "http://test/test/url", + testMock.Anything, + testMock.Anything, + ).Return(Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 404, + Body: ioutil.NopCloser(strings.NewReader("mockResp")), + }, + }) + generalClient = clientMock + resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ + SourceName: "source", + }) + asserts.Error(resp) + clientMock.AssertExpectations(t) + } + + // 无法发起回调 + { + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + "http://test/test/url", + testMock.Anything, + testMock.Anything, + ).Return(Response{ + Err: errors.New("error"), + }) + generalClient = clientMock + resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ + SourceName: "source", + }) + asserts.Error(resp) + clientMock.AssertExpectations(t) + } +} diff --git a/pkg/request/request.go b/pkg/request/request.go new file mode 100644 index 0000000..fc4e9e1 --- /dev/null +++ b/pkg/request/request.go @@ -0,0 +1,106 @@ +package request + +import ( + "github.com/HFO4/cloudreve/pkg/auth" + "io" + "net/http" + "time" +) + +var generalClient Client = HTTPClient{} + +// Response 请求的响应或错误信息 +type Response struct { + Err error + Response *http.Response +} + +// Client 请求客户端 +type Client interface { + Request(method, target string, body io.Reader, opts ...Option) Response +} + +// HTTPClient 实现 Client 接口 +type HTTPClient struct { +} + +// Option 发送请求的额外设置 +type Option interface { + apply(*options) +} + +type options struct { + timeout time.Duration + header http.Header + sign auth.Auth + signTTL int64 +} + +type optionFunc func(*options) + +func (f optionFunc) apply(o *options) { + f(o) +} + +func newDefaultOption() *options { + return &options{ + header: http.Header{}, + timeout: time.Duration(30) * time.Second, + } +} + +// WithTimeout 设置请求超时 +func WithTimeout(t time.Duration) Option { + return optionFunc(func(o *options) { + o.timeout = t + }) +} + +// WithCredential 对请求进行签名 +func WithCredential(instance auth.Auth, ttl int64) Option { + return optionFunc(func(o *options) { + o.sign = instance + o.signTTL = ttl + }) +} + +// WithHeader 设置请求Header +func WithHeader(header http.Header) Option { + return optionFunc(func(o *options) { + o.header = header + }) +} + +// Request 发送HTTP请求 +func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Option) Response { + // 应用额外设置 + options := newDefaultOption() + for _, o := range opts { + o.apply(options) + } + + // 创建请求客户端 + client := &http.Client{Timeout: options.timeout} + + // 创建请求 + req, err := http.NewRequest(method, target, body) + if err != nil { + return Response{Err: err} + } + + // 添加请求header + req.Header = options.header + + // 签名请求 + if options.sign != nil { + auth.SignRequest(options.sign, req, time.Now().Unix()+options.signTTL) + } + + // 发送请求 + resp, err := client.Do(req) + if err != nil { + return Response{Err: err} + } + + return Response{Err: nil, Response: resp} +} diff --git a/pkg/request/request_test.go b/pkg/request/request_test.go new file mode 100644 index 0000000..3055172 --- /dev/null +++ b/pkg/request/request_test.go @@ -0,0 +1,62 @@ +package request + +import ( + "github.com/HFO4/cloudreve/pkg/auth" + "github.com/stretchr/testify/assert" + testMock "github.com/stretchr/testify/mock" + "io" + "net/http" + "strings" + "testing" + "time" +) + +type ClientMock struct { + testMock.Mock +} + +func (m ClientMock) Request(method, target string, body io.Reader, opts ...Option) Response { + args := m.Called(method, target, body, opts) + return args.Get(0).(Response) +} + +func TestWithTimeout(t *testing.T) { + asserts := assert.New(t) + options := newDefaultOption() + WithTimeout(time.Duration(5) * time.Second).apply(options) + asserts.Equal(time.Duration(5)*time.Second, options.timeout) +} + +func TestWithHeader(t *testing.T) { + asserts := assert.New(t) + options := newDefaultOption() + WithHeader(map[string][]string{"Origin": []string{"123"}}).apply(options) + asserts.Equal(http.Header{"Origin": []string{"123"}}, options.header) +} + +func TestWithCredential(t *testing.T) { + asserts := assert.New(t) + options := newDefaultOption() + WithCredential(auth.HMACAuth{SecretKey: []byte("123")}, 10).apply(options) + asserts.Equal(auth.HMACAuth{SecretKey: []byte("123")}, options.sign) + asserts.EqualValues(10, options.signTTL) +} + +func TestHTTPClient_Request(t *testing.T) { + asserts := assert.New(t) + client := HTTPClient{} + + // 正常 + { + resp := client.Request( + "GET", + "http://cloudreveisnotexist.com", + strings.NewReader(""), + WithTimeout(time.Duration(1)*time.Microsecond), + WithCredential(auth.HMACAuth{SecretKey: []byte("123")}, 10), + ) + asserts.Error(resp.Err) + asserts.Nil(resp.Response) + } + +} diff --git a/pkg/serializer/error.go b/pkg/serializer/error.go index f23d294..484dcdf 100644 --- a/pkg/serializer/error.go +++ b/pkg/serializer/error.go @@ -70,6 +70,8 @@ const ( CodeInternalSetting = 50005 // CodeCacheOperation 缓存操作失败 CodeCacheOperation = 50006 + // CodeCallbackError 回调失败 + CodeCallbackError = 50007 //CodeParamErr 各种奇奇怪怪的参数错误 CodeParamErr = 40001 // CodeNotSet 未定错误,后续尝试从error中获取 @@ -94,13 +96,13 @@ func ParamErr(msg string, err error) Response { // Err 通用错误处理 func Err(errCode int, msg string, err error) Response { - // 如果错误code未定,则尝试从AppError中获取 - if errCode == CodeNotSet { - if appError, ok := err.(AppError); ok { - errCode = appError.Code - err = appError.RawError - } + // 底层错误是AppError,则尝试从AppError中获取详细信息 + if appError, ok := err.(AppError); ok { + errCode = appError.Code + err = appError.RawError + msg = appError.Msg } + res := Response{ Code: errCode, Msg: msg, diff --git a/pkg/serializer/upload.go b/pkg/serializer/upload.go index 9416a13..b56dc8e 100644 --- a/pkg/serializer/upload.go +++ b/pkg/serializer/upload.go @@ -28,6 +28,13 @@ type UploadSession struct { VirtualPath string } +// UploadCallback 远程存储策略上传回调正文 +type UploadCallback struct { + Name string `json:"name"` + SourceName string `json:"source_name"` + PicInfo string `json:"pic_info"` +} + func init() { gob.Register(UploadSession{}) }