Feat(remotearia2): add task
This commit is contained in:
parent
6fb419d998
commit
fa56d81381
13 changed files with 294 additions and 31 deletions
|
@ -28,6 +28,11 @@ func Init(path string) {
|
|||
email.Init()
|
||||
crontab.Init()
|
||||
InitStatic()
|
||||
} else {
|
||||
if conf.SlaveConfig.Aria2 {
|
||||
model.Init()
|
||||
aria2.Init(false)
|
||||
}
|
||||
}
|
||||
auth.Init()
|
||||
}
|
||||
|
|
|
@ -100,6 +100,13 @@ func GetDownloadByGid(gid string, uid uint) (*Download, error) {
|
|||
return download, result.Error
|
||||
}
|
||||
|
||||
// GetDownloadById 根据ID查找下载
|
||||
func GetDownloadById(id uint) (*Download, error) {
|
||||
var download Download
|
||||
result := DB.First(&download, id)
|
||||
return &download, result.Error
|
||||
}
|
||||
|
||||
// GetOwner 获取下载任务所属用户
|
||||
func (task *Download) GetOwner() *User {
|
||||
if task.User == nil {
|
||||
|
|
|
@ -136,6 +136,8 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
|
|||
{Name: "aria2_temp_path", Value: ``, Type: "aria2"},
|
||||
{Name: "aria2_options", Value: `{}`, Type: "aria2"},
|
||||
{Name: "aria2_interval", Value: `60`, Type: "aria2"},
|
||||
{Name: "aria2_remote_enabled", Value: `0`, Type: "aria2"},
|
||||
{Name: "aria2_remote_id", Value: `0`, Type: "aria2"},
|
||||
{Name: "max_worker_num", Value: `10`, Type: "task"},
|
||||
{Name: "max_parallel_transfer", Value: `4`, Type: "task"},
|
||||
{Name: "secret_key", Value: util.RandStringRunes(256), Type: "auth"},
|
||||
|
|
|
@ -2,6 +2,7 @@ package aria2
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
|
@ -89,6 +90,21 @@ func (instance *DummyAria2) Select(task *model.Download, files []int) error {
|
|||
|
||||
// Init 初始化
|
||||
func Init(isReload bool) {
|
||||
if conf.SystemConfig.Mode == "master" {
|
||||
MasterInit(isReload)
|
||||
} else {
|
||||
SlaveInit(isReload)
|
||||
}
|
||||
}
|
||||
|
||||
// SlaveInit 从机初始化
|
||||
func SlaveInit(isReload bool) {
|
||||
if !model.IsTrueVal(model.GetSettingByName("aria2_remote_enabled")) {
|
||||
return
|
||||
}
|
||||
if conf.SlaveConfig.SlaveId == 0 || model.GetIntSetting("aria2_remote_id", 0) != int(conf.SlaveConfig.SlaveId) {
|
||||
return
|
||||
}
|
||||
Lock.Lock()
|
||||
defer Lock.Unlock()
|
||||
|
||||
|
@ -136,16 +152,82 @@ func Init(isReload bool) {
|
|||
|
||||
Instance = client
|
||||
|
||||
if !isReload {
|
||||
// 从数据库中读取未完成任务,创建监控
|
||||
unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading)
|
||||
// monitor
|
||||
}
|
||||
|
||||
for i := 0; i < len(unfinished); i++ {
|
||||
// 创建任务监控
|
||||
NewMonitor(&unfinished[i])
|
||||
// MasterInit 主机初始化
|
||||
func MasterInit(isReload bool) {
|
||||
Lock.Lock()
|
||||
defer Lock.Unlock()
|
||||
|
||||
if !model.IsTrueVal(model.GetSettingByName("aria2_remote_enabled")) {
|
||||
// 关闭上个初始连接
|
||||
if previousClient, ok := Instance.(*RPCService); ok {
|
||||
if previousClient.Caller != nil {
|
||||
util.Log().Debug("关闭上个 aria2 连接")
|
||||
previousClient.Caller.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options")
|
||||
timeout := model.GetIntSetting("aria2_call_timeout", 5)
|
||||
if options["aria2_rpcurl"] == "" {
|
||||
Instance = &DummyAria2{}
|
||||
return
|
||||
}
|
||||
|
||||
util.Log().Info("初始化 aria2 RPC 服务[%s]", options["aria2_rpcurl"])
|
||||
client := &RPCService{}
|
||||
|
||||
// 解析RPC服务地址
|
||||
server, err := url.Parse(options["aria2_rpcurl"])
|
||||
if err != nil {
|
||||
util.Log().Warning("无法解析 aria2 RPC 服务地址,%s", err)
|
||||
Instance = &DummyAria2{}
|
||||
return
|
||||
}
|
||||
server.Path = "/jsonrpc"
|
||||
|
||||
// 加载自定义下载配置
|
||||
var globalOptions map[string]interface{}
|
||||
err = json.Unmarshal([]byte(options["aria2_options"]), &globalOptions)
|
||||
if err != nil {
|
||||
util.Log().Warning("无法解析 aria2 全局配置,%s", err)
|
||||
Instance = &DummyAria2{}
|
||||
return
|
||||
}
|
||||
|
||||
if err := client.Init(server.String(), options["aria2_token"], timeout, globalOptions); err != nil {
|
||||
util.Log().Warning("初始化 aria2 RPC 服务失败,%s", err)
|
||||
Instance = &DummyAria2{}
|
||||
return
|
||||
}
|
||||
|
||||
Instance = client
|
||||
|
||||
if !isReload {
|
||||
// 从数据库中读取未完成任务,创建监控
|
||||
unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading)
|
||||
|
||||
for i := 0; i < len(unfinished); i++ {
|
||||
// 创建任务监控
|
||||
NewMonitor(&unfinished[i])
|
||||
}
|
||||
}
|
||||
} else {
|
||||
util.Log().Info("初始化 从机 aria2 RPC 服务")
|
||||
remote, err := model.GetPolicyByID(uint(model.GetIntSetting("aria2_remote_id", 0)))
|
||||
if err != nil {
|
||||
util.Log().Warning("初始化 从机 aria2 RPC 服务失败,%s", err)
|
||||
Instance = &DummyAria2{}
|
||||
return
|
||||
}
|
||||
|
||||
client := &RemoteService{}
|
||||
|
||||
client.Init(&remote)
|
||||
Instance = client
|
||||
}
|
||||
}
|
||||
|
||||
// getStatus 将给定的状态字符串转换为状态标识数字
|
||||
|
|
|
@ -111,7 +111,7 @@ func (client *RPCService) CreateTask(task *model.Download, groupOptions map[stri
|
|||
|
||||
// 保存到数据库
|
||||
task.GID = gid
|
||||
_, err = task.Create()
|
||||
err = task.Save()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
103
pkg/aria2/remote_caller.go
Normal file
103
pkg/aria2/remote_caller.go
Normal file
|
@ -0,0 +1,103 @@
|
|||
package aria2
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// RemoteService 通过从机RPC服务的Aria2任务管理器
|
||||
type RemoteService struct {
|
||||
Policy *model.Policy
|
||||
Client request.Client
|
||||
AuthInstance auth.Auth
|
||||
}
|
||||
|
||||
func (client *RemoteService) Init(policy *model.Policy) {
|
||||
client.Policy = policy
|
||||
client.Client = request.HTTPClient{}
|
||||
client.AuthInstance = auth.HMACAuth{SecretKey: []byte(client.Policy.SecretKey)}
|
||||
}
|
||||
|
||||
func (client *RemoteService) CreateTask(task *model.Download, options map[string]interface{}) error {
|
||||
reqBody := serializer.RemoteAria2AddRequest{
|
||||
TaskId: task.ID,
|
||||
Options: options,
|
||||
}
|
||||
reqBodyEncoded, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 发送列表请求
|
||||
bodyReader := strings.NewReader(string(reqBodyEncoded))
|
||||
signTTL := model.GetIntSetting("slave_api_timeout", 60)
|
||||
resp, err := client.Client.Request(
|
||||
"POST",
|
||||
client.getAPIUrl("add"),
|
||||
bodyReader,
|
||||
request.WithCredential(client.AuthInstance, int64(signTTL)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return errors.New(resp.Error)
|
||||
}
|
||||
|
||||
if resStr, ok := resp.Data.(string); ok {
|
||||
var res serializer.Response
|
||||
err = json.Unmarshal([]byte(resStr), &res)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if res.Code != 0 {
|
||||
return errors.New(res.Msg)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *RemoteService) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (client *RemoteService) Cancel(task *model.Download) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (client *RemoteService) Select(task *model.Download, files []int) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
// getAPIUrl 获取接口请求地址
|
||||
func (client *RemoteService) getAPIUrl(scope string, routes ...string) string {
|
||||
serverURL, err := url.Parse(client.Policy.Server)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
var controller *url.URL
|
||||
|
||||
switch scope {
|
||||
case "add":
|
||||
controller, _ = url.Parse("/api/v3/slave/aria2/add")
|
||||
default:
|
||||
controller = serverURL
|
||||
}
|
||||
|
||||
for _, r := range routes {
|
||||
controller.Path = path.Join(controller.Path, r)
|
||||
}
|
||||
|
||||
return serverURL.ResolveReference(controller).String()
|
||||
}
|
|
@ -42,6 +42,8 @@ type slave struct {
|
|||
Secret string `validate:"omitempty,gte=64"`
|
||||
CallbackTimeout int `validate:"omitempty,gte=1"`
|
||||
SignatureTTL int `validate:"omitempty,gte=1"`
|
||||
SlaveId uint `validate:"omitempty"`
|
||||
Aria2 bool `validate:"omitempty"`
|
||||
}
|
||||
|
||||
// captcha 验证码配置
|
||||
|
|
|
@ -10,3 +10,8 @@ type ListRequest struct {
|
|||
Path string `json:"path"`
|
||||
Recursive bool `json:"recursive"`
|
||||
}
|
||||
|
||||
type RemoteAria2AddRequest struct {
|
||||
TaskId uint `json:"task_id"`
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package controllers
|
|||
|
||||
import (
|
||||
"context"
|
||||
"github.com/cloudreve/Cloudreve/v3/service/slave"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
|
@ -175,3 +176,14 @@ func SlaveList(c *gin.Context) {
|
|||
c.JSON(200, ErrorResponse(err))
|
||||
}
|
||||
}
|
||||
|
||||
// SlaveAria2Add 从机创建远程下载任务
|
||||
func SlaveAria2Add(c *gin.Context) {
|
||||
var service slave.Aria2AddService
|
||||
if err := c.ShouldBindJSON(&service); err == nil {
|
||||
res := service.Add()
|
||||
c.JSON(200, res)
|
||||
} else {
|
||||
c.JSON(200, ErrorResponse(err))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -50,6 +50,12 @@ func InitSlaveRouter() *gin.Engine {
|
|||
// 列出文件
|
||||
v3.POST("list", controllers.SlaveList)
|
||||
}
|
||||
|
||||
aria2 := v3.Group("aria2")
|
||||
aria2.Use(middleware.SignRequired())
|
||||
{
|
||||
aria2.POST("add", controllers.SlaveAria2Add)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package admin
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"net/url"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
|
||||
|
@ -15,29 +16,34 @@ type Aria2TestService struct {
|
|||
|
||||
// Test 测试aria2连接
|
||||
func (service *Aria2TestService) Test() serializer.Response {
|
||||
testRPC := aria2.RPCService{}
|
||||
if !model.IsTrueVal(model.GetSettingByName("aria2_remote_enabled")) {
|
||||
testRPC := aria2.RPCService{}
|
||||
|
||||
// 解析RPC服务地址
|
||||
server, err := url.Parse(service.Server)
|
||||
if err != nil {
|
||||
return serializer.ParamErr("无法解析 aria2 RPC 服务地址, "+err.Error(), nil)
|
||||
// 解析RPC服务地址
|
||||
server, err := url.Parse(service.Server)
|
||||
if err != nil {
|
||||
return serializer.ParamErr("无法解析 aria2 RPC 服务地址, "+err.Error(), nil)
|
||||
}
|
||||
server.Path = "/jsonrpc"
|
||||
|
||||
if err := testRPC.Init(server.String(), service.Token, 5, map[string]interface{}{}); err != nil {
|
||||
return serializer.ParamErr("无法初始化连接, "+err.Error(), nil)
|
||||
}
|
||||
|
||||
defer testRPC.Caller.Close()
|
||||
|
||||
info, err := testRPC.Caller.GetVersion()
|
||||
if err != nil {
|
||||
return serializer.ParamErr("无法请求 RPC 服务, "+err.Error(), nil)
|
||||
}
|
||||
|
||||
if info.Version == "" {
|
||||
return serializer.ParamErr("RPC 服务返回非预期响应", nil)
|
||||
}
|
||||
|
||||
return serializer.Response{Data: info.Version}
|
||||
} else {
|
||||
// TODO
|
||||
return serializer.Response{Data: "TODO"}
|
||||
}
|
||||
server.Path = "/jsonrpc"
|
||||
|
||||
if err := testRPC.Init(server.String(), service.Token, 5, map[string]interface{}{}); err != nil {
|
||||
return serializer.ParamErr("无法初始化连接, "+err.Error(), nil)
|
||||
}
|
||||
|
||||
defer testRPC.Caller.Close()
|
||||
|
||||
info, err := testRPC.Caller.GetVersion()
|
||||
if err != nil {
|
||||
return serializer.ParamErr("无法请求 RPC 服务, "+err.Error(), nil)
|
||||
}
|
||||
|
||||
if info.Version == "" {
|
||||
return serializer.ParamErr("RPC 服务返回非预期响应", nil)
|
||||
}
|
||||
|
||||
return serializer.Response{Data: info.Version}
|
||||
}
|
||||
|
|
|
@ -42,6 +42,11 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo
|
|||
Source: service.URL,
|
||||
}
|
||||
|
||||
_, err = task.Create()
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeNotSet, "任务创建失败", err)
|
||||
}
|
||||
|
||||
aria2.Lock.RLock()
|
||||
if err := aria2.Instance.CreateTask(task, fs.User.Group.OptionsSerialized.Aria2Options); err != nil {
|
||||
aria2.Lock.RUnlock()
|
||||
|
|
28
service/slave/aria2.go
Normal file
28
service/slave/aria2.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package slave
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
type Aria2AddService struct {
|
||||
TaskId uint `json:"task_id"`
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
func (service *Aria2AddService) Add() serializer.Response {
|
||||
task, err := model.GetDownloadById(service.TaskId)
|
||||
if err != nil {
|
||||
util.Log().Warning("无法获取记录, %s", err)
|
||||
return serializer.Err(serializer.CodeNotSet, "任务创建失败, 无法获取记录", err)
|
||||
}
|
||||
aria2.Lock.RLock()
|
||||
if err := aria2.Instance.CreateTask(task, service.Options); err != nil {
|
||||
aria2.Lock.RUnlock()
|
||||
return serializer.Err(serializer.CodeNotSet, "任务创建失败", err)
|
||||
}
|
||||
aria2.Lock.RUnlock()
|
||||
return serializer.Response{}
|
||||
}
|
Loading…
Add table
Reference in a new issue