Feat: validate / cancel task while downloading file in aria2
This commit is contained in:
parent
8c7e3883ee
commit
3ed84ad5ec
7 changed files with 119 additions and 4 deletions
|
@ -22,6 +22,9 @@ type Download struct {
|
||||||
Dst string `gorm:"type:text"` // 用户文件系统存储父目录路径
|
Dst string `gorm:"type:text"` // 用户文件系统存储父目录路径
|
||||||
UserID uint // 发起者UID
|
UserID uint // 发起者UID
|
||||||
TaskID uint // 对应的转存任务ID
|
TaskID uint // 对应的转存任务ID
|
||||||
|
|
||||||
|
// 关联模型
|
||||||
|
User *User `gorm:"PRELOAD:false,association_autoupdate:false"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create 创建离线下载记录
|
// Create 创建离线下载记录
|
||||||
|
@ -48,3 +51,13 @@ func GetDownloadsByStatus(status ...int) []Download {
|
||||||
DB.Where("status in (?)", status).Find(&tasks)
|
DB.Where("status in (?)", status).Find(&tasks)
|
||||||
return tasks
|
return tasks
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetOwner 获取下载任务所属用户
|
||||||
|
func (task *Download) GetOwner() *User {
|
||||||
|
if task.User == nil {
|
||||||
|
if user, err := GetUserByID(task.UserID); err == nil {
|
||||||
|
return &user
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return task.User
|
||||||
|
}
|
||||||
|
|
|
@ -12,7 +12,7 @@ type Task struct {
|
||||||
Type int // 任务类型
|
Type int // 任务类型
|
||||||
UserID uint // 发起者UID,0表示为系统发起
|
UserID uint // 发起者UID,0表示为系统发起
|
||||||
Progress int // 进度
|
Progress int // 进度
|
||||||
Error string // 错误信息
|
Error string `gorm:"type:text"` // 错误信息
|
||||||
Props string `gorm:"type:text"` // 任务属性
|
Props string `gorm:"type:text"` // 任务属性
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,13 @@
|
||||||
package aria2
|
package aria2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
model "github.com/HFO4/cloudreve/models"
|
model "github.com/HFO4/cloudreve/models"
|
||||||
|
"github.com/HFO4/cloudreve/pkg/filesystem"
|
||||||
|
"github.com/HFO4/cloudreve/pkg/filesystem/driver/local"
|
||||||
|
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
|
||||||
"github.com/HFO4/cloudreve/pkg/task"
|
"github.com/HFO4/cloudreve/pkg/task"
|
||||||
"github.com/HFO4/cloudreve/pkg/util"
|
"github.com/HFO4/cloudreve/pkg/util"
|
||||||
"github.com/zyxar/argo/rpc"
|
"github.com/zyxar/argo/rpc"
|
||||||
|
@ -71,9 +75,18 @@ func (monitor *Monitor) Update() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 磁力链下载需要跟随
|
||||||
|
if len(status.FollowedBy) > 0 {
|
||||||
|
util.Log().Debug("离线下载[%s]重定向至[%s]", monitor.Task.GID, status.FollowedBy[0])
|
||||||
|
monitor.Task.GID = status.FollowedBy[0]
|
||||||
|
monitor.Task.Save()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// 更新任务信息
|
// 更新任务信息
|
||||||
if err := monitor.UpdateTaskInfo(status); err != nil {
|
if err := monitor.UpdateTaskInfo(status); err != nil {
|
||||||
util.Log().Warning("无法更新下载任务[%s]的任务信息[%s],", monitor.Task.GID, err)
|
util.Log().Warning("无法更新下载任务[%s]的任务信息[%s],", monitor.Task.GID, err)
|
||||||
|
monitor.setErrorStatus(err)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -96,6 +109,9 @@ func (monitor *Monitor) Update() bool {
|
||||||
|
|
||||||
// UpdateTaskInfo 更新数据库中的任务信息
|
// UpdateTaskInfo 更新数据库中的任务信息
|
||||||
func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error {
|
func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error {
|
||||||
|
originSize := monitor.Task.TotalSize
|
||||||
|
originPath := monitor.Task.Path
|
||||||
|
|
||||||
monitor.Task.GID = status.Gid
|
monitor.Task.GID = status.Gid
|
||||||
monitor.Task.Status = getStatus(status.Status)
|
monitor.Task.Status = getStatus(status.Status)
|
||||||
|
|
||||||
|
@ -126,7 +142,68 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error {
|
||||||
attrs, _ := json.Marshal(status)
|
attrs, _ := json.Marshal(status)
|
||||||
monitor.Task.Attrs = string(attrs)
|
monitor.Task.Attrs = string(attrs)
|
||||||
|
|
||||||
return monitor.Task.Save()
|
if err := monitor.Task.Save(); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if originSize != monitor.Task.TotalSize || originPath != monitor.Task.Path {
|
||||||
|
// 大小、文件名更新后,对文件限制等进行校验
|
||||||
|
if err := monitor.ValidateFile(); err != nil {
|
||||||
|
// 验证失败时取消任务
|
||||||
|
monitor.Cancel()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel 取消上传并尝试删除临时文件
|
||||||
|
func (monitor *Monitor) Cancel() {
|
||||||
|
if err := Instance.Cancel(monitor.Task); err != nil {
|
||||||
|
util.Log().Warning("无法取消离线下载任务[%s], %s", monitor.Task.GID, err)
|
||||||
|
}
|
||||||
|
util.Log().Debug("离线下载任务[%s]已取消,1 分钟后删除临时文件", monitor.Task.GID)
|
||||||
|
go func(monitor *Monitor) {
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Duration(60) * time.Second):
|
||||||
|
monitor.RemoveTempFolder()
|
||||||
|
}
|
||||||
|
}(monitor)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateFile 上传过程中校验文件大小、文件名
|
||||||
|
func (monitor *Monitor) ValidateFile() error {
|
||||||
|
// 找到任务创建者
|
||||||
|
user := monitor.Task.GetOwner()
|
||||||
|
if user == nil {
|
||||||
|
return ErrUserNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建文件系统
|
||||||
|
fs, err := filesystem.NewFileSystem(user)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer fs.Recycle()
|
||||||
|
|
||||||
|
// 创建上下文环境
|
||||||
|
ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, local.FileStream{
|
||||||
|
Size: monitor.Task.TotalSize,
|
||||||
|
Name: filepath.Base(monitor.Task.Path),
|
||||||
|
})
|
||||||
|
|
||||||
|
// 验证文件
|
||||||
|
if err := filesystem.HookValidateFile(ctx, fs); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证用户容量
|
||||||
|
if err := filesystem.HookValidateCapacityWithoutIncrease(ctx, fs); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error 任务下载出错处理,返回是否中断监控
|
// Error 任务下载出错处理,返回是否中断监控
|
||||||
|
|
|
@ -20,6 +20,8 @@ type Aria2 interface {
|
||||||
CreateTask(task *model.Download) error
|
CreateTask(task *model.Download) error
|
||||||
// 返回状态信息
|
// 返回状态信息
|
||||||
Status(task *model.Download) (rpc.StatusInfo, error)
|
Status(task *model.Download) (rpc.StatusInfo, error)
|
||||||
|
// 取消任务
|
||||||
|
Cancel(task *model.Download) error
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -48,7 +50,8 @@ const (
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// ErrNotEnabled 功能未开启错误
|
// ErrNotEnabled 功能未开启错误
|
||||||
ErrNotEnabled = serializer.NewError(serializer.CodeNoPermissionErr, "离线下载功能未开启", nil)
|
ErrNotEnabled = serializer.NewError(serializer.CodeNoPermissionErr, "离线下载功能未开启", nil)
|
||||||
|
ErrUserNotFound = serializer.NewError(serializer.CodeNotFound, "无法找到任务创建者", nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
// DummyAria2 未开启Aria2功能时使用的默认处理器
|
// DummyAria2 未开启Aria2功能时使用的默认处理器
|
||||||
|
@ -65,6 +68,11 @@ func (instance *DummyAria2) Status(task *model.Download) (rpc.StatusInfo, error)
|
||||||
return rpc.StatusInfo{}, ErrNotEnabled
|
return rpc.StatusInfo{}, ErrNotEnabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cancel 返回未开启错误
|
||||||
|
func (instance *DummyAria2) Cancel(task *model.Download) error {
|
||||||
|
return ErrNotEnabled
|
||||||
|
}
|
||||||
|
|
||||||
// Init 初始化
|
// Init 初始化
|
||||||
func Init() {
|
func Init() {
|
||||||
options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options")
|
options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options")
|
||||||
|
|
|
@ -40,6 +40,12 @@ func (client *RPCService) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||||
return client.caller.TellStatus(task.GID)
|
return client.caller.TellStatus(task.GID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cancel 取消下载
|
||||||
|
func (client *RPCService) Cancel(task *model.Download) error {
|
||||||
|
_, err := client.caller.Remove(task.GID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// CreateTask 创建新任务
|
// CreateTask 创建新任务
|
||||||
func (client *RPCService) CreateTask(task *model.Download) error {
|
func (client *RPCService) CreateTask(task *model.Download) error {
|
||||||
// 生成存储路径
|
// 生成存储路径
|
||||||
|
|
|
@ -153,7 +153,7 @@ func (client *Client) UploadChunk(ctx context.Context, uploadURL string, chunk *
|
||||||
// 如果重试次数小于限制,5秒后重试
|
// 如果重试次数小于限制,5秒后重试
|
||||||
if chunk.Retried < model.GetIntSetting("onedrive_chunk_retries", 1) {
|
if chunk.Retried < model.GetIntSetting("onedrive_chunk_retries", 1) {
|
||||||
chunk.Retried++
|
chunk.Retried++
|
||||||
util.Log().Debug("分片偏移%d上传失败,5秒钟后重试", chunk.Offset)
|
util.Log().Debug("分片偏移%d上传失败[%s],5秒钟后重试", chunk.Offset, err)
|
||||||
time.Sleep(time.Duration(5) * time.Second)
|
time.Sleep(time.Duration(5) * time.Second)
|
||||||
return client.UploadChunk(ctx, uploadURL, chunk)
|
return client.UploadChunk(ctx, uploadURL, chunk)
|
||||||
}
|
}
|
||||||
|
@ -518,6 +518,7 @@ func (client *Client) request(ctx context.Context, method string, url string, bo
|
||||||
if res.Response.StatusCode < 200 || res.Response.StatusCode >= 300 {
|
if res.Response.StatusCode < 200 || res.Response.StatusCode >= 300 {
|
||||||
decodeErr = json.Unmarshal([]byte(respBody), &errResp)
|
decodeErr = json.Unmarshal([]byte(respBody), &errResp)
|
||||||
if decodeErr != nil {
|
if decodeErr != nil {
|
||||||
|
util.Log().Debug("Onedrive返回未知响应[%s]", respBody)
|
||||||
return "", sysError(decodeErr)
|
return "", sysError(decodeErr)
|
||||||
}
|
}
|
||||||
return "", &errResp
|
return "", &errResp
|
||||||
|
|
|
@ -128,6 +128,16 @@ func HookValidateCapacity(ctx context.Context, fs *FileSystem) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HookValidateCapacityWithoutIncrease 验证用户容量,不扣除
|
||||||
|
func HookValidateCapacityWithoutIncrease(ctx context.Context, fs *FileSystem) error {
|
||||||
|
file := ctx.Value(fsctx.FileHeaderCtx).(FileHeader)
|
||||||
|
// 验证并扣除容量
|
||||||
|
if fs.User.GetRemainingCapacity() < file.GetSize() {
|
||||||
|
return ErrInsufficientCapacity
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// HookChangeCapacity 根据原有文件和新文件的大小更新用户容量
|
// HookChangeCapacity 根据原有文件和新文件的大小更新用户容量
|
||||||
func HookChangeCapacity(ctx context.Context, fs *FileSystem) error {
|
func HookChangeCapacity(ctx context.Context, fs *FileSystem) error {
|
||||||
newFile := ctx.Value(fsctx.FileHeaderCtx).(FileHeader)
|
newFile := ctx.Value(fsctx.FileHeaderCtx).(FileHeader)
|
||||||
|
|
Loading…
Add table
Reference in a new issue