Feat: init aria2 client in master node
This commit is contained in:
parent
cb737be9bb
commit
7f50406a31
5 changed files with 179 additions and 67 deletions
|
@ -1,6 +1,9 @@
|
|||
package model
|
||||
|
||||
import "github.com/jinzhu/gorm"
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
// Node 从机节点信息模型
|
||||
type Node struct {
|
||||
|
@ -12,6 +15,23 @@ type Node struct {
|
|||
SecretKey string `gorm:"type:text"` // 通信密钥
|
||||
Aria2Enabled bool // 是否支持用作离线下载节点
|
||||
Aria2Options string `gorm:"type:text"` // 离线下载配置
|
||||
|
||||
// 数据库忽略字段
|
||||
Aria2OptionsSerialized Aria2Option `gorm:"-"`
|
||||
}
|
||||
|
||||
// Aria2Option 非公有的Aria2配置属性
|
||||
type Aria2Option struct {
|
||||
// RPC 服务器地址
|
||||
Server string `json:"server,omitempty"`
|
||||
// RPC 密钥
|
||||
Token string `json:"token,omitempty"`
|
||||
// 临时下载目录
|
||||
TempPath string `json:"temp_path,omitempty"`
|
||||
// 附加下载配置
|
||||
Options string `json:"options,omitempty"`
|
||||
// 下载监控间隔
|
||||
Interval string `json:"interval,omitempty"`
|
||||
}
|
||||
|
||||
type NodeStatus int
|
||||
|
@ -33,3 +53,20 @@ func GetNodesByStatus(status ...NodeStatus) ([]Node, error) {
|
|||
result := DB.Where("status in (?)", status).Find(&nodes)
|
||||
return nodes, result.Error
|
||||
}
|
||||
|
||||
// AfterFind 找到节点后的钩子
|
||||
func (node *Node) AfterFind() (err error) {
|
||||
// 解析离线下载设置到 Aria2OptionsSerialized
|
||||
if node.Aria2Options != "" {
|
||||
err = json.Unmarshal([]byte(node.Aria2Options), &node.Aria2OptionsSerialized)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// BeforeSave Save策略前的钩子
|
||||
func (node *Node) BeforeSave() (err error) {
|
||||
optionsValue, err := json.Marshal(&node.Aria2OptionsSerialized)
|
||||
node.Aria2Options = string(optionsValue)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1,14 +1,11 @@
|
|||
package aria2
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
// Instance 默认使用的Aria2处理实例
|
||||
|
@ -22,6 +19,8 @@ var EventNotifier = &Notifier{}
|
|||
|
||||
// Aria2 离线下载处理接口
|
||||
type Aria2 interface {
|
||||
// Init 初始化客户端连接
|
||||
Init() error
|
||||
// CreateTask 创建新的任务
|
||||
CreateTask(task *model.Download, options map[string]interface{}) (string, error)
|
||||
// 返回状态信息
|
||||
|
@ -67,6 +66,10 @@ var (
|
|||
type DummyAria2 struct {
|
||||
}
|
||||
|
||||
func (instance *DummyAria2) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateTask 创建新任务,此处直接返回未开启错误
|
||||
func (instance *DummyAria2) CreateTask(model *model.Download, options map[string]interface{}) (string, error) {
|
||||
return "", ErrNotEnabled
|
||||
|
@ -89,53 +92,6 @@ func (instance *DummyAria2) Select(task *model.Download, files []int) error {
|
|||
|
||||
// Init 初始化
|
||||
func Init(isReload bool) {
|
||||
Lock.Lock()
|
||||
defer Lock.Unlock()
|
||||
|
||||
// 关闭上个初始连接
|
||||
if previousClient, ok := Instance.(*RPCService); ok {
|
||||
if previousClient.Caller != nil {
|
||||
util.Log().Debug("关闭上个 aria2 连接")
|
||||
previousClient.Caller.Close()
|
||||
}
|
||||
}
|
||||
|
||||
options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options")
|
||||
timeout := model.GetIntSetting("aria2_call_timeout", 5)
|
||||
if options["aria2_rpcurl"] == "" {
|
||||
Instance = &DummyAria2{}
|
||||
return
|
||||
}
|
||||
|
||||
util.Log().Info("初始化 aria2 RPC 服务[%s]", options["aria2_rpcurl"])
|
||||
client := &RPCService{}
|
||||
|
||||
// 解析RPC服务地址
|
||||
server, err := url.Parse(options["aria2_rpcurl"])
|
||||
if err != nil {
|
||||
util.Log().Warning("无法解析 aria2 RPC 服务地址,%s", err)
|
||||
Instance = &DummyAria2{}
|
||||
return
|
||||
}
|
||||
server.Path = "/jsonrpc"
|
||||
|
||||
// 加载自定义下载配置
|
||||
var globalOptions map[string]interface{}
|
||||
err = json.Unmarshal([]byte(options["aria2_options"]), &globalOptions)
|
||||
if err != nil {
|
||||
util.Log().Warning("无法解析 aria2 全局配置,%s", err)
|
||||
Instance = &DummyAria2{}
|
||||
return
|
||||
}
|
||||
|
||||
if err := client.Init(server.String(), options["aria2_token"], timeout, globalOptions); err != nil {
|
||||
util.Log().Warning("初始化 aria2 RPC 服务失败,%s", err)
|
||||
Instance = &DummyAria2{}
|
||||
return
|
||||
}
|
||||
|
||||
Instance = client
|
||||
|
||||
if !isReload {
|
||||
// 从数据库中读取未完成任务,创建监控
|
||||
unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading)
|
||||
|
|
|
@ -1,12 +1,51 @@
|
|||
package cluster
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type MasterNode struct {
|
||||
Model *model.Node
|
||||
Model *model.Node
|
||||
aria2RPC rpcService
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// RPCService 通过RPC服务的Aria2任务管理器
|
||||
type rpcService struct {
|
||||
Caller rpc.Client
|
||||
Initialized bool
|
||||
|
||||
parent *MasterNode
|
||||
options *clientOptions
|
||||
}
|
||||
|
||||
type clientOptions struct {
|
||||
Options map[string]interface{} // 创建下载时额外添加的设置
|
||||
}
|
||||
|
||||
// Init 初始化节点
|
||||
func (node *MasterNode) Init(nodeModel *model.Node) {
|
||||
node.lock.Lock()
|
||||
node.Model = nodeModel
|
||||
node.aria2RPC.parent = node
|
||||
node.lock.Unlock()
|
||||
|
||||
node.lock.RLock()
|
||||
if node.Model.Aria2Enabled {
|
||||
node.lock.RUnlock()
|
||||
node.aria2RPC.Init()
|
||||
return
|
||||
}
|
||||
node.lock.RUnlock()
|
||||
}
|
||||
|
||||
func (node *MasterNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
|
||||
|
@ -30,7 +69,70 @@ func (node *MasterNode) IsActive() bool {
|
|||
return true
|
||||
}
|
||||
|
||||
// InitAria2RPCClient 初始化主机 Aria2 RPC 服务
|
||||
func (node *MasterNode) InitAria2RPCClient() error {
|
||||
return nil
|
||||
// GetAria2Instance 获取主机Aria2实例
|
||||
func (node *MasterNode) GetAria2Instance() (aria2.Aria2, error) {
|
||||
if !node.Model.Aria2Enabled {
|
||||
return &aria2.DummyAria2{}, nil
|
||||
}
|
||||
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
if !node.aria2RPC.Initialized {
|
||||
return &aria2.DummyAria2{}, nil
|
||||
}
|
||||
|
||||
return &node.aria2RPC, nil
|
||||
}
|
||||
|
||||
func (r *rpcService) Init() error {
|
||||
r.parent.lock.Lock()
|
||||
defer r.parent.lock.Unlock()
|
||||
r.Initialized = false
|
||||
|
||||
// 客户端已存在,则关闭先前连接
|
||||
if r.Caller != nil {
|
||||
r.Caller.Close()
|
||||
}
|
||||
|
||||
// 解析RPC服务地址
|
||||
server, err := url.Parse(r.parent.Model.Aria2OptionsSerialized.Server)
|
||||
if err != nil {
|
||||
util.Log().Warning("无法解析主机 Aria2 RPC 服务地址,%s", err)
|
||||
return err
|
||||
}
|
||||
server.Path = "/jsonrpc"
|
||||
|
||||
// 加载自定义下载配置
|
||||
var globalOptions map[string]interface{}
|
||||
err = json.Unmarshal([]byte(r.parent.Model.Aria2OptionsSerialized.Options), &globalOptions)
|
||||
if err != nil {
|
||||
util.Log().Warning("无法解析主机 Aria2 配置,%s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
r.options = &clientOptions{
|
||||
Options: globalOptions,
|
||||
}
|
||||
timeout := model.GetIntSetting("aria2_call_timeout", 5)
|
||||
caller, err := rpc.New(context.Background(), server.String(), r.parent.Model.Aria2OptionsSerialized.Token, time.Duration(timeout)*time.Second, aria2.EventNotifier)
|
||||
|
||||
r.Caller = caller
|
||||
r.Initialized = true
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *rpcService) CreateTask(task *model.Download, options map[string]interface{}) (string, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (r *rpcService) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (r *rpcService) Cancel(task *model.Download) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (r *rpcService) Select(task *model.Download, files []int) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
|
|
@ -2,32 +2,28 @@ package cluster
|
|||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
)
|
||||
|
||||
type Node interface {
|
||||
Init(node *model.Node)
|
||||
IsFeatureEnabled(feature string) bool
|
||||
SubscribeStatusChange(callback func(isActive bool, id uint))
|
||||
Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error)
|
||||
IsActive() bool
|
||||
GetAria2Instance() (aria2.Aria2, error)
|
||||
}
|
||||
|
||||
func getNodeFromDBModel(node *model.Node) Node {
|
||||
switch node.Type {
|
||||
case model.SlaveNodeType:
|
||||
slave := &SlaveNode{
|
||||
Model: node,
|
||||
AuthInstance: auth.HMACAuth{SecretKey: []byte(node.SecretKey)},
|
||||
Client: request.HTTPClient{},
|
||||
Active: true,
|
||||
}
|
||||
go slave.StartPingLoop()
|
||||
slave := &SlaveNode{}
|
||||
slave.Init(node)
|
||||
return slave
|
||||
default:
|
||||
return &MasterNode{
|
||||
Model: node,
|
||||
}
|
||||
master := &MasterNode{}
|
||||
master.Init(node)
|
||||
return master
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
|
@ -26,6 +27,21 @@ type SlaveNode struct {
|
|||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// Init 初始化节点
|
||||
func (node *SlaveNode) Init(nodeModel *model.Node) {
|
||||
node.lock.Lock()
|
||||
node.Model = nodeModel
|
||||
node.AuthInstance = auth.HMACAuth{SecretKey: []byte(nodeModel.SecretKey)}
|
||||
node.Client = request.HTTPClient{}
|
||||
node.Active = true
|
||||
if node.close != nil {
|
||||
node.close <- true
|
||||
}
|
||||
node.lock.Unlock()
|
||||
|
||||
go node.StartPingLoop()
|
||||
}
|
||||
|
||||
// IsFeatureEnabled 查询节点的某项功能是否启用
|
||||
func (node *SlaveNode) IsFeatureEnabled(feature string) bool {
|
||||
switch feature {
|
||||
|
@ -167,3 +183,8 @@ loop:
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetAria2Instance 获取从机Aria2实例
|
||||
func (node *SlaveNode) GetAria2Instance() (aria2.Aria2, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue