diff --git a/models/node.go b/models/node.go index 739a98e..0c49aad 100644 --- a/models/node.go +++ b/models/node.go @@ -1,6 +1,9 @@ package model -import "github.com/jinzhu/gorm" +import ( + "encoding/json" + "github.com/jinzhu/gorm" +) // Node 从机节点信息模型 type Node struct { @@ -12,6 +15,23 @@ type Node struct { SecretKey string `gorm:"type:text"` // 通信密钥 Aria2Enabled bool // 是否支持用作离线下载节点 Aria2Options string `gorm:"type:text"` // 离线下载配置 + + // 数据库忽略字段 + Aria2OptionsSerialized Aria2Option `gorm:"-"` +} + +// Aria2Option 非公有的Aria2配置属性 +type Aria2Option struct { + // RPC 服务器地址 + Server string `json:"server,omitempty"` + // RPC 密钥 + Token string `json:"token,omitempty"` + // 临时下载目录 + TempPath string `json:"temp_path,omitempty"` + // 附加下载配置 + Options string `json:"options,omitempty"` + // 下载监控间隔 + Interval string `json:"interval,omitempty"` } type NodeStatus int @@ -33,3 +53,20 @@ func GetNodesByStatus(status ...NodeStatus) ([]Node, error) { result := DB.Where("status in (?)", status).Find(&nodes) return nodes, result.Error } + +// AfterFind 找到节点后的钩子 +func (node *Node) AfterFind() (err error) { + // 解析离线下载设置到 Aria2OptionsSerialized + if node.Aria2Options != "" { + err = json.Unmarshal([]byte(node.Aria2Options), &node.Aria2OptionsSerialized) + } + + return err +} + +// BeforeSave Save策略前的钩子 +func (node *Node) BeforeSave() (err error) { + optionsValue, err := json.Marshal(&node.Aria2OptionsSerialized) + node.Aria2Options = string(optionsValue) + return err +} diff --git a/pkg/aria2/aria2.go b/pkg/aria2/aria2.go index 21953cf..83131b6 100644 --- a/pkg/aria2/aria2.go +++ b/pkg/aria2/aria2.go @@ -1,14 +1,11 @@ package aria2 import ( - "encoding/json" - "net/url" "sync" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" ) // Instance 默认使用的Aria2处理实例 @@ -22,6 +19,8 @@ var EventNotifier = &Notifier{} // Aria2 离线下载处理接口 type Aria2 interface { + // Init 初始化客户端连接 + Init() error // CreateTask 创建新的任务 CreateTask(task *model.Download, options map[string]interface{}) (string, error) // 返回状态信息 @@ -67,6 +66,10 @@ var ( type DummyAria2 struct { } +func (instance *DummyAria2) Init() error { + return nil +} + // CreateTask 创建新任务,此处直接返回未开启错误 func (instance *DummyAria2) CreateTask(model *model.Download, options map[string]interface{}) (string, error) { return "", ErrNotEnabled @@ -89,53 +92,6 @@ func (instance *DummyAria2) Select(task *model.Download, files []int) error { // Init 初始化 func Init(isReload bool) { - Lock.Lock() - defer Lock.Unlock() - - // 关闭上个初始连接 - if previousClient, ok := Instance.(*RPCService); ok { - if previousClient.Caller != nil { - util.Log().Debug("关闭上个 aria2 连接") - previousClient.Caller.Close() - } - } - - options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options") - timeout := model.GetIntSetting("aria2_call_timeout", 5) - if options["aria2_rpcurl"] == "" { - Instance = &DummyAria2{} - return - } - - util.Log().Info("初始化 aria2 RPC 服务[%s]", options["aria2_rpcurl"]) - client := &RPCService{} - - // 解析RPC服务地址 - server, err := url.Parse(options["aria2_rpcurl"]) - if err != nil { - util.Log().Warning("无法解析 aria2 RPC 服务地址,%s", err) - Instance = &DummyAria2{} - return - } - server.Path = "/jsonrpc" - - // 加载自定义下载配置 - var globalOptions map[string]interface{} - err = json.Unmarshal([]byte(options["aria2_options"]), &globalOptions) - if err != nil { - util.Log().Warning("无法解析 aria2 全局配置,%s", err) - Instance = &DummyAria2{} - return - } - - if err := client.Init(server.String(), options["aria2_token"], timeout, globalOptions); err != nil { - util.Log().Warning("初始化 aria2 RPC 服务失败,%s", err) - Instance = &DummyAria2{} - return - } - - Instance = client - if !isReload { // 从数据库中读取未完成任务,创建监控 unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading) diff --git a/pkg/cluster/master.go b/pkg/cluster/master.go index a044112..6b2cf08 100644 --- a/pkg/cluster/master.go +++ b/pkg/cluster/master.go @@ -1,12 +1,51 @@ package cluster import ( + "context" + "encoding/json" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/util" + "net/url" + "sync" + "time" ) type MasterNode struct { - Model *model.Node + Model *model.Node + aria2RPC rpcService + lock sync.RWMutex +} + +// RPCService 通过RPC服务的Aria2任务管理器 +type rpcService struct { + Caller rpc.Client + Initialized bool + + parent *MasterNode + options *clientOptions +} + +type clientOptions struct { + Options map[string]interface{} // 创建下载时额外添加的设置 +} + +// Init 初始化节点 +func (node *MasterNode) Init(nodeModel *model.Node) { + node.lock.Lock() + node.Model = nodeModel + node.aria2RPC.parent = node + node.lock.Unlock() + + node.lock.RLock() + if node.Model.Aria2Enabled { + node.lock.RUnlock() + node.aria2RPC.Init() + return + } + node.lock.RUnlock() } func (node *MasterNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) { @@ -30,7 +69,70 @@ func (node *MasterNode) IsActive() bool { return true } -// InitAria2RPCClient 初始化主机 Aria2 RPC 服务 -func (node *MasterNode) InitAria2RPCClient() error { - return nil +// GetAria2Instance 获取主机Aria2实例 +func (node *MasterNode) GetAria2Instance() (aria2.Aria2, error) { + if !node.Model.Aria2Enabled { + return &aria2.DummyAria2{}, nil + } + + node.lock.RLock() + defer node.lock.RUnlock() + if !node.aria2RPC.Initialized { + return &aria2.DummyAria2{}, nil + } + + return &node.aria2RPC, nil +} + +func (r *rpcService) Init() error { + r.parent.lock.Lock() + defer r.parent.lock.Unlock() + r.Initialized = false + + // 客户端已存在,则关闭先前连接 + if r.Caller != nil { + r.Caller.Close() + } + + // 解析RPC服务地址 + server, err := url.Parse(r.parent.Model.Aria2OptionsSerialized.Server) + if err != nil { + util.Log().Warning("无法解析主机 Aria2 RPC 服务地址,%s", err) + return err + } + server.Path = "/jsonrpc" + + // 加载自定义下载配置 + var globalOptions map[string]interface{} + err = json.Unmarshal([]byte(r.parent.Model.Aria2OptionsSerialized.Options), &globalOptions) + if err != nil { + util.Log().Warning("无法解析主机 Aria2 配置,%s", err) + return err + } + + r.options = &clientOptions{ + Options: globalOptions, + } + timeout := model.GetIntSetting("aria2_call_timeout", 5) + caller, err := rpc.New(context.Background(), server.String(), r.parent.Model.Aria2OptionsSerialized.Token, time.Duration(timeout)*time.Second, aria2.EventNotifier) + + r.Caller = caller + r.Initialized = true + return err +} + +func (r *rpcService) CreateTask(task *model.Download, options map[string]interface{}) (string, error) { + panic("implement me") +} + +func (r *rpcService) Status(task *model.Download) (rpc.StatusInfo, error) { + panic("implement me") +} + +func (r *rpcService) Cancel(task *model.Download) error { + panic("implement me") +} + +func (r *rpcService) Select(task *model.Download, files []int) error { + panic("implement me") } diff --git a/pkg/cluster/node.go b/pkg/cluster/node.go index dcbff47..0925623 100644 --- a/pkg/cluster/node.go +++ b/pkg/cluster/node.go @@ -2,32 +2,28 @@ package cluster import ( model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/request" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" ) type Node interface { + Init(node *model.Node) IsFeatureEnabled(feature string) bool SubscribeStatusChange(callback func(isActive bool, id uint)) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) IsActive() bool + GetAria2Instance() (aria2.Aria2, error) } func getNodeFromDBModel(node *model.Node) Node { switch node.Type { case model.SlaveNodeType: - slave := &SlaveNode{ - Model: node, - AuthInstance: auth.HMACAuth{SecretKey: []byte(node.SecretKey)}, - Client: request.HTTPClient{}, - Active: true, - } - go slave.StartPingLoop() + slave := &SlaveNode{} + slave.Init(node) return slave default: - return &MasterNode{ - Model: node, - } + master := &MasterNode{} + master.Init(node) + return master } } diff --git a/pkg/cluster/slave.go b/pkg/cluster/slave.go index aa74364..347d34c 100644 --- a/pkg/cluster/slave.go +++ b/pkg/cluster/slave.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2" "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" @@ -26,6 +27,21 @@ type SlaveNode struct { lock sync.RWMutex } +// Init 初始化节点 +func (node *SlaveNode) Init(nodeModel *model.Node) { + node.lock.Lock() + node.Model = nodeModel + node.AuthInstance = auth.HMACAuth{SecretKey: []byte(nodeModel.SecretKey)} + node.Client = request.HTTPClient{} + node.Active = true + if node.close != nil { + node.close <- true + } + node.lock.Unlock() + + go node.StartPingLoop() +} + // IsFeatureEnabled 查询节点的某项功能是否启用 func (node *SlaveNode) IsFeatureEnabled(feature string) bool { switch feature { @@ -167,3 +183,8 @@ loop: } } } + +// GetAria2Instance 获取从机Aria2实例 +func (node *SlaveNode) GetAria2Instance() (aria2.Aria2, error) { + return nil, nil +}