Feat(remotearia2): add task

This commit is contained in:
Cian John 2021-03-20 16:30:48 +08:00
parent 6fb419d998
commit fa56d81381
13 changed files with 294 additions and 31 deletions

View file

@ -28,6 +28,11 @@ func Init(path string) {
email.Init()
crontab.Init()
InitStatic()
} else {
if conf.SlaveConfig.Aria2 {
model.Init()
aria2.Init(false)
}
}
auth.Init()
}

View file

@ -100,6 +100,13 @@ func GetDownloadByGid(gid string, uid uint) (*Download, error) {
return download, result.Error
}
// GetDownloadById 根据ID查找下载
func GetDownloadById(id uint) (*Download, error) {
var download Download
result := DB.First(&download, id)
return &download, result.Error
}
// GetOwner 获取下载任务所属用户
func (task *Download) GetOwner() *User {
if task.User == nil {

View file

@ -136,6 +136,8 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
{Name: "aria2_temp_path", Value: ``, Type: "aria2"},
{Name: "aria2_options", Value: `{}`, Type: "aria2"},
{Name: "aria2_interval", Value: `60`, Type: "aria2"},
{Name: "aria2_remote_enabled", Value: `0`, Type: "aria2"},
{Name: "aria2_remote_id", Value: `0`, Type: "aria2"},
{Name: "max_worker_num", Value: `10`, Type: "task"},
{Name: "max_parallel_transfer", Value: `4`, Type: "task"},
{Name: "secret_key", Value: util.RandStringRunes(256), Type: "auth"},

View file

@ -2,6 +2,7 @@ package aria2
import (
"encoding/json"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"net/url"
"sync"
@ -89,6 +90,21 @@ func (instance *DummyAria2) Select(task *model.Download, files []int) error {
// Init 初始化
func Init(isReload bool) {
if conf.SystemConfig.Mode == "master" {
MasterInit(isReload)
} else {
SlaveInit(isReload)
}
}
// SlaveInit 从机初始化
func SlaveInit(isReload bool) {
if !model.IsTrueVal(model.GetSettingByName("aria2_remote_enabled")) {
return
}
if conf.SlaveConfig.SlaveId == 0 || model.GetIntSetting("aria2_remote_id", 0) != int(conf.SlaveConfig.SlaveId) {
return
}
Lock.Lock()
defer Lock.Unlock()
@ -136,16 +152,82 @@ func Init(isReload bool) {
Instance = client
if !isReload {
// 从数据库中读取未完成任务,创建监控
unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading)
// monitor
}
for i := 0; i < len(unfinished); i++ {
// 创建任务监控
NewMonitor(&unfinished[i])
// MasterInit 主机初始化
func MasterInit(isReload bool) {
Lock.Lock()
defer Lock.Unlock()
if !model.IsTrueVal(model.GetSettingByName("aria2_remote_enabled")) {
// 关闭上个初始连接
if previousClient, ok := Instance.(*RPCService); ok {
if previousClient.Caller != nil {
util.Log().Debug("关闭上个 aria2 连接")
previousClient.Caller.Close()
}
}
}
options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options")
timeout := model.GetIntSetting("aria2_call_timeout", 5)
if options["aria2_rpcurl"] == "" {
Instance = &DummyAria2{}
return
}
util.Log().Info("初始化 aria2 RPC 服务[%s]", options["aria2_rpcurl"])
client := &RPCService{}
// 解析RPC服务地址
server, err := url.Parse(options["aria2_rpcurl"])
if err != nil {
util.Log().Warning("无法解析 aria2 RPC 服务地址,%s", err)
Instance = &DummyAria2{}
return
}
server.Path = "/jsonrpc"
// 加载自定义下载配置
var globalOptions map[string]interface{}
err = json.Unmarshal([]byte(options["aria2_options"]), &globalOptions)
if err != nil {
util.Log().Warning("无法解析 aria2 全局配置,%s", err)
Instance = &DummyAria2{}
return
}
if err := client.Init(server.String(), options["aria2_token"], timeout, globalOptions); err != nil {
util.Log().Warning("初始化 aria2 RPC 服务失败,%s", err)
Instance = &DummyAria2{}
return
}
Instance = client
if !isReload {
// 从数据库中读取未完成任务,创建监控
unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading)
for i := 0; i < len(unfinished); i++ {
// 创建任务监控
NewMonitor(&unfinished[i])
}
}
} else {
util.Log().Info("初始化 从机 aria2 RPC 服务")
remote, err := model.GetPolicyByID(uint(model.GetIntSetting("aria2_remote_id", 0)))
if err != nil {
util.Log().Warning("初始化 从机 aria2 RPC 服务失败,%s", err)
Instance = &DummyAria2{}
return
}
client := &RemoteService{}
client.Init(&remote)
Instance = client
}
}
// getStatus 将给定的状态字符串转换为状态标识数字

View file

@ -111,7 +111,7 @@ func (client *RPCService) CreateTask(task *model.Download, groupOptions map[stri
// 保存到数据库
task.GID = gid
_, err = task.Create()
err = task.Save()
if err != nil {
return err
}

103
pkg/aria2/remote_caller.go Normal file
View file

@ -0,0 +1,103 @@
package aria2
import (
"encoding/json"
"errors"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"net/url"
"path"
"strings"
)
// RemoteService 通过从机RPC服务的Aria2任务管理器
type RemoteService struct {
Policy *model.Policy
Client request.Client
AuthInstance auth.Auth
}
func (client *RemoteService) Init(policy *model.Policy) {
client.Policy = policy
client.Client = request.HTTPClient{}
client.AuthInstance = auth.HMACAuth{SecretKey: []byte(client.Policy.SecretKey)}
}
func (client *RemoteService) CreateTask(task *model.Download, options map[string]interface{}) error {
reqBody := serializer.RemoteAria2AddRequest{
TaskId: task.ID,
Options: options,
}
reqBodyEncoded, err := json.Marshal(reqBody)
if err != nil {
return err
}
// 发送列表请求
bodyReader := strings.NewReader(string(reqBodyEncoded))
signTTL := model.GetIntSetting("slave_api_timeout", 60)
resp, err := client.Client.Request(
"POST",
client.getAPIUrl("add"),
bodyReader,
request.WithCredential(client.AuthInstance, int64(signTTL)),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return err
}
// 处理列取结果
if resp.Code != 0 {
return errors.New(resp.Error)
}
if resStr, ok := resp.Data.(string); ok {
var res serializer.Response
err = json.Unmarshal([]byte(resStr), &res)
if err != nil {
return err
}
if res.Code != 0 {
return errors.New(res.Msg)
}
}
return nil
}
func (client *RemoteService) Status(task *model.Download) (rpc.StatusInfo, error) {
panic("implement me")
}
func (client *RemoteService) Cancel(task *model.Download) error {
panic("implement me")
}
func (client *RemoteService) Select(task *model.Download, files []int) error {
panic("implement me")
}
// getAPIUrl 获取接口请求地址
func (client *RemoteService) getAPIUrl(scope string, routes ...string) string {
serverURL, err := url.Parse(client.Policy.Server)
if err != nil {
return ""
}
var controller *url.URL
switch scope {
case "add":
controller, _ = url.Parse("/api/v3/slave/aria2/add")
default:
controller = serverURL
}
for _, r := range routes {
controller.Path = path.Join(controller.Path, r)
}
return serverURL.ResolveReference(controller).String()
}

View file

@ -42,6 +42,8 @@ type slave struct {
Secret string `validate:"omitempty,gte=64"`
CallbackTimeout int `validate:"omitempty,gte=1"`
SignatureTTL int `validate:"omitempty,gte=1"`
SlaveId uint `validate:"omitempty"`
Aria2 bool `validate:"omitempty"`
}
// captcha 验证码配置

View file

@ -10,3 +10,8 @@ type ListRequest struct {
Path string `json:"path"`
Recursive bool `json:"recursive"`
}
type RemoteAria2AddRequest struct {
TaskId uint `json:"task_id"`
Options map[string]interface{} `json:"options"`
}

View file

@ -2,6 +2,7 @@ package controllers
import (
"context"
"github.com/cloudreve/Cloudreve/v3/service/slave"
"net/url"
"strconv"
@ -175,3 +176,14 @@ func SlaveList(c *gin.Context) {
c.JSON(200, ErrorResponse(err))
}
}
// SlaveAria2Add 从机创建远程下载任务
func SlaveAria2Add(c *gin.Context) {
var service slave.Aria2AddService
if err := c.ShouldBindJSON(&service); err == nil {
res := service.Add()
c.JSON(200, res)
} else {
c.JSON(200, ErrorResponse(err))
}
}

View file

@ -50,6 +50,12 @@ func InitSlaveRouter() *gin.Engine {
// 列出文件
v3.POST("list", controllers.SlaveList)
}
aria2 := v3.Group("aria2")
aria2.Use(middleware.SignRequired())
{
aria2.POST("add", controllers.SlaveAria2Add)
}
return r
}

View file

@ -1,6 +1,7 @@
package admin
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"net/url"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
@ -15,29 +16,34 @@ type Aria2TestService struct {
// Test 测试aria2连接
func (service *Aria2TestService) Test() serializer.Response {
testRPC := aria2.RPCService{}
if !model.IsTrueVal(model.GetSettingByName("aria2_remote_enabled")) {
testRPC := aria2.RPCService{}
// 解析RPC服务地址
server, err := url.Parse(service.Server)
if err != nil {
return serializer.ParamErr("无法解析 aria2 RPC 服务地址, "+err.Error(), nil)
// 解析RPC服务地址
server, err := url.Parse(service.Server)
if err != nil {
return serializer.ParamErr("无法解析 aria2 RPC 服务地址, "+err.Error(), nil)
}
server.Path = "/jsonrpc"
if err := testRPC.Init(server.String(), service.Token, 5, map[string]interface{}{}); err != nil {
return serializer.ParamErr("无法初始化连接, "+err.Error(), nil)
}
defer testRPC.Caller.Close()
info, err := testRPC.Caller.GetVersion()
if err != nil {
return serializer.ParamErr("无法请求 RPC 服务, "+err.Error(), nil)
}
if info.Version == "" {
return serializer.ParamErr("RPC 服务返回非预期响应", nil)
}
return serializer.Response{Data: info.Version}
} else {
// TODO
return serializer.Response{Data: "TODO"}
}
server.Path = "/jsonrpc"
if err := testRPC.Init(server.String(), service.Token, 5, map[string]interface{}{}); err != nil {
return serializer.ParamErr("无法初始化连接, "+err.Error(), nil)
}
defer testRPC.Caller.Close()
info, err := testRPC.Caller.GetVersion()
if err != nil {
return serializer.ParamErr("无法请求 RPC 服务, "+err.Error(), nil)
}
if info.Version == "" {
return serializer.ParamErr("RPC 服务返回非预期响应", nil)
}
return serializer.Response{Data: info.Version}
}

View file

@ -42,6 +42,11 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo
Source: service.URL,
}
_, err = task.Create()
if err != nil {
return serializer.Err(serializer.CodeNotSet, "任务创建失败", err)
}
aria2.Lock.RLock()
if err := aria2.Instance.CreateTask(task, fs.User.Group.OptionsSerialized.Aria2Options); err != nil {
aria2.Lock.RUnlock()

28
service/slave/aria2.go Normal file
View file

@ -0,0 +1,28 @@
package slave
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
type Aria2AddService struct {
TaskId uint `json:"task_id"`
Options map[string]interface{} `json:"options"`
}
func (service *Aria2AddService) Add() serializer.Response {
task, err := model.GetDownloadById(service.TaskId)
if err != nil {
util.Log().Warning("无法获取记录, %s", err)
return serializer.Err(serializer.CodeNotSet, "任务创建失败, 无法获取记录", err)
}
aria2.Lock.RLock()
if err := aria2.Instance.CreateTask(task, service.Options); err != nil {
aria2.Lock.RUnlock()
return serializer.Err(serializer.CodeNotSet, "任务创建失败", err)
}
aria2.Lock.RUnlock()
return serializer.Response{}
}