diff --git a/bootstrap/init.go b/bootstrap/init.go index 8c8958e..553bc81 100644 --- a/bootstrap/init.go +++ b/bootstrap/init.go @@ -24,7 +24,7 @@ func Init(path string) { if conf.SystemConfig.Mode == "master" { model.Init() task.Init() - aria2.Init() + aria2.Init(false) email.Init() crontab.Init() } diff --git a/models/group.go b/models/group.go index b45991b..c190975 100644 --- a/models/group.go +++ b/models/group.go @@ -24,15 +24,15 @@ type Group struct { // GroupOption 用户组其他配置 type GroupOption struct { - ArchiveDownload bool `json:"archive_download,omitempty"` // 打包下载 - ArchiveTask bool `json:"archive_task,omitempty"` // 在线压缩 - CompressSize uint64 `json:"compress_size,omitempty"` // 可压缩大小 - DecompressSize uint64 `json:"decompress_size,omitempty"` - OneTimeDownload bool `json:"one_time_download,omitempty"` - ShareDownload bool `json:"share_download,omitempty"` - ShareFree bool `json:"share_free,omitempty"` - Aria2 bool `json:"aria2,omitempty"` // 离线下载 - Aria2Options []interface{} `json:"aria2_options,omitempty"` // 离线下载用户组配置 + ArchiveDownload bool `json:"archive_download,omitempty"` // 打包下载 + ArchiveTask bool `json:"archive_task,omitempty"` // 在线压缩 + CompressSize uint64 `json:"compress_size,omitempty"` // 可压缩大小 + DecompressSize uint64 `json:"decompress_size,omitempty"` + OneTimeDownload bool `json:"one_time_download,omitempty"` + ShareDownload bool `json:"share_download,omitempty"` + ShareFree bool `json:"share_free,omitempty"` + Aria2 bool `json:"aria2,omitempty"` // 离线下载 + Aria2Options map[string]interface{} `json:"aria2_options,omitempty"` // 离线下载用户组配置 } // GetGroupByID 用ID获取用户组 diff --git a/models/migration.go b/models/migration.go index fc25ca4..c6de4cd 100644 --- a/models/migration.go +++ b/models/migration.go @@ -140,10 +140,11 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti {Name: "gravatar_server", Value: `https://gravatar.loli.net/`, Type: "avatar"}, {Name: "defaultTheme", Value: `#3f51b5`, Type: "basic"}, {Name: "themes", Value: `{"#3f51b5":{"palette":{"primary":{"main":"#3f51b5"},"secondary":{"main":"#f50057"}}},"#2196f3":{"palette":{"primary":{"main":"#2196f3"},"secondary":{"main":"#FFC107"}}},"#673AB7":{"palette":{"primary":{"main":"#673AB7"},"secondary":{"main":"#2196F3"}}},"#E91E63":{"palette":{"primary":{"main":"#E91E63"},"secondary":{"main":"#42A5F5","contrastText":"#fff"}}},"#FF5722":{"palette":{"primary":{"main":"#FF5722"},"secondary":{"main":"#3F51B5"}}},"#FFC107":{"palette":{"primary":{"main":"#FFC107"},"secondary":{"main":"#26C6DA"}}},"#8BC34A":{"palette":{"primary":{"main":"#8BC34A","contrastText":"#fff"},"secondary":{"main":"#FF8A65","contrastText":"#fff"}}},"#009688":{"palette":{"primary":{"main":"#009688"},"secondary":{"main":"#4DD0E1","contrastText":"#fff"}}},"#607D8B":{"palette":{"primary":{"main":"#607D8B"},"secondary":{"main":"#F06292"}}},"#795548":{"palette":{"primary":{"main":"#795548"},"secondary":{"main":"#4CAF50","contrastText":"#fff"}}}}`, Type: "basic"}, - {Name: "aria2_token", Value: `your token`, Type: "aria2"}, + {Name: "aria2_token", Value: ``, Type: "aria2"}, + {Name: "aria2_rpcurl", Value: ``, Type: "aria2"}, {Name: "aria2_temp_path", Value: ``, Type: "aria2"}, - {Name: "aria2_options", Value: `[]`, Type: "aria2"}, - {Name: "aria2_interval", Value: `10`, Type: "aria2"}, + {Name: "aria2_options", Value: `{}`, Type: "aria2"}, + {Name: "aria2_interval", Value: `60`, Type: "aria2"}, {Name: "max_worker_num", Value: `10`, Type: "task"}, {Name: "max_parallel_transfer", Value: `4`, Type: "task"}, {Name: "secret_key", Value: util.RandStringRunes(256), Type: "auth"}, diff --git a/pkg/aria2/aria2.go b/pkg/aria2/aria2.go index d492938..413b5a2 100644 --- a/pkg/aria2/aria2.go +++ b/pkg/aria2/aria2.go @@ -7,18 +7,22 @@ import ( "github.com/HFO4/cloudreve/pkg/serializer" "github.com/HFO4/cloudreve/pkg/util" "net/url" + "sync" ) // Instance 默认使用的Aria2处理实例 var Instance Aria2 = &DummyAria2{} +// Lock Instance的读写锁 +var Lock sync.RWMutex + // EventNotifier 任务状态更新通知处理器 var EventNotifier = &Notifier{} // Aria2 离线下载处理接口 type Aria2 interface { // CreateTask 创建新的任务 - CreateTask(task *model.Download, options []interface{}) error + CreateTask(task *model.Download, options map[string]interface{}) error // 返回状态信息 Status(task *model.Download) (rpc.StatusInfo, error) // 取消任务 @@ -63,7 +67,7 @@ type DummyAria2 struct { } // CreateTask 创建新任务,此处直接返回未开启错误 -func (instance *DummyAria2) CreateTask(model *model.Download, options []interface{}) error { +func (instance *DummyAria2) CreateTask(model *model.Download, options map[string]interface{}) error { return ErrNotEnabled } @@ -83,7 +87,16 @@ func (instance *DummyAria2) Select(task *model.Download, files []int) error { } // Init 初始化 -func Init() { +func Init(isReload bool) { + Lock.Lock() + defer Lock.Unlock() + + // 关闭上个初始连接 + if previousClient, ok := Instance.(*RPCService); ok { + util.Log().Debug("关闭上个 aria2 连接") + previousClient.caller.Close() + } + options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options") timeout := model.GetIntSetting("aria2_call_timeout", 5) if options["aria2_rpcurl"] == "" { @@ -93,9 +106,6 @@ func Init() { util.Log().Info("初始化 aria2 RPC 服务[%s]", options["aria2_rpcurl"]) client := &RPCService{} - if previousClient, ok := Instance.(*RPCService); ok { - client = previousClient - } // 解析RPC服务地址 server, err := url.Parse(options["aria2_rpcurl"]) @@ -107,7 +117,7 @@ func Init() { server.Path = "/jsonrpc" // 加载自定义下载配置 - var globalOptions []interface{} + var globalOptions map[string]interface{} err = json.Unmarshal([]byte(options["aria2_options"]), &globalOptions) if err != nil { util.Log().Warning("无法解析 aria2 全局配置,%s", err) @@ -123,13 +133,16 @@ func Init() { Instance = client - // 从数据库中读取未完成任务,创建监控 - unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading) + if !isReload { + // 从数据库中读取未完成任务,创建监控 + unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading) - for i := 0; i < len(unfinished); i++ { - // 创建任务监控 - NewMonitor(&unfinished[i]) + for i := 0; i < len(unfinished); i++ { + // 创建任务监控 + NewMonitor(&unfinished[i]) + } } + } // getStatus 将给定的状态字符串转换为状态标识数字 diff --git a/pkg/aria2/aria2_test.go b/pkg/aria2/aria2_test.go index aebbe37..0cca187 100644 --- a/pkg/aria2/aria2_test.go +++ b/pkg/aria2/aria2_test.go @@ -45,14 +45,14 @@ func TestInit(t *testing.T) { // 未指定RPC地址,跳过 { cache.Set("setting_aria2_rpcurl", "", 0) - Init() + Init(false) asserts.IsType(&DummyAria2{}, Instance) } // 无法解析服务器地址 { cache.Set("setting_aria2_rpcurl", string(byte(0x7f)), 0) - Init() + Init(false) asserts.IsType(&DummyAria2{}, Instance) } @@ -61,7 +61,7 @@ func TestInit(t *testing.T) { Instance = &RPCService{} cache.Set("setting_aria2_options", "?", 0) cache.Set("setting_aria2_rpcurl", "ws://127.0.0.1:1234", 0) - Init() + Init(false) asserts.IsType(&DummyAria2{}, Instance) } @@ -72,7 +72,7 @@ func TestInit(t *testing.T) { cache.Set("setting_aria2_call_timeout", "1", 0) cache.Set("setting_aria2_interval", "100", 0) mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"g_id"}).AddRow("1")) - Init() + Init(false) asserts.NoError(mock.ExpectationsWereMet()) asserts.IsType(&RPCService{}, Instance) } diff --git a/pkg/aria2/caller.go b/pkg/aria2/caller.go index bb068dc..638792a 100644 --- a/pkg/aria2/caller.go +++ b/pkg/aria2/caller.go @@ -18,11 +18,11 @@ type RPCService struct { } type clientOptions struct { - Options []interface{} // 创建下载时额外添加的设置 + Options map[string]interface{} // 创建下载时额外添加的设置 } // Init 初始化 -func (client *RPCService) Init(server, secret string, timeout int, options []interface{}) error { +func (client *RPCService) Init(server, secret string, timeout int, options map[string]interface{}) error { // 客户端已存在,则关闭先前连接 if client.caller != nil { client.caller.Close() @@ -84,7 +84,7 @@ func (client *RPCService) Select(task *model.Download, files []int) error { } // CreateTask 创建新任务 -func (client *RPCService) CreateTask(task *model.Download, groupOptions []interface{}) error { +func (client *RPCService) CreateTask(task *model.Download, groupOptions map[string]interface{}) error { // 生成存储路径 path := filepath.Join( model.GetSettingByName("aria2_temp_path"), @@ -93,13 +93,17 @@ func (client *RPCService) CreateTask(task *model.Download, groupOptions []interf ) // 创建下载任务 - options := []interface{}{map[string]string{"dir": path}} - if len(client.options.Options) > 0 { - options = append(options, client.options.Options...) + options := map[string]interface{}{ + "dir": path, + } + for k, v := range client.options.Options { + options[k] = v + } + for k, v := range groupOptions { + options[k] = v } - options = append(options, groupOptions...) - gid, err := client.caller.AddURI(task.Source, options...) + gid, err := client.caller.AddURI(task.Source, options) if err != nil || gid == "" { return err } diff --git a/pkg/aria2/caller_test.go b/pkg/aria2/caller_test.go index 7bfec67..ca065b8 100644 --- a/pkg/aria2/caller_test.go +++ b/pkg/aria2/caller_test.go @@ -46,6 +46,6 @@ func TestRPCService_CreateTask(t *testing.T) { caller := &RPCService{} asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil)) cache.Set("setting_aria2_temp_path", "test", 0) - err := caller.CreateTask(&model.Download{Parent: "test"}, []interface{}{map[string]string{"1": "1"}}) + err := caller.CreateTask(&model.Download{Parent: "test"}, map[string]interface{}{"1": "1"}) asserts.Error(err) } diff --git a/pkg/aria2/monitor.go b/pkg/aria2/monitor.go index f5ce324..eeb9226 100644 --- a/pkg/aria2/monitor.go +++ b/pkg/aria2/monitor.go @@ -69,7 +69,10 @@ func (monitor *Monitor) Loop() { // Update 更新状态,返回值表示是否退出监控 func (monitor *Monitor) Update() bool { + Lock.RLock() status, err := Instance.Status(monitor.Task) + Lock.RUnlock() + if err != nil { monitor.retried++ util.Log().Warning("无法获取下载任务[%s]的状态,%s", monitor.Task.GID, err) @@ -160,7 +163,9 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error { // 文件大小更新后,对文件限制等进行校验 if err := monitor.ValidateFile(); err != nil { // 验证失败时取消任务 + Lock.RLock() Instance.Cancel(monitor.Task) + Lock.RUnlock() return err } } diff --git a/routers/controllers/admin.go b/routers/controllers/admin.go index a0c1349..749264f 100644 --- a/routers/controllers/admin.go +++ b/routers/controllers/admin.go @@ -1,6 +1,7 @@ package controllers import ( + "github.com/HFO4/cloudreve/pkg/aria2" "github.com/HFO4/cloudreve/pkg/email" "github.com/HFO4/cloudreve/pkg/request" "github.com/HFO4/cloudreve/pkg/serializer" @@ -66,6 +67,8 @@ func AdminReloadService(c *gin.Context) { switch service { case "email": email.Init() + case "aria2": + aria2.Init(true) } c.JSON(200, serializer.Response{}) diff --git a/service/aria2/add.go b/service/aria2/add.go index 2b67578..8f552a2 100644 --- a/service/aria2/add.go +++ b/service/aria2/add.go @@ -41,9 +41,13 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo UserID: fs.User.ID, Source: service.URL, } + + aria2.Lock.RLock() if err := aria2.Instance.CreateTask(task, fs.User.Group.OptionsSerialized.Aria2Options); err != nil { + aria2.Lock.RUnlock() return serializer.Err(serializer.CodeNotSet, "任务创建失败", err) } + aria2.Lock.RUnlock() return serializer.Response{} } diff --git a/service/aria2/manage.go b/service/aria2/manage.go index f36ab30..4d3ab7d 100644 --- a/service/aria2/manage.go +++ b/service/aria2/manage.go @@ -52,6 +52,8 @@ func (service *DownloadTaskService) Delete(c *gin.Context) serializer.Response { } // 取消任务 + aria2.Lock.RLock() + defer aria2.Lock.RUnlock() if err := aria2.Instance.Cancel(download); err != nil { return serializer.Err(serializer.CodeNotSet, "操作失败", err) } @@ -75,6 +77,8 @@ func (service *SelectFileService) Select(c *gin.Context) serializer.Response { } // 选取下载 + aria2.Lock.RLock() + defer aria2.Lock.RUnlock() if err := aria2.Instance.Select(download, service.Indexes); err != nil { return serializer.Err(serializer.CodeNotSet, "操作失败", err) }