Cloudreve/models/share.go

247 lines
6.3 KiB
Go
Raw Normal View History

2020-01-26 00:07:05 -05:00
package model
import (
"errors"
"fmt"
2020-02-13 00:17:09 -05:00
"strings"
2020-01-26 00:07:05 -05:00
"time"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
2020-01-26 00:07:05 -05:00
)
// Share 分享模型
type Share struct {
gorm.Model
Password string // 分享密码,空值为非加密分享
IsDir bool // 原始资源是否为目录
UserID uint // 创建用户ID
SourceID uint // 原始资源ID
Views int // 浏览数
Downloads int // 下载数
RemainDownloads int // 剩余下载配额,负值标识无限制
Expires *time.Time // 过期时间,空值表示无过期时间
PreviewEnabled bool // 是否允许直接预览
2020-02-12 22:53:24 -05:00
SourceName string `gorm:"index:source"` // 用于搜索的字段
2020-01-27 00:34:39 -05:00
// 数据库忽略字段
User User `gorm:"PRELOAD:false,association_autoupdate:false"`
File File `gorm:"PRELOAD:false,association_autoupdate:false"`
Folder Folder `gorm:"PRELOAD:false,association_autoupdate:false"`
2020-01-26 00:07:05 -05:00
}
// Create 创建分享
func (share *Share) Create() (uint, error) {
if err := DB.Create(share).Error; err != nil {
util.Log().Warning("无法插入数据库记录, %s", err)
return 0, err
}
return share.ID, nil
}
2020-01-26 01:57:07 -05:00
// GetShareByHashID 根据HashID查找分享
func GetShareByHashID(hashID string) *Share {
id, err := hashid.DecodeHashID(hashID, hashid.ShareID)
if err != nil {
return nil
}
var share Share
result := DB.First(&share, id)
if result.Error != nil {
return nil
}
return &share
}
2020-01-27 00:34:39 -05:00
// IsAvailable 返回此分享是否可用(是否过期)
func (share *Share) IsAvailable() bool {
if share.RemainDownloads == 0 {
return false
}
if share.Expires != nil && time.Now().After(*share.Expires) {
return false
}
2020-02-17 01:04:48 -05:00
// 检查创建者状态
if share.Creator().Status != Active {
return false
}
2020-01-27 00:34:39 -05:00
// 检查源对象是否存在
var sourceID uint
if share.IsDir {
2020-02-01 00:14:50 -05:00
folder := share.SourceFolder()
2020-01-27 00:34:39 -05:00
sourceID = folder.ID
} else {
2020-02-01 00:14:50 -05:00
file := share.SourceFile()
2020-01-27 00:34:39 -05:00
sourceID = file.ID
}
if sourceID == 0 {
2020-02-01 00:14:50 -05:00
// TODO 是否要在这里删除这个无效分享?
2020-01-27 00:34:39 -05:00
return false
}
return true
}
2020-02-01 00:14:50 -05:00
// Creator 获取分享的创建者
func (share *Share) Creator() *User {
2020-01-27 00:34:39 -05:00
if share.User.ID == 0 {
share.User, _ = GetUserByID(share.UserID)
}
return &share.User
}
2020-02-01 00:14:50 -05:00
// Source 返回源对象
func (share *Share) Source() interface{} {
if share.IsDir {
2020-02-01 00:14:50 -05:00
return share.SourceFolder()
}
2020-02-01 00:14:50 -05:00
return share.SourceFile()
}
2020-02-01 00:14:50 -05:00
// SourceFolder 获取源目录
func (share *Share) SourceFolder() *Folder {
2020-01-27 00:34:39 -05:00
if share.Folder.ID == 0 {
folders, _ := GetFoldersByIDs([]uint{share.SourceID}, share.UserID)
if len(folders) > 0 {
share.Folder = folders[0]
}
}
return &share.Folder
}
2020-02-01 00:14:50 -05:00
// SourceFile 获取源文件
func (share *Share) SourceFile() *File {
2020-01-27 00:34:39 -05:00
if share.File.ID == 0 {
files, _ := GetFilesByIDs([]uint{share.SourceID}, share.UserID)
if len(files) > 0 {
share.File = files[0]
}
}
return &share.File
}
// CanBeDownloadBy 返回此分享是否可以被给定用户下载
func (share *Share) CanBeDownloadBy(user *User) error {
// 用户组权限
2020-02-02 01:40:07 -05:00
if !user.Group.OptionsSerialized.ShareDownload {
if user.IsAnonymous() {
return errors.New("未登录用户无法下载")
}
return errors.New("您当前的用户组无权下载")
}
return nil
}
// WasDownloadedBy 返回分享是否已被用户下载过
func (share *Share) WasDownloadedBy(user *User, c *gin.Context) (exist bool) {
if user.IsAnonymous() {
exist = util.GetSession(c, fmt.Sprintf("share_%d_%d", share.ID, user.ID)) != nil
} else {
_, exist = cache.Get(fmt.Sprintf("share_%d_%d", share.ID, user.ID))
}
return exist
}
2020-03-11 02:22:21 -05:00
// DownloadBy 增加下载次数,匿名用户不会缓存
func (share *Share) DownloadBy(user *User, c *gin.Context) error {
if !share.WasDownloadedBy(user, c) {
share.Downloaded()
if !user.IsAnonymous() {
cache.Set(fmt.Sprintf("share_%d_%d", share.ID, user.ID), true,
GetIntSetting("share_download_session_timeout", 2073600))
} else {
util.SetSession(c, map[string]interface{}{fmt.Sprintf("share_%d_%d", share.ID, user.ID): true})
}
}
return nil
}
// Viewed 增加访问次数
func (share *Share) Viewed() {
share.Views++
DB.Model(share).UpdateColumn("views", gorm.Expr("views + ?", 1))
}
// Downloaded 增加下载次数
func (share *Share) Downloaded() {
share.Downloads++
if share.RemainDownloads > 0 {
share.RemainDownloads--
}
DB.Model(share).Updates(map[string]interface{}{
"downloads": share.Downloads,
"remain_downloads": share.RemainDownloads,
})
}
2020-02-12 22:53:24 -05:00
// Update 更新分享属性
func (share *Share) Update(props map[string]interface{}) error {
return DB.Model(share).Updates(props).Error
}
// Delete 删除分享
func (share *Share) Delete() error {
return DB.Model(share).Delete(share).Error
}
// DeleteShareBySourceIDs 根据原始资源类型和ID删除文件
func DeleteShareBySourceIDs(sources []uint, isDir bool) error {
return DB.Where("source_id in (?) and is_dir = ?", sources, isDir).Delete(&Share{}).Error
}
2020-02-12 22:53:24 -05:00
// ListShares 列出UID下的分享
func ListShares(uid uint, page, pageSize int, order string, publicOnly bool) ([]Share, int) {
var (
shares []Share
total int
)
dbChain := DB
dbChain = dbChain.Where("user_id = ?", uid)
if publicOnly {
2020-02-18 02:34:40 -05:00
dbChain = dbChain.Where("password = ?", "")
2020-02-12 22:53:24 -05:00
}
// 计算总数用于分页
dbChain.Model(&Share{}).Count(&total)
// 查询记录
dbChain.Limit(pageSize).Offset((page - 1) * pageSize).Order(order).Find(&shares)
return shares, total
}
// SearchShares 根据关键字搜索分享
func SearchShares(page, pageSize int, order, keywords string) ([]Share, int) {
var (
shares []Share
total int
)
2020-02-13 00:17:09 -05:00
keywordList := strings.Split(keywords, " ")
availableList := make([]string, 0, len(keywordList))
for i := 0; i < len(keywordList); i++ {
if len(keywordList[i]) > 0 {
availableList = append(availableList, keywordList[i])
}
}
if len(availableList) == 0 {
return shares, 0
}
2020-02-12 22:53:24 -05:00
dbChain := DB
2020-02-13 00:17:09 -05:00
dbChain = dbChain.Where("password = ? and remain_downloads <> 0 and (expires is NULL or expires > ?) and source_name like ?", "", time.Now(), "%"+strings.Join(availableList, "%")+"%")
2020-02-12 22:53:24 -05:00
// 计算总数用于分页
dbChain.Model(&Share{}).Count(&total)
// 查询记录
dbChain.Limit(pageSize).Offset((page - 1) * pageSize).Order(order).Find(&shares)
return shares, total
}