From 9a942f8b487aba878ea4e0f97cc399af07642eff Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Mon, 16 Dec 2019 12:52:35 +0800 Subject: [PATCH] Modify: decoupling getSignedURL modules --- pkg/filesystem/file.go | 73 +++++++++++++++------------ pkg/filesystem/file_test.go | 11 ++-- pkg/filesystem/filesystem.go | 12 +++-- pkg/filesystem/local/handler.go | 69 +++++++++++++------------ pkg/filesystem/local/handller_test.go | 8 +-- pkg/filesystem/upload_test.go | 8 +-- service/explorer/file.go | 2 +- 7 files changed, 98 insertions(+), 85 deletions(-) diff --git a/pkg/filesystem/file.go b/pkg/filesystem/file.go index 1eb33f7..7ce2d26 100644 --- a/pkg/filesystem/file.go +++ b/pkg/filesystem/file.go @@ -175,37 +175,36 @@ func (fs *FileSystem) GroupFileByPolicy(ctx context.Context, files []model.File) return policyGroup } -// GetDownloadURL 创建文件下载链接 -func (fs *FileSystem) GetDownloadURL(ctx context.Context, path string) (string, error) { +// GetDownloadURL 创建文件下载链接, timeout 为数据库中存储过期时间的字段 +func (fs *FileSystem) GetDownloadURL(ctx context.Context, path string, timeout string) (string, error) { + var fileTarget *model.File // 找到文件 if len(fs.FileTarget) == 0 { exist, file := fs.IsFileExist(path) if !exist { return "", ErrObjectNotExist } - fs.FileTarget = []model.File{*file} - } - - ctx = context.WithValue(ctx, fsctx.FileModelCtx, fs.FileTarget[0]) - - // 将当前存储策略重设为文件使用的 - fs.Policy = fs.FileTarget[0].GetPolicy() - err := fs.dispatchHandler() - if err != nil { - return "", err + fileTarget = file + } else { + fileTarget = &fs.FileTarget[0] } // 生成下載地址 - siteURL := model.GetSiteURL() - ttl, err := strconv.ParseInt(model.GetSettingByName("download_timeout"), 10, 64) + ttl, err := strconv.ParseInt(model.GetSettingByName(timeout), 10, 64) if err != nil { - return "", serializer.NewError(serializer.CodeInternalSetting, "无法获取下载地址有效期", err) + return "", + serializer.NewError( + serializer.CodeInternalSetting, + "无法获取下载地址有效期", + err, + ) } - source, err := fs.Handler.GetDownloadURL( + + source, err := fs.signURL( ctx, - fs.FileTarget[0].SourceName, - *siteURL, + fileTarget, ttl, + true, ) if err != nil { return "", err @@ -222,18 +221,8 @@ func (fs *FileSystem) GetSource(ctx context.Context, fileID uint) (string, error return "", ErrObjectNotExist.WithError(err) } - fs.FileTarget = []model.File{fileObject[0]} - ctx = context.WithValue(ctx, fsctx.FileModelCtx, fileObject[0]) - - // 将当前存储策略重设为文件使用的 - fs.Policy = fileObject[0].GetPolicy() - err = fs.dispatchHandler() - if err != nil { - return "", err - } - // 检查存储策略是否可以获得外链 - if !fs.Policy.IsOriginLinkEnable { + if !fileObject[0].GetPolicy().IsOriginLinkEnable { return "", serializer.NewError( serializer.CodePolicyNotAllowed, "当前存储策略无法获得外链", @@ -241,9 +230,29 @@ func (fs *FileSystem) GetSource(ctx context.Context, fileID uint) (string, error ) } - // 生成外链地址 - siteURL := model.GetSiteURL() - source, err := fs.Handler.Source(ctx, fileObject[0].SourceName, *siteURL, 0) + source, err := fs.signURL(ctx, &fileObject[0], 0, false) + if err != nil { + return "", serializer.NewError(serializer.CodeNotSet, "无法获取外链", err) + } + + return source, nil +} + +func (fs *FileSystem) signURL(ctx context.Context, file *model.File, ttl int64, isDownload bool) (string, error) { + fs.FileTarget = []model.File{*file} + ctx = context.WithValue(ctx, fsctx.FileModelCtx, *file) + + // 将当前存储策略重设为文件使用的 + fs.Policy = file.GetPolicy() + err := fs.dispatchHandler() + if err != nil { + return "", err + } + + // 签名最终URL + // 生成外链地址 + siteURL := model.GetSiteURL() + source, err := fs.Handler.Source(ctx, fs.FileTarget[0].SourceName, *siteURL, ttl, isDownload) if err != nil { return "", serializer.NewError(serializer.CodeNotSet, "无法获取外链", err) } diff --git a/pkg/filesystem/file_test.go b/pkg/filesystem/file_test.go index d03b624..0c63c45 100644 --- a/pkg/filesystem/file_test.go +++ b/pkg/filesystem/file_test.go @@ -403,17 +403,16 @@ func TestFileSystem_GetDownloadURL(t *testing.T) { WithArgs(1). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"}).AddRow(1, "1.txt", 1)) + // 相关设置 + mock.ExpectQuery("SELECT(.+)").WithArgs("download_timeout").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "20")) // 查找上传策略 mock.ExpectQuery("SELECT(.+)"). WillReturnRows( sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}). AddRow(35, "local", true), ) - // 相关设置 mock.ExpectQuery("SELECT(.+)").WithArgs("siteURL").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "https://cloudreve.org")) - mock.ExpectQuery("SELECT(.+)").WithArgs("download_timeout").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "20")) - - downloadURL, err := fs.GetDownloadURL(ctx, "/1.txt") + downloadURL, err := fs.GetDownloadURL(ctx, "/1.txt", "download_timeout") asserts.NoError(mock.ExpectationsWereMet()) asserts.NoError(err) asserts.NotEmpty(downloadURL) @@ -432,7 +431,7 @@ func TestFileSystem_GetDownloadURL(t *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"})) - downloadURL, err := fs.GetDownloadURL(ctx, "/1.txt") + downloadURL, err := fs.GetDownloadURL(ctx, "/1.txt", "download_timeout") asserts.NoError(mock.ExpectationsWereMet()) asserts.Error(err) asserts.Empty(downloadURL) @@ -457,7 +456,7 @@ func TestFileSystem_GetDownloadURL(t *testing.T) { AddRow(35, "unknown", true), ) - downloadURL, err := fs.GetDownloadURL(ctx, "/1.txt") + downloadURL, err := fs.GetDownloadURL(ctx, "/1.txt", "download_timeout") asserts.NoError(mock.ExpectationsWereMet()) asserts.Error(err) asserts.Empty(downloadURL) diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go index 08a8490..d6d0aa1 100644 --- a/pkg/filesystem/filesystem.go +++ b/pkg/filesystem/filesystem.go @@ -24,16 +24,20 @@ type FileHeader interface { type Handler interface { // 上传文件 Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error + // 删除一个或多个文件 Delete(ctx context.Context, files []string) ([]string, error) + // 获取文件 Get(ctx context.Context, path string) (response.RSCloser, error) + // 获取缩略图 Thumb(ctx context.Context, path string) (*response.ContentResponse, error) - // 获取外链地址,url - Source(ctx context.Context, path string, url url.URL, expires int64) (string, error) - //获取下载地址 - GetDownloadURL(ctx context.Context, path string, url url.URL, expires int64) (string, error) + + // 获取外链/下载地址, + // url - 站点本身地址, + // isDownload - 是否直接下载 + Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool) (string, error) } // FileSystem 管理文件的文件系统 diff --git a/pkg/filesystem/local/handler.go b/pkg/filesystem/local/handler.go index db5cdd5..a0cb6d9 100644 --- a/pkg/filesystem/local/handler.go +++ b/pkg/filesystem/local/handler.go @@ -111,47 +111,52 @@ func (handler Handler) Thumb(ctx context.Context, path string) (*response.Conten } // Source 获取外链URL -func (handler Handler) Source(ctx context.Context, path string, url url.URL, expires int64) (string, error) { +func (handler Handler) Source( + ctx context.Context, + path string, + baseURL url.URL, + ttl int64, + isDownload bool, +) (string, error) { file, ok := ctx.Value(fsctx.FileModelCtx).(model.File) if !ok { return "", errors.New("无法获取文件记录上下文") } - // 签名生成文件记录 - signedURI, err := auth.SignURI( - fmt.Sprintf("/api/v3/file/get/%d/%s", file.ID, file.Name), - 0, + var expires int64 + if ttl > 0 { + expires = time.Now().Unix() + ttl + } + + var ( + signedURI *url.URL + err error ) + if isDownload { + // 创建下载会话,将文件信息写入缓存 + downloadSessionID := util.RandStringRunes(16) + err = cache.Set("download_"+downloadSessionID, file, int(ttl)) + if err != nil { + return "", serializer.NewError(serializer.CodeCacheOperation, "无法创建下載会话", err) + } + + // 签名生成文件记录 + signedURI, err = auth.SignURI( + fmt.Sprintf("/api/v3/file/download/%s", downloadSessionID), + expires, + ) + } else { + // 签名生成文件记录 + signedURI, err = auth.SignURI( + fmt.Sprintf("/api/v3/file/get/%d/%s", file.ID, file.Name), + expires, + ) + } + if err != nil { return "", serializer.NewError(serializer.CodeEncryptError, "无法对URL进行签名", err) } - finalURL := url.ResolveReference(signedURI).String() - return finalURL, nil -} - -func (handler Handler) GetDownloadURL(ctx context.Context, path string, url url.URL, ttl int64) (string, error) { - file, ok := ctx.Value(fsctx.FileModelCtx).(model.File) - if !ok { - return "", errors.New("无法获取文件记录上下文") - } - - // 创建下载会话,将文件信息写入缓存 - downloadSessionID := util.RandStringRunes(16) - err := cache.Set("download_"+downloadSessionID, file, int(ttl)) - if err != nil { - return "", serializer.NewError(serializer.CodeCacheOperation, "无法创建下載会话", err) - } - - // 签名生成文件记录 - signedURI, err := auth.SignURI( - fmt.Sprintf("/api/v3/file/download/%s", downloadSessionID), - time.Now().Unix()+ttl, - ) - if err != nil { - return "", serializer.NewError(serializer.CodeEncryptError, "无法对URL进行签名", err) - } - - finalURL := url.ResolveReference(signedURI).String() + finalURL := baseURL.ResolveReference(signedURI).String() return finalURL, nil } diff --git a/pkg/filesystem/local/handller_test.go b/pkg/filesystem/local/handller_test.go index 6f4ddab..ea9ff6a 100644 --- a/pkg/filesystem/local/handller_test.go +++ b/pkg/filesystem/local/handller_test.go @@ -135,7 +135,7 @@ func TestHandler_Source(t *testing.T) { ctx := context.WithValue(ctx, fsctx.FileModelCtx, file) baseURL, err := url.Parse("https://cloudreve.org") asserts.NoError(err) - sourceURL, err := handler.Source(ctx, "", *baseURL, 0) + sourceURL, err := handler.Source(ctx, "", *baseURL, 0, false) asserts.NoError(err) asserts.NotEmpty(sourceURL) asserts.Contains(sourceURL, "sign=") @@ -146,7 +146,7 @@ func TestHandler_Source(t *testing.T) { { baseURL, err := url.Parse("https://cloudreve.org") asserts.NoError(err) - sourceURL, err := handler.Source(ctx, "", *baseURL, 0) + sourceURL, err := handler.Source(ctx, "", *baseURL, 0, false) asserts.Error(err) asserts.Empty(sourceURL) } @@ -169,7 +169,7 @@ func TestHandler_GetDownloadURL(t *testing.T) { ctx := context.WithValue(ctx, fsctx.FileModelCtx, file) baseURL, err := url.Parse("https://cloudreve.org") asserts.NoError(err) - downloadURL, err := handler.GetDownloadURL(ctx, "", *baseURL, 10) + downloadURL, err := handler.Source(ctx, "", *baseURL, 10, true) asserts.NoError(err) asserts.Contains(downloadURL, "sign=") asserts.Contains(downloadURL, "https://cloudreve.org") @@ -179,7 +179,7 @@ func TestHandler_GetDownloadURL(t *testing.T) { { baseURL, err := url.Parse("https://cloudreve.org") asserts.NoError(err) - downloadURL, err := handler.GetDownloadURL(ctx, "", *baseURL, 10) + downloadURL, err := handler.Source(ctx, "", *baseURL, 10, true) asserts.Error(err) asserts.Empty(downloadURL) } diff --git a/pkg/filesystem/upload_test.go b/pkg/filesystem/upload_test.go index b90a483..aa53d0c 100644 --- a/pkg/filesystem/upload_test.go +++ b/pkg/filesystem/upload_test.go @@ -42,12 +42,8 @@ func (m FileHeaderMock) Thumb(ctx context.Context, files string) (*response.Cont return args.Get(0).(*response.ContentResponse), args.Error(1) } -func (m FileHeaderMock) Source(ctx context.Context, path string, url url.URL, expires int64) (string, error) { - args := m.Called(ctx, path, url, expires) - return args.Get(0).(string), args.Error(1) -} -func (m FileHeaderMock) GetDownloadURL(ctx context.Context, path string, url url.URL, expires int64) (string, error) { - args := m.Called(ctx, path, url, expires) +func (m FileHeaderMock) Source(ctx context.Context, path string, url url.URL, expires int64, isDownload bool) (string, error) { + args := m.Called(ctx, path, url, expires, isDownload) return args.Get(0).(string), args.Error(1) } diff --git a/service/explorer/file.go b/service/explorer/file.go index 6d1448d..c7aecb9 100644 --- a/service/explorer/file.go +++ b/service/explorer/file.go @@ -104,7 +104,7 @@ func (service *SingleFileService) CreateDownloadSession(ctx context.Context, c * } // 获取下载地址 - downloadURL, err := fs.GetDownloadURL(ctx, service.Path) + downloadURL, err := fs.GetDownloadURL(ctx, service.Path, "download_timeout") if err != nil { return serializer.Err(serializer.CodeNotSet, err.Error(), err) }