From fa56d81381b5697714af7b44da3e31e3d3fb4f33 Mon Sep 17 00:00:00 2001 From: Cian John Date: Sat, 20 Mar 2021 16:30:48 +0800 Subject: [PATCH] Feat(remotearia2): add task --- bootstrap/init.go | 5 ++ models/download.go | 7 +++ models/migration.go | 2 + pkg/aria2/aria2.go | 96 +++++++++++++++++++++++++++++--- pkg/aria2/caller.go | 2 +- pkg/aria2/remote_caller.go | 103 +++++++++++++++++++++++++++++++++++ pkg/conf/conf.go | 2 + pkg/serializer/slave.go | 5 ++ routers/controllers/slave.go | 12 ++++ routers/router.go | 6 ++ service/admin/aria2.go | 52 ++++++++++-------- service/aria2/add.go | 5 ++ service/slave/aria2.go | 28 ++++++++++ 13 files changed, 294 insertions(+), 31 deletions(-) create mode 100644 pkg/aria2/remote_caller.go create mode 100644 service/slave/aria2.go diff --git a/bootstrap/init.go b/bootstrap/init.go index 98f2c8d..c71e0d6 100644 --- a/bootstrap/init.go +++ b/bootstrap/init.go @@ -28,6 +28,11 @@ func Init(path string) { email.Init() crontab.Init() InitStatic() + } else { + if conf.SlaveConfig.Aria2 { + model.Init() + aria2.Init(false) + } } auth.Init() } diff --git a/models/download.go b/models/download.go index 8b6599d..d981e38 100644 --- a/models/download.go +++ b/models/download.go @@ -100,6 +100,13 @@ func GetDownloadByGid(gid string, uid uint) (*Download, error) { return download, result.Error } +// GetDownloadById 根据ID查找下载 +func GetDownloadById(id uint) (*Download, error) { + var download Download + result := DB.First(&download, id) + return &download, result.Error +} + // GetOwner 获取下载任务所属用户 func (task *Download) GetOwner() *User { if task.User == nil { diff --git a/models/migration.go b/models/migration.go index 4522c8e..c8d0e17 100644 --- a/models/migration.go +++ b/models/migration.go @@ -136,6 +136,8 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti {Name: "aria2_temp_path", Value: ``, Type: "aria2"}, {Name: "aria2_options", Value: `{}`, Type: "aria2"}, {Name: "aria2_interval", Value: `60`, Type: "aria2"}, + {Name: "aria2_remote_enabled", Value: `0`, Type: "aria2"}, + {Name: "aria2_remote_id", Value: `0`, Type: "aria2"}, {Name: "max_worker_num", Value: `10`, Type: "task"}, {Name: "max_parallel_transfer", Value: `4`, Type: "task"}, {Name: "secret_key", Value: util.RandStringRunes(256), Type: "auth"}, diff --git a/pkg/aria2/aria2.go b/pkg/aria2/aria2.go index 40ce36a..a2513a3 100644 --- a/pkg/aria2/aria2.go +++ b/pkg/aria2/aria2.go @@ -2,6 +2,7 @@ package aria2 import ( "encoding/json" + "github.com/cloudreve/Cloudreve/v3/pkg/conf" "net/url" "sync" @@ -89,6 +90,21 @@ func (instance *DummyAria2) Select(task *model.Download, files []int) error { // Init 初始化 func Init(isReload bool) { + if conf.SystemConfig.Mode == "master" { + MasterInit(isReload) + } else { + SlaveInit(isReload) + } +} + +// SlaveInit 从机初始化 +func SlaveInit(isReload bool) { + if !model.IsTrueVal(model.GetSettingByName("aria2_remote_enabled")) { + return + } + if conf.SlaveConfig.SlaveId == 0 || model.GetIntSetting("aria2_remote_id", 0) != int(conf.SlaveConfig.SlaveId) { + return + } Lock.Lock() defer Lock.Unlock() @@ -136,16 +152,82 @@ func Init(isReload bool) { Instance = client - if !isReload { - // 从数据库中读取未完成任务,创建监控 - unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading) + // monitor +} - for i := 0; i < len(unfinished); i++ { - // 创建任务监控 - NewMonitor(&unfinished[i]) +// MasterInit 主机初始化 +func MasterInit(isReload bool) { + Lock.Lock() + defer Lock.Unlock() + + if !model.IsTrueVal(model.GetSettingByName("aria2_remote_enabled")) { + // 关闭上个初始连接 + if previousClient, ok := Instance.(*RPCService); ok { + if previousClient.Caller != nil { + util.Log().Debug("关闭上个 aria2 连接") + previousClient.Caller.Close() + } } - } + options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options") + timeout := model.GetIntSetting("aria2_call_timeout", 5) + if options["aria2_rpcurl"] == "" { + Instance = &DummyAria2{} + return + } + + util.Log().Info("初始化 aria2 RPC 服务[%s]", options["aria2_rpcurl"]) + client := &RPCService{} + + // 解析RPC服务地址 + server, err := url.Parse(options["aria2_rpcurl"]) + if err != nil { + util.Log().Warning("无法解析 aria2 RPC 服务地址,%s", err) + Instance = &DummyAria2{} + return + } + server.Path = "/jsonrpc" + + // 加载自定义下载配置 + var globalOptions map[string]interface{} + err = json.Unmarshal([]byte(options["aria2_options"]), &globalOptions) + if err != nil { + util.Log().Warning("无法解析 aria2 全局配置,%s", err) + Instance = &DummyAria2{} + return + } + + if err := client.Init(server.String(), options["aria2_token"], timeout, globalOptions); err != nil { + util.Log().Warning("初始化 aria2 RPC 服务失败,%s", err) + Instance = &DummyAria2{} + return + } + + Instance = client + + if !isReload { + // 从数据库中读取未完成任务,创建监控 + unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading) + + for i := 0; i < len(unfinished); i++ { + // 创建任务监控 + NewMonitor(&unfinished[i]) + } + } + } else { + util.Log().Info("初始化 从机 aria2 RPC 服务") + remote, err := model.GetPolicyByID(uint(model.GetIntSetting("aria2_remote_id", 0))) + if err != nil { + util.Log().Warning("初始化 从机 aria2 RPC 服务失败,%s", err) + Instance = &DummyAria2{} + return + } + + client := &RemoteService{} + + client.Init(&remote) + Instance = client + } } // getStatus 将给定的状态字符串转换为状态标识数字 diff --git a/pkg/aria2/caller.go b/pkg/aria2/caller.go index 6e287a2..63151f8 100644 --- a/pkg/aria2/caller.go +++ b/pkg/aria2/caller.go @@ -111,7 +111,7 @@ func (client *RPCService) CreateTask(task *model.Download, groupOptions map[stri // 保存到数据库 task.GID = gid - _, err = task.Create() + err = task.Save() if err != nil { return err } diff --git a/pkg/aria2/remote_caller.go b/pkg/aria2/remote_caller.go new file mode 100644 index 0000000..fd11eff --- /dev/null +++ b/pkg/aria2/remote_caller.go @@ -0,0 +1,103 @@ +package aria2 + +import ( + "encoding/json" + "errors" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" + "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/cloudreve/Cloudreve/v3/pkg/request" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "net/url" + "path" + "strings" +) + +// RemoteService 通过从机RPC服务的Aria2任务管理器 +type RemoteService struct { + Policy *model.Policy + Client request.Client + AuthInstance auth.Auth +} + +func (client *RemoteService) Init(policy *model.Policy) { + client.Policy = policy + client.Client = request.HTTPClient{} + client.AuthInstance = auth.HMACAuth{SecretKey: []byte(client.Policy.SecretKey)} +} + +func (client *RemoteService) CreateTask(task *model.Download, options map[string]interface{}) error { + reqBody := serializer.RemoteAria2AddRequest{ + TaskId: task.ID, + Options: options, + } + reqBodyEncoded, err := json.Marshal(reqBody) + if err != nil { + return err + } + + // 发送列表请求 + bodyReader := strings.NewReader(string(reqBodyEncoded)) + signTTL := model.GetIntSetting("slave_api_timeout", 60) + resp, err := client.Client.Request( + "POST", + client.getAPIUrl("add"), + bodyReader, + request.WithCredential(client.AuthInstance, int64(signTTL)), + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return err + } + + // 处理列取结果 + if resp.Code != 0 { + return errors.New(resp.Error) + } + + if resStr, ok := resp.Data.(string); ok { + var res serializer.Response + err = json.Unmarshal([]byte(resStr), &res) + if err != nil { + return err + } + if res.Code != 0 { + return errors.New(res.Msg) + } + } + + return nil +} + +func (client *RemoteService) Status(task *model.Download) (rpc.StatusInfo, error) { + panic("implement me") +} + +func (client *RemoteService) Cancel(task *model.Download) error { + panic("implement me") +} + +func (client *RemoteService) Select(task *model.Download, files []int) error { + panic("implement me") +} + +// getAPIUrl 获取接口请求地址 +func (client *RemoteService) getAPIUrl(scope string, routes ...string) string { + serverURL, err := url.Parse(client.Policy.Server) + if err != nil { + return "" + } + var controller *url.URL + + switch scope { + case "add": + controller, _ = url.Parse("/api/v3/slave/aria2/add") + default: + controller = serverURL + } + + for _, r := range routes { + controller.Path = path.Join(controller.Path, r) + } + + return serverURL.ResolveReference(controller).String() +} diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index b9be97c..da6d325 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -42,6 +42,8 @@ type slave struct { Secret string `validate:"omitempty,gte=64"` CallbackTimeout int `validate:"omitempty,gte=1"` SignatureTTL int `validate:"omitempty,gte=1"` + SlaveId uint `validate:"omitempty"` + Aria2 bool `validate:"omitempty"` } // captcha 验证码配置 diff --git a/pkg/serializer/slave.go b/pkg/serializer/slave.go index e23e809..1cf835b 100644 --- a/pkg/serializer/slave.go +++ b/pkg/serializer/slave.go @@ -10,3 +10,8 @@ type ListRequest struct { Path string `json:"path"` Recursive bool `json:"recursive"` } + +type RemoteAria2AddRequest struct { + TaskId uint `json:"task_id"` + Options map[string]interface{} `json:"options"` +} diff --git a/routers/controllers/slave.go b/routers/controllers/slave.go index e10e2b0..71652fe 100644 --- a/routers/controllers/slave.go +++ b/routers/controllers/slave.go @@ -2,6 +2,7 @@ package controllers import ( "context" + "github.com/cloudreve/Cloudreve/v3/service/slave" "net/url" "strconv" @@ -175,3 +176,14 @@ func SlaveList(c *gin.Context) { c.JSON(200, ErrorResponse(err)) } } + +// SlaveAria2Add 从机创建远程下载任务 +func SlaveAria2Add(c *gin.Context) { + var service slave.Aria2AddService + if err := c.ShouldBindJSON(&service); err == nil { + res := service.Add() + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} diff --git a/routers/router.go b/routers/router.go index 8cf5a9d..cbffe30 100644 --- a/routers/router.go +++ b/routers/router.go @@ -50,6 +50,12 @@ func InitSlaveRouter() *gin.Engine { // 列出文件 v3.POST("list", controllers.SlaveList) } + + aria2 := v3.Group("aria2") + aria2.Use(middleware.SignRequired()) + { + aria2.POST("add", controllers.SlaveAria2Add) + } return r } diff --git a/service/admin/aria2.go b/service/admin/aria2.go index 8801c96..6c3308d 100644 --- a/service/admin/aria2.go +++ b/service/admin/aria2.go @@ -1,6 +1,7 @@ package admin import ( + model "github.com/cloudreve/Cloudreve/v3/models" "net/url" "github.com/cloudreve/Cloudreve/v3/pkg/aria2" @@ -15,29 +16,34 @@ type Aria2TestService struct { // Test 测试aria2连接 func (service *Aria2TestService) Test() serializer.Response { - testRPC := aria2.RPCService{} + if !model.IsTrueVal(model.GetSettingByName("aria2_remote_enabled")) { + testRPC := aria2.RPCService{} - // 解析RPC服务地址 - server, err := url.Parse(service.Server) - if err != nil { - return serializer.ParamErr("无法解析 aria2 RPC 服务地址, "+err.Error(), nil) + // 解析RPC服务地址 + server, err := url.Parse(service.Server) + if err != nil { + return serializer.ParamErr("无法解析 aria2 RPC 服务地址, "+err.Error(), nil) + } + server.Path = "/jsonrpc" + + if err := testRPC.Init(server.String(), service.Token, 5, map[string]interface{}{}); err != nil { + return serializer.ParamErr("无法初始化连接, "+err.Error(), nil) + } + + defer testRPC.Caller.Close() + + info, err := testRPC.Caller.GetVersion() + if err != nil { + return serializer.ParamErr("无法请求 RPC 服务, "+err.Error(), nil) + } + + if info.Version == "" { + return serializer.ParamErr("RPC 服务返回非预期响应", nil) + } + + return serializer.Response{Data: info.Version} + } else { + // TODO + return serializer.Response{Data: "TODO"} } - server.Path = "/jsonrpc" - - if err := testRPC.Init(server.String(), service.Token, 5, map[string]interface{}{}); err != nil { - return serializer.ParamErr("无法初始化连接, "+err.Error(), nil) - } - - defer testRPC.Caller.Close() - - info, err := testRPC.Caller.GetVersion() - if err != nil { - return serializer.ParamErr("无法请求 RPC 服务, "+err.Error(), nil) - } - - if info.Version == "" { - return serializer.ParamErr("RPC 服务返回非预期响应", nil) - } - - return serializer.Response{Data: info.Version} } diff --git a/service/aria2/add.go b/service/aria2/add.go index be7213a..5b95fd6 100644 --- a/service/aria2/add.go +++ b/service/aria2/add.go @@ -42,6 +42,11 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo Source: service.URL, } + _, err = task.Create() + if err != nil { + return serializer.Err(serializer.CodeNotSet, "任务创建失败", err) + } + aria2.Lock.RLock() if err := aria2.Instance.CreateTask(task, fs.User.Group.OptionsSerialized.Aria2Options); err != nil { aria2.Lock.RUnlock() diff --git a/service/slave/aria2.go b/service/slave/aria2.go new file mode 100644 index 0000000..6e22208 --- /dev/null +++ b/service/slave/aria2.go @@ -0,0 +1,28 @@ +package slave + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/util" +) + +type Aria2AddService struct { + TaskId uint `json:"task_id"` + Options map[string]interface{} `json:"options"` +} + +func (service *Aria2AddService) Add() serializer.Response { + task, err := model.GetDownloadById(service.TaskId) + if err != nil { + util.Log().Warning("无法获取记录, %s", err) + return serializer.Err(serializer.CodeNotSet, "任务创建失败, 无法获取记录", err) + } + aria2.Lock.RLock() + if err := aria2.Instance.CreateTask(task, service.Options); err != nil { + aria2.Lock.RUnlock() + return serializer.Err(serializer.CodeNotSet, "任务创建失败", err) + } + aria2.Lock.RUnlock() + return serializer.Response{} +}