From e2f6dab70c6769ef4001958325b432025f798855 Mon Sep 17 00:00:00 2001
From: HFO4 <912394456@qq.com>
Date: Wed, 29 Jan 2020 13:45:27 +0800
Subject: [PATCH] Feat: save re-save single shared file

---
 middleware/share.go          |  96 +++++++++++++++++++++++++++++++++
 models/folder.go             |  40 ++++++++++++--
 models/folder_test.go        |   1 -
 models/init.go               |   2 +-
 models/user.go               |   2 +-
 pkg/filesystem/manage.go     |  33 +++++++++++-
 routers/controllers/share.go |  11 ++++
 routers/router.go            |  29 ++++++++--
 service/share/manage.go      |  13 +----
 service/share/visit.go       | 102 +++++++++++++----------------------
 10 files changed, 241 insertions(+), 88 deletions(-)
 create mode 100644 middleware/share.go

diff --git a/middleware/share.go b/middleware/share.go
new file mode 100644
index 0000000..2154bf0
--- /dev/null
+++ b/middleware/share.go
@@ -0,0 +1,96 @@
+package middleware
+
+import (
+	"fmt"
+	model "github.com/HFO4/cloudreve/models"
+	"github.com/HFO4/cloudreve/pkg/serializer"
+	"github.com/HFO4/cloudreve/pkg/util"
+	"github.com/gin-gonic/gin"
+)
+
+// ShareAvailable 检查分享是否可用
+func ShareAvailable() gin.HandlerFunc {
+	return func(c *gin.Context) {
+		var user *model.User
+		if userCtx, ok := c.Get("user"); ok {
+			user = userCtx.(*model.User)
+		} else {
+			user = model.NewAnonymousUser()
+		}
+
+		share := model.GetShareByHashID(c.Param("id"))
+
+		if share == nil || !share.IsAvailable() {
+			c.JSON(200, serializer.Err(serializer.CodeNotFound, "分享不存在或已被取消", nil))
+			c.Abort()
+			return
+		}
+
+		c.Set("user", user)
+		c.Set("share", share)
+		c.Next()
+	}
+}
+
+// ShareCanPreview 检查分享是否可被预览
+func ShareCanPreview() gin.HandlerFunc {
+	return func(c *gin.Context) {
+		if share, ok := c.Get("share"); ok {
+			if share.(*model.Share).PreviewEnabled {
+				c.Next()
+				return
+			}
+			c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, "此分享无法预览",
+				nil))
+			c.Abort()
+			return
+		}
+		c.Abort()
+	}
+}
+
+// BeforeShareDownload 分享被下载前的检查
+func BeforeShareDownload() gin.HandlerFunc {
+	return func(c *gin.Context) {
+		if shareCtx, ok := c.Get("share"); ok {
+			if userCtx, ok := c.Get("user"); ok {
+				share := shareCtx.(*model.Share)
+				user := userCtx.(*model.User)
+
+				// 检查用户是否可以下载此分享的文件
+				err := share.CanBeDownloadBy(user)
+				if err != nil {
+					c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, err.Error(),
+						nil))
+					c.Abort()
+					return
+				}
+
+				// 分享是否已解锁
+				if share.Password != "" {
+					sessionKey := fmt.Sprintf("share_unlock_%d", share.ID)
+					unlocked := util.GetSession(c, sessionKey) != nil
+					if !unlocked {
+						c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr,
+							"无权访问此分享", nil))
+						c.Abort()
+						return
+					}
+				}
+
+				// 对积分、下载次数进行更新
+				err = share.DownloadBy(user, c)
+				if err != nil {
+					c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, err.Error(),
+						nil))
+					c.Abort()
+					return
+				}
+
+				c.Next()
+				return
+			}
+		}
+		c.Abort()
+	}
+}
diff --git a/models/folder.go b/models/folder.go
index efc3dd0..96cf516 100644
--- a/models/folder.go
+++ b/models/folder.go
@@ -13,7 +13,7 @@ type Folder struct {
 	// 表字段
 	gorm.Model
 	Name     string `gorm:"unique_index:idx_only_one_name"`
-	ParentID uint   `gorm:"index:parent_id;unique_index:idx_only_one_name"`
+	ParentID *uint  `gorm:"index:parent_id;unique_index:idx_only_one_name"`
 	OwnerID  uint   `gorm:"index:owner_id"`
 
 	// 数据库忽略字段
@@ -192,7 +192,7 @@ func (folder *Folder) CopyFolderTo(folderID uint, dstFolder *Folder) (size uint6
 		// 顶级目录直接指向新的目的目录
 		if folder.ID == folderID {
 			newID = dstFolder.ID
-		} else if IDCache, ok := newIDCache[folder.ParentID]; ok {
+		} else if IDCache, ok := newIDCache[*folder.ParentID]; ok {
 			newID = IDCache
 		} else {
 			util.Log().Warning("无法取得新的父目录:%d", folder.ParentID)
@@ -202,7 +202,7 @@ func (folder *Folder) CopyFolderTo(folderID uint, dstFolder *Folder) (size uint6
 		// 插入新的目录记录
 		oldID := folder.ID
 		folder.Model = gorm.Model{}
-		folder.ParentID = newID
+		folder.ParentID = &newID
 		if err = DB.Create(&folder).Error; err != nil {
 			return size, err
 		}
@@ -262,6 +262,40 @@ func (folder *Folder) Rename(new string) error {
 	return nil
 }
 
+// CopyChildFrom 将给定文件和拷贝至自身,并更改所有者ID
+func (folder *Folder) CopyChildFrom(folders []Folder, files []File) error {
+	// 开启事务
+	tx := DB.Begin()
+	defer func() {
+		if r := recover(); r != nil {
+			tx.Rollback()
+		}
+	}()
+
+	// 记录文件父目录对应复制的新目录ID
+	var newParent = make(map[uint]uint, len(folders))
+
+	// TODO 复制目录结构
+
+	// 复制子文件
+	for _, file := range files {
+		file.ID = 0
+		file.UserID = folder.OwnerID
+		if newParentID, ok := newParent[file.FolderID]; ok {
+			file.FolderID = newParentID
+		} else {
+			file.FolderID = folder.ID
+		}
+		if err := tx.Create(&file).Error; err != nil {
+			tx.Rollback()
+			return err
+		}
+	}
+
+	return tx.Commit().Error
+
+}
+
 /*
 	实现 FileInfo.FileInfo 接口
 	TODO 测试
diff --git a/models/folder_test.go b/models/folder_test.go
index d41a8ba..1996c6d 100644
--- a/models/folder_test.go
+++ b/models/folder_test.go
@@ -515,7 +515,6 @@ func TestFolder_FileInfoInterface(t *testing.T) {
 			UpdatedAt: time.Date(2019, 12, 21, 12, 40, 0, 0, time.UTC),
 		},
 		Name:     "test_name",
-		ParentID: 0,
 		OwnerID:  0,
 		Position: "/test",
 	}
diff --git a/models/init.go b/models/init.go
index bafc977..c782745 100644
--- a/models/init.go
+++ b/models/init.go
@@ -47,7 +47,7 @@ func Init() {
 
 	// Debug模式下,输出所有 SQL 日志
 	if conf.SystemConfig.Debug {
-		db.LogMode(false)
+		db.LogMode(true)
 	}
 
 	//db.SetLogger(util.Log())
diff --git a/models/user.go b/models/user.go
index 1b0299c..8f0947f 100644
--- a/models/user.go
+++ b/models/user.go
@@ -60,7 +60,7 @@ type UserOption struct {
 // Root 获取用户的根目录
 func (user *User) Root() (*Folder, error) {
 	var folder Folder
-	err := DB.Where("parent_id = 0 AND owner_id = ?", user.ID).First(&folder).Error
+	err := DB.Where("parent_id is NULL AND owner_id = ?", user.ID).First(&folder).Error
 	return &folder, err
 }
 
diff --git a/pkg/filesystem/manage.go b/pkg/filesystem/manage.go
index 9fef4c3..5c3bba6 100644
--- a/pkg/filesystem/manage.go
+++ b/pkg/filesystem/manage.go
@@ -345,7 +345,7 @@ func (fs *FileSystem) CreateDirectory(ctx context.Context, fullPath string) erro
 	// 创建目录
 	newFolder := model.Folder{
 		Name:     dir,
-		ParentID: parent.ID,
+		ParentID: &parent.ID,
 		OwnerID:  fs.User.ID,
 	}
 	_, err := newFolder.Create()
@@ -355,3 +355,34 @@ func (fs *FileSystem) CreateDirectory(ctx context.Context, fullPath string) erro
 	}
 	return nil
 }
+
+// SaveTo 将别人分享的文件转存到目标路径下
+// TODO 测试
+func (fs *FileSystem) SaveTo(ctx context.Context, path string) error {
+	// 获取父目录
+	isExist, folder := fs.IsPathExist(path)
+	if !isExist {
+		return ErrPathNotExist
+	}
+
+	// TODO 列目录
+
+	// 计算要复制的总大小
+	var totalSize uint64
+	for _, file := range fs.FileTarget {
+		totalSize += file.Size
+	}
+
+	// 扣除用户容量
+	if !fs.User.IncreaseStorage(totalSize) {
+		return ErrInsufficientCapacity
+	}
+
+	err := folder.CopyChildFrom(fs.DirTarget, fs.FileTarget)
+	if err != nil {
+		fs.User.DeductionStorage(totalSize)
+		return ErrFileExisted.WithError(err)
+	}
+
+	return nil
+}
diff --git a/routers/controllers/share.go b/routers/controllers/share.go
index 25199f0..eb326a2 100644
--- a/routers/controllers/share.go
+++ b/routers/controllers/share.go
@@ -90,3 +90,14 @@ func GetShareDocPreview(c *gin.Context) {
 		c.JSON(200, ErrorResponse(err))
 	}
 }
+
+// SaveShare 转存他人分享
+func SaveShare(c *gin.Context) {
+	var service share.SingleFileService
+	if err := c.ShouldBindJSON(&service); err == nil {
+		res := service.SaveToMyFile(c)
+		c.JSON(200, res)
+	} else {
+		c.JSON(200, ErrorResponse(err))
+	}
+}
diff --git a/routers/router.go b/routers/router.go
index 00668f7..db4c59f 100644
--- a/routers/router.go
+++ b/routers/router.go
@@ -168,18 +168,31 @@ func InitMasterRouter() *gin.Engine {
 		}
 
 		// 分享相关
-		share := v3.Group("share")
+		share := v3.Group("share", middleware.ShareAvailable())
 		{
 			// 获取分享
 			share.GET("info/:id", controllers.GetShare)
 			// 创建文件下载会话
-			share.POST("download/:id", controllers.GetShareDownload)
+			share.POST("download/:id",
+				middleware.BeforeShareDownload(),
+				controllers.GetShareDownload,
+			)
 			// 预览分享文件
-			share.GET("preview/:id", controllers.PreviewShare)
+			share.GET("preview/:id",
+				middleware.ShareCanPreview(),
+				middleware.BeforeShareDownload(),
+				controllers.PreviewShare,
+			)
 			// 取得Office文档预览地址
-			share.GET("doc/:id", controllers.GetShareDocPreview)
+			share.GET("doc/:id", middleware.ShareCanPreview(),
+				middleware.BeforeShareDownload(),
+				controllers.GetShareDocPreview,
+			)
 			// 获取文本文件内容
-			share.GET("content/:id", controllers.PreviewShareText)
+			share.GET("content/:id",
+				middleware.BeforeShareDownload(),
+				controllers.PreviewShareText,
+			)
 		}
 
 		// 需要登录保护的
@@ -256,6 +269,12 @@ func InitMasterRouter() *gin.Engine {
 			{
 				// 创建新分享
 				share.POST("", controllers.CreateShare)
+				// 转存他人分享
+				share.POST("save/:id",
+					middleware.ShareAvailable(),
+					middleware.BeforeShareDownload(),
+					controllers.SaveShare,
+				)
 			}
 
 		}
diff --git a/service/share/manage.go b/service/share/manage.go
index 2247ca8..65c01a2 100644
--- a/service/share/manage.go
+++ b/service/share/manage.go
@@ -22,7 +22,8 @@ type ShareCreateService struct {
 
 // Create 创建新分享
 func (service *ShareCreateService) Create(c *gin.Context) serializer.Response {
-	user := currentUser(c)
+	userCtx, _ := c.Get("user")
+	user := userCtx.(*model.User)
 
 	// 是否拥有权限
 	if !user.Group.ShareEnabled {
@@ -82,13 +83,3 @@ func (service *ShareCreateService) Create(c *gin.Context) serializer.Response {
 	}
 
 }
-
-func currentUser(c *gin.Context) *model.User {
-	var user *model.User
-	if userCtx, ok := c.Get("user"); ok {
-		user = userCtx.(*model.User)
-	} else {
-		user = model.NewAnonymousUser()
-	}
-	return user
-}
diff --git a/service/share/visit.go b/service/share/visit.go
index 9278612..ee941a5 100644
--- a/service/share/visit.go
+++ b/service/share/visit.go
@@ -2,7 +2,6 @@ package share
 
 import (
 	"context"
-	"errors"
 	"fmt"
 	model "github.com/HFO4/cloudreve/models"
 	"github.com/HFO4/cloudreve/pkg/filesystem"
@@ -25,11 +24,10 @@ type SingleFileService struct {
 
 // Get 获取分享内容
 func (service *ShareGetService) Get(c *gin.Context) serializer.Response {
-	user := currentUser(c)
-	share := model.GetShareByHashID(c.Param("id"))
-	if share == nil || !share.IsAvailable() {
-		return serializer.Err(serializer.CodeNotFound, "分享不存在或已被取消", nil)
-	}
+	shareCtx, _ := c.Get("share")
+	share := shareCtx.(*model.Share)
+	userCtx, _ := c.Get("user")
+	user := userCtx.(*model.User)
 
 	// 是否已解锁
 	unlocked := true
@@ -62,17 +60,10 @@ func (service *ShareGetService) Get(c *gin.Context) serializer.Response {
 
 // CreateDownloadSession 创建下载会话
 func (service *SingleFileService) CreateDownloadSession(c *gin.Context) serializer.Response {
-	user := currentUser(c)
-	share := model.GetShareByHashID(c.Param("id"))
-	if share == nil || !share.IsAvailable() {
-		return serializer.Err(serializer.CodeNotFound, "分享不存在或已被取消", nil)
-	}
-
-	// 检查用户是否可以下载此分享的文件
-	err := CheckBeforeGetShare(share, user, c)
-	if err != nil {
-		return serializer.Err(serializer.CodeNoPermissionErr, err.Error(), nil)
-	}
+	shareCtx, _ := c.Get("share")
+	share := shareCtx.(*model.Share)
+	userCtx, _ := c.Get("user")
+	user := userCtx.(*model.User)
 
 	// 创建文件系统
 	fs, err := filesystem.NewFileSystem(user)
@@ -102,21 +93,8 @@ func (service *SingleFileService) CreateDownloadSession(c *gin.Context) serializ
 // PreviewContent 预览文件,需要登录会话, isText - 是否为文本文件,文本文件会
 // 强制经由服务端中转
 func (service *SingleFileService) PreviewContent(ctx context.Context, c *gin.Context, isText bool) serializer.Response {
-	user := currentUser(c)
-	share := model.GetShareByHashID(c.Param("id"))
-	if share == nil || !share.IsAvailable() {
-		return serializer.Err(serializer.CodeNotFound, "分享不存在或已被取消", nil)
-	}
-
-	if !share.PreviewEnabled {
-		return serializer.Err(serializer.CodeNoPermissionErr, "此分享无法预览", nil)
-	}
-
-	// 检查用户是否可以下载此分享的文件
-	err := CheckBeforeGetShare(share, user, c)
-	if err != nil {
-		return serializer.Err(serializer.CodeNoPermissionErr, err.Error(), nil)
-	}
+	shareCtx, _ := c.Get("share")
+	share := shareCtx.(*model.Share)
 
 	// 用于调下层service
 	ctx = context.WithValue(ctx, fsctx.FileModelCtx, share.GetSource())
@@ -129,21 +107,8 @@ func (service *SingleFileService) PreviewContent(ctx context.Context, c *gin.Con
 
 // CreateDocPreviewSession 创建Office预览会话,返回预览地址
 func (service *SingleFileService) CreateDocPreviewSession(c *gin.Context) serializer.Response {
-	user := currentUser(c)
-	share := model.GetShareByHashID(c.Param("id"))
-	if share == nil || !share.IsAvailable() {
-		return serializer.Err(serializer.CodeNotFound, "分享不存在或已被取消", nil)
-	}
-
-	if !share.PreviewEnabled {
-		return serializer.Err(serializer.CodeNoPermissionErr, "此分享无法预览", nil)
-	}
-
-	// 检查用户是否可以下载此分享的文件
-	err := CheckBeforeGetShare(share, user, c)
-	if err != nil {
-		return serializer.Err(serializer.CodeNoPermissionErr, err.Error(), nil)
-	}
+	shareCtx, _ := c.Get("share")
+	share := shareCtx.(*model.Share)
 
 	// 用于调下层service
 	ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, share.GetSource())
@@ -154,28 +119,35 @@ func (service *SingleFileService) CreateDocPreviewSession(c *gin.Context) serial
 	return subService.CreateDocPreviewSession(ctx, c)
 }
 
-// CheckBeforeGetShare 获取分享内容/下载前进行的一系列检查
-func CheckBeforeGetShare(share *model.Share, user *model.User, c *gin.Context) error {
-	// 检查用户是否可以下载此分享的文件
-	err := share.CanBeDownloadBy(user)
+// SaveToMyFile 将此分享转存到自己的网盘
+func (service *SingleFileService) SaveToMyFile(c *gin.Context) serializer.Response {
+	shareCtx, _ := c.Get("share")
+	share := shareCtx.(*model.Share)
+	userCtx, _ := c.Get("user")
+	user := userCtx.(*model.User)
+
+	// 不能转存自己的文件
+	if share.UserID == user.ID {
+		return serializer.Err(serializer.CodePolicyNotAllowed, "不能转存自己的分享", nil)
+	}
+
+	// 创建文件系统
+	fs, err := filesystem.NewFileSystem(user)
 	if err != nil {
-		return err
+		return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)
 	}
+	defer fs.Recycle()
 
-	// 分享是否已解锁
-	if share.Password != "" {
-		sessionKey := fmt.Sprintf("share_unlock_%d", share.ID)
-		unlocked := util.GetSession(c, sessionKey) != nil
-		if !unlocked {
-			return errors.New("无权访问此分享")
-		}
-	}
-
-	// 对积分、下载次数进行更新
-	err = share.DownloadBy(user, c)
+	// 重设文件系统处理目标为源文件
+	err = fs.SetTargetByInterface(share.GetSource())
 	if err != nil {
-		return err
+		return serializer.Err(serializer.CodePolicyNotAllowed, "源文件不存在", err)
 	}
 
-	return nil
+	err = fs.SaveTo(context.Background(), service.Path)
+	if err != nil {
+		return serializer.Err(serializer.CodeNotSet, err.Error(), err)
+	}
+
+	return serializer.Response{}
 }