From 93dc25aabb2bb47ca32317b1e8d34c8c22ad1d25 Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Sat, 4 Jan 2020 15:17:27 +0800 Subject: [PATCH] Test: remote handler get & request chan operations / Modify: GET request to remote server should return error http status code --- pkg/filesystem/archive.go | 14 ++-- pkg/filesystem/archive_test.go | 23 ++++++ pkg/filesystem/hooks_test.go | 6 +- pkg/filesystem/remote/handler.go | 13 ++- pkg/filesystem/remote/handler_test.go | 70 ++++++++++++++-- pkg/request/request.go | 38 ++++++--- pkg/request/request_test.go | 113 +++++++++++++++++++++++++- pkg/request/slave.go | 2 +- pkg/request/slave_test.go | 8 +- routers/controllers/slave.go | 6 +- 10 files changed, 249 insertions(+), 44 deletions(-) diff --git a/pkg/filesystem/archive.go b/pkg/filesystem/archive.go index 5604295..fe71993 100644 --- a/pkg/filesystem/archive.go +++ b/pkg/filesystem/archive.go @@ -3,7 +3,6 @@ package filesystem import ( "archive/zip" "context" - "errors" "fmt" model "github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/pkg/filesystem/fsctx" @@ -27,15 +26,18 @@ func (fs *FileSystem) Compress(ctx context.Context, folderIDs, fileIDs []uint) ( if err != nil && len(folders) != 0 { return "", ErrDBListObjects } + // 查找待压缩文件 files, err := model.GetFilesByIDs(fileIDs, fs.User.ID) if err != nil && len(files) != 0 { return "", ErrDBListObjects } + // 尝试获取请求上下文,以便于后续检查用户取消任务 + reqContext := ctx ginCtx, ok := ctx.Value(fsctx.GinCtx).(*gin.Context) - if !ok { - return "", errors.New("无法获取请求上下文") + if ok { + reqContext = ginCtx.Request.Context() } // 将顶级待处理对象的路径设为根路径 @@ -62,12 +64,12 @@ func (fs *FileSystem) Compress(ctx context.Context, folderIDs, fileIDs []uint) ( zipWriter := zip.NewWriter(zipFile) defer zipWriter.Close() - ctx = context.WithValue(ginCtx.Request.Context(), fsctx.UserCtx, *fs.User) + ctx = reqContext // 压缩各个目录及文件 for i := 0; i < len(folders); i++ { select { - case <-ginCtx.Request.Context().Done(): + case <-reqContext.Done(): // 取消压缩请求 fs.cancelCompress(ctx, zipWriter, zipFile, zipFilePath) return "", ErrClientCanceled @@ -78,7 +80,7 @@ func (fs *FileSystem) Compress(ctx context.Context, folderIDs, fileIDs []uint) ( } for i := 0; i < len(files); i++ { select { - case <-ginCtx.Request.Context().Done(): + case <-reqContext.Done(): // 取消压缩请求 fs.cancelCompress(ctx, zipWriter, zipFile, zipFilePath) return "", ErrClientCanceled diff --git a/pkg/filesystem/archive_test.go b/pkg/filesystem/archive_test.go index 82e5a2f..9547aa3 100644 --- a/pkg/filesystem/archive_test.go +++ b/pkg/filesystem/archive_test.go @@ -56,4 +56,27 @@ func TestFileSystem_Compress(t *testing.T) { asserts.Contains(zipFile, "archive_") asserts.Contains(zipFile, "tests") } + + // 上下文取消 + { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + // 查找压缩父目录 + mock.ExpectQuery("SELECT(.+)folders(.+)"). + WithArgs(1, 1). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "parent")) + // 查找顶级待压缩文件 + mock.ExpectQuery("SELECT(.+)files(.+)"). + WithArgs(1, 1). + WillReturnRows( + sqlmock.NewRows( + []string{"id", "name", "source_name", "policy_id"}). + AddRow(1, "1.txt", "tests/file1.txt", 1), + ) + asserts.NoError(cache.Set("setting_temp_path", "tests", -1)) + + zipFile, err := fs.Compress(ctx, []uint{1}, []uint{1}) + asserts.Error(err) + asserts.Empty(zipFile) + } } diff --git a/pkg/filesystem/hooks_test.go b/pkg/filesystem/hooks_test.go index 3533284..3d6d739 100644 --- a/pkg/filesystem/hooks_test.go +++ b/pkg/filesystem/hooks_test.go @@ -550,9 +550,9 @@ type ClientMock struct { testMock.Mock } -func (m ClientMock) Request(method, target string, body io.Reader, opts ...request.Option) request.Response { +func (m ClientMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response { args := m.Called(method, target, body, opts) - return args.Get(0).(request.Response) + return args.Get(0).(*request.Response) } func TestSlaveAfterUpload(t *testing.T) { @@ -571,7 +571,7 @@ func TestSlaveAfterUpload(t *testing.T) { "http://test/callbakc", testMock.Anything, testMock.Anything, - ).Return(request.Response{ + ).Return(&request.Response{ Err: nil, Response: &http.Response{ StatusCode: 200, diff --git a/pkg/filesystem/remote/handler.go b/pkg/filesystem/remote/handler.go index 27e772b..e703c87 100644 --- a/pkg/filesystem/remote/handler.go +++ b/pkg/filesystem/remote/handler.go @@ -26,8 +26,8 @@ type Handler struct { AuthInstance auth.Auth } -// getAPI 获取接口请求地址 -func (handler Handler) getAPI(scope string) string { +// getAPIUrl 获取接口请求地址 +func (handler Handler) getAPIUrl(scope string) string { serverURL, err := url.Parse(handler.Policy.Server) if err != nil { return "" @@ -45,7 +45,6 @@ func (handler Handler) getAPI(scope string) string { } // Get 获取文件内容 -// TODO 测试 func (handler Handler) Get(ctx context.Context, path string) (response.RSCloser, error) { // 尝试获取速度限制 TODO 是否需要在这里限制? speedLimit := 0 @@ -65,7 +64,7 @@ func (handler Handler) Get(ctx context.Context, path string) (response.RSCloser, downloadURL, nil, request.WithContext(ctx), - ).GetRSCloser() + ).CheckHTTPResponse(200).GetRSCloser() if err != nil { return nil, err @@ -96,10 +95,10 @@ func (handler Handler) Delete(ctx context.Context, files []string) ([]string, er signTTL := model.GetIntSetting("slave_api_timeout", 60) resp, err := handler.Client.Request( "POST", - handler.getAPI("delete"), + handler.getAPIUrl("delete"), bodyReader, request.WithCredential(handler.AuthInstance, int64(signTTL)), - ).GetResponse(200) + ).CheckHTTPResponse(200).GetResponse() if err != nil { return files, err } @@ -127,7 +126,7 @@ func (handler Handler) Delete(ctx context.Context, files []string) ([]string, er // Thumb 获取文件缩略图 func (handler Handler) Thumb(ctx context.Context, path string) (*response.ContentResponse, error) { sourcePath := base64.RawURLEncoding.EncodeToString([]byte(path)) - thumbURL := handler.getAPI("thumb") + "/" + sourcePath + thumbURL := handler.getAPIUrl("thumb") + "/" + sourcePath ttl := model.GetIntSetting("slave_api_timeout", 60) signedThumbURL, err := auth.SignURI(handler.AuthInstance, thumbURL, int64(ttl)) if err != nil { diff --git a/pkg/filesystem/remote/handler_test.go b/pkg/filesystem/remote/handler_test.go index 0ee7e45..1f09954 100644 --- a/pkg/filesystem/remote/handler_test.go +++ b/pkg/filesystem/remote/handler_test.go @@ -102,9 +102,9 @@ type ClientMock struct { testMock.Mock } -func (m ClientMock) Request(method, target string, body io.Reader, opts ...request.Option) request.Response { +func (m ClientMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response { args := m.Called(method, target, body, opts) - return args.Get(0).(request.Response) + return args.Get(0).(*request.Response) } func TestHandler_Delete(t *testing.T) { @@ -128,7 +128,7 @@ func TestHandler_Delete(t *testing.T) { "http://test.com/api/v3/slave/delete", testMock.Anything, testMock.Anything, - ).Return(request.Response{ + ).Return(&request.Response{ Err: nil, Response: &http.Response{ StatusCode: 200, @@ -152,7 +152,7 @@ func TestHandler_Delete(t *testing.T) { "http://test.com/api/v3/slave/delete", testMock.Anything, testMock.Anything, - ).Return(request.Response{ + ).Return(&request.Response{ Err: nil, Response: &http.Response{ StatusCode: 200, @@ -175,7 +175,7 @@ func TestHandler_Delete(t *testing.T) { "http://test.com/api/v3/slave/delete", testMock.Anything, testMock.Anything, - ).Return(request.Response{ + ).Return(&request.Response{ Err: nil, Response: &http.Response{ StatusCode: 200, @@ -189,3 +189,63 @@ func TestHandler_Delete(t *testing.T) { asserts.Len(failed, 1) } } + +func TestHandler_Get(t *testing.T) { + asserts := assert.New(t) + handler := Handler{ + Policy: &model.Policy{ + SecretKey: "test", + Server: "http://test.com", + }, + AuthInstance: auth.HMACAuth{}, + } + ctx := context.Background() + + // 成功 + { + ctx = context.WithValue(ctx, fsctx.UserCtx, model.User{}) + clientMock := ClientMock{} + clientMock.On( + "Request", + "GET", + testMock.Anything, + nil, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), + }, + }) + handler.Client = clientMock + resp, err := handler.Get(ctx, "/test.txt") + clientMock.AssertExpectations(t) + asserts.NotNil(resp) + asserts.NoError(err) + } + + // 请求失败 + { + ctx = context.WithValue(ctx, fsctx.UserCtx, model.User{}) + clientMock := ClientMock{} + clientMock.On( + "Request", + "GET", + testMock.Anything, + nil, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 404, + Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), + }, + }) + handler.Client = clientMock + resp, err := handler.Get(ctx, "/test.txt") + clientMock.AssertExpectations(t) + asserts.Nil(resp) + asserts.Error(err) + } +} diff --git a/pkg/request/request.go b/pkg/request/request.go index ba3b810..df95547 100644 --- a/pkg/request/request.go +++ b/pkg/request/request.go @@ -12,6 +12,7 @@ import ( "time" ) +// GeneralClient 通用 HTTP Client var GeneralClient Client = HTTPClient{} // Response 请求的响应或错误信息 @@ -22,7 +23,7 @@ type Response struct { // Client 请求客户端 type Client interface { - Request(method, target string, body io.Reader, opts ...Option) Response + Request(method, target string, body io.Reader, opts ...Option) *Response } // HTTPClient 实现 Client 接口 @@ -63,7 +64,6 @@ func WithTimeout(t time.Duration) Option { } // WithContext 设置请求上下文 -// TODO 测试 func WithContext(c context.Context) Option { return optionFunc(func(o *options) { o.ctx = c @@ -86,7 +86,7 @@ func WithHeader(header http.Header) Option { } // Request 发送HTTP请求 -func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Option) Response { +func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response { // 应用额外设置 options := newDefaultOption() for _, o := range opts { @@ -107,7 +107,7 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio req, err = http.NewRequest(method, target, body) } if err != nil { - return Response{Err: err} + return &Response{Err: err} } // 添加请求header @@ -121,32 +121,44 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio // 发送请求 resp, err := client.Do(req) if err != nil { - return Response{Err: err} + return &Response{Err: err} } - return Response{Err: nil, Response: resp} + return &Response{Err: nil, Response: resp} } // GetResponse 检查响应并获取响应正文 -// todo 测试 -func (resp Response) GetResponse(expectStatus int) (string, error) { +func (resp *Response) GetResponse() (string, error) { if resp.Err != nil { return "", resp.Err } respBody, err := ioutil.ReadAll(resp.Response.Body) - if resp.Response.StatusCode != expectStatus { - return string(respBody), - fmt.Errorf("服务器返回非正常HTTP状态%d", resp.Response.StatusCode) - } return string(respBody), err } +// CheckHTTPResponse 检查请求响应HTTP状态码 +func (resp *Response) CheckHTTPResponse(status int) *Response { + if resp.Err != nil { + return resp + } + + // 检查HTTP状态码 + if resp.Response.StatusCode != status { + resp.Err = fmt.Errorf("服务器返回非正常HTTP状态%d", resp.Response.StatusCode) + } + return resp +} + type nopRSCloser struct { body io.ReadCloser } // GetRSCloser 返回带有空seeker的body reader -func (resp Response) GetRSCloser() (response.RSCloser, error) { +func (resp *Response) GetRSCloser() (response.RSCloser, error) { + if resp.Err != nil { + return nil, resp.Err + } + return nopRSCloser{ body: resp.Response.Body, }, resp.Err diff --git a/pkg/request/request_test.go b/pkg/request/request_test.go index 3055172..af8157d 100644 --- a/pkg/request/request_test.go +++ b/pkg/request/request_test.go @@ -1,10 +1,13 @@ package request import ( + "context" + "errors" "github.com/HFO4/cloudreve/pkg/auth" "github.com/stretchr/testify/assert" testMock "github.com/stretchr/testify/mock" "io" + "io/ioutil" "net/http" "strings" "testing" @@ -15,9 +18,9 @@ type ClientMock struct { testMock.Mock } -func (m ClientMock) Request(method, target string, body io.Reader, opts ...Option) Response { +func (m ClientMock) Request(method, target string, body io.Reader, opts ...Option) *Response { args := m.Called(method, target, body, opts) - return args.Get(0).(Response) + return args.Get(0).(*Response) } func TestWithTimeout(t *testing.T) { @@ -42,6 +45,13 @@ func TestWithCredential(t *testing.T) { asserts.EqualValues(10, options.signTTL) } +func TestWithContext(t *testing.T) { + asserts := assert.New(t) + options := newDefaultOption() + WithContext(context.Background()).apply(options) + asserts.NotNil(options.ctx) +} + func TestHTTPClient_Request(t *testing.T) { asserts := assert.New(t) client := HTTPClient{} @@ -59,4 +69,103 @@ func TestHTTPClient_Request(t *testing.T) { asserts.Nil(resp.Response) } + // 正常 带有ctx + { + resp := client.Request( + "GET", + "http://cloudreveisnotexist.com", + strings.NewReader(""), + WithTimeout(time.Duration(1)*time.Microsecond), + WithCredential(auth.HMACAuth{SecretKey: []byte("123")}, 10), + WithContext(context.Background()), + ) + asserts.Error(resp.Err) + asserts.Nil(resp.Response) + } + +} + +func TestResponse_GetResponse(t *testing.T) { + asserts := assert.New(t) + + // 直接返回错误 + { + resp := Response{ + Err: errors.New("error"), + } + content, err := resp.GetResponse() + asserts.Empty(content) + asserts.Error(err) + } + + // 正常 + { + resp := Response{ + Response: &http.Response{Body: ioutil.NopCloser(strings.NewReader("123"))}, + } + content, err := resp.GetResponse() + asserts.Equal("123", content) + asserts.NoError(err) + } +} + +func TestResponse_CheckHTTPResponse(t *testing.T) { + asserts := assert.New(t) + + // 直接返回错误 + { + resp := Response{ + Err: errors.New("error"), + } + res := resp.CheckHTTPResponse(200) + asserts.Error(res.Err) + } + + // 404错误 + { + resp := Response{ + Response: &http.Response{StatusCode: 404}, + } + res := resp.CheckHTTPResponse(200) + asserts.Error(res.Err) + } + + // 通过 + { + resp := Response{ + Response: &http.Response{StatusCode: 200}, + } + res := resp.CheckHTTPResponse(200) + asserts.NoError(res.Err) + } +} + +func TestResponse_GetRSCloser(t *testing.T) { + asserts := assert.New(t) + + // 直接返回错误 + { + resp := Response{ + Err: errors.New("error"), + } + res, err := resp.GetRSCloser() + asserts.Error(err) + asserts.Nil(res) + } + + // 正常 + { + resp := Response{ + Response: &http.Response{Body: ioutil.NopCloser(strings.NewReader("123"))}, + } + res, err := resp.GetRSCloser() + asserts.NoError(err) + content, err := ioutil.ReadAll(res) + asserts.NoError(err) + asserts.Equal("123", string(content)) + _, err = res.Seek(0, 0) + asserts.Error(err) + asserts.NoError(res.Close()) + } + } diff --git a/pkg/request/slave.go b/pkg/request/slave.go index e18440c..261639d 100644 --- a/pkg/request/slave.go +++ b/pkg/request/slave.go @@ -35,7 +35,7 @@ func RemoteCallback(url string, body serializer.RemoteUploadCallback) error { } // 检查返回HTTP状态码 - rawResp, err := resp.GetResponse(200) + rawResp, err := resp.CheckHTTPResponse(200).GetResponse() if err != nil { return serializer.NewError(serializer.CodeCallbackError, "服务器返回异常响应", err) } diff --git a/pkg/request/slave_test.go b/pkg/request/slave_test.go index 957a47b..b77816b 100644 --- a/pkg/request/slave_test.go +++ b/pkg/request/slave_test.go @@ -51,7 +51,7 @@ func TestRemoteCallback(t *testing.T) { "http://test/test/url", testMock.Anything, testMock.Anything, - ).Return(Response{ + ).Return(&Response{ Err: nil, Response: &http.Response{ StatusCode: 200, @@ -75,7 +75,7 @@ func TestRemoteCallback(t *testing.T) { "http://test/test/url", testMock.Anything, testMock.Anything, - ).Return(Response{ + ).Return(&Response{ Err: nil, Response: &http.Response{ StatusCode: 200, @@ -99,7 +99,7 @@ func TestRemoteCallback(t *testing.T) { "http://test/test/url", testMock.Anything, testMock.Anything, - ).Return(Response{ + ).Return(&Response{ Err: nil, Response: &http.Response{ StatusCode: 404, @@ -123,7 +123,7 @@ func TestRemoteCallback(t *testing.T) { "http://test/test/url", testMock.Anything, testMock.Anything, - ).Return(Response{ + ).Return(&Response{ Err: errors.New("error"), }) GeneralClient = clientMock diff --git a/routers/controllers/slave.go b/routers/controllers/slave.go index 5f2490b..6792a81 100644 --- a/routers/controllers/slave.go +++ b/routers/controllers/slave.go @@ -81,7 +81,7 @@ func SlaveUpload(c *gin.Context) { }) } -// SlaveDownload 从机文件下载 +// SlaveDownload 从机文件下载,此请求返回的HTTP状态码不全为200 func SlaveDownload(c *gin.Context) { // 创建上下文 ctx, cancel := context.WithCancel(context.Background()) @@ -91,10 +91,10 @@ func SlaveDownload(c *gin.Context) { if err := c.ShouldBindUri(&service); err == nil { res := service.ServeFile(ctx, c, true) if res.Code != 0 { - c.JSON(200, res) + c.JSON(400, res) } } else { - c.JSON(200, ErrorResponse(err)) + c.JSON(400, ErrorResponse(err)) } }