From 416f4c1dd22ccdf3e96ee848261f3f78976403be Mon Sep 17 00:00:00 2001
From: HFO4 <912394456@qq.com>
Date: Thu, 11 Nov 2021 20:56:16 +0800
Subject: [PATCH] Test: balancer / auth / controller in pkg

---
 bootstrap/init.go               |   3 +-
 pkg/aria2/aria2.go              |   4 +-
 pkg/aria2/aria2_test.go         |  86 ++++-------
 pkg/aria2/caller.go             | 114 --------------
 pkg/aria2/caller_test.go        |  52 -------
 pkg/aria2/notification_test.go  |  52 -------
 pkg/auth/auth.go                |   8 +-
 pkg/auth/auth_test.go           |  13 ++
 pkg/auth/hmac.go                |   2 +-
 pkg/balancer/balancer_test.go   |  12 ++
 pkg/balancer/roundrobin_test.go |  42 ++++++
 pkg/cluster/controller_test.go  | 254 ++++++++++++++++++++++++++++++++
 pkg/mocks/mocks.go              |  10 ++
 routers/controllers/admin.go    |   4 +-
 service/aria2/add.go            |   4 +-
 15 files changed, 375 insertions(+), 285 deletions(-)
 delete mode 100644 pkg/aria2/caller.go
 delete mode 100644 pkg/aria2/caller_test.go
 delete mode 100644 pkg/aria2/notification_test.go
 create mode 100644 pkg/balancer/balancer_test.go
 create mode 100644 pkg/balancer/roundrobin_test.go
 create mode 100644 pkg/cluster/controller_test.go

diff --git a/bootstrap/init.go b/bootstrap/init.go
index 60ebcb8..0a51835 100644
--- a/bootstrap/init.go
+++ b/bootstrap/init.go
@@ -9,6 +9,7 @@ import (
 	"github.com/cloudreve/Cloudreve/v3/pkg/conf"
 	"github.com/cloudreve/Cloudreve/v3/pkg/crontab"
 	"github.com/cloudreve/Cloudreve/v3/pkg/email"
+	"github.com/cloudreve/Cloudreve/v3/pkg/mq"
 	"github.com/cloudreve/Cloudreve/v3/pkg/task"
 	"github.com/gin-gonic/gin"
 )
@@ -53,7 +54,7 @@ func Init(path string) {
 		{
 			"master",
 			func() {
-				aria2.Init(false)
+				aria2.Init(false, cluster.Default, mq.GlobalMQ)
 			},
 		},
 		{
diff --git a/pkg/aria2/aria2.go b/pkg/aria2/aria2.go
index ef2f0df..60d254e 100644
--- a/pkg/aria2/aria2.go
+++ b/pkg/aria2/aria2.go
@@ -33,7 +33,7 @@ func GetLoadBalancer() balancer.Balancer {
 }
 
 // Init 初始化
-func Init(isReload bool) {
+func Init(isReload bool, pool cluster.Pool, mqClient mq.MQ) {
 	Lock.Lock()
 	LB = balancer.NewBalancer("RoundRobin")
 	Lock.Unlock()
@@ -44,7 +44,7 @@ func Init(isReload bool) {
 
 		for i := 0; i < len(unfinished); i++ {
 			// 创建任务监控
-			monitor.NewMonitor(&unfinished[i], cluster.Default, mq.GlobalMQ)
+			monitor.NewMonitor(&unfinished[i], pool, mqClient)
 		}
 	}
 }
diff --git a/pkg/aria2/aria2_test.go b/pkg/aria2/aria2_test.go
index dfd71a3..b6e7092 100644
--- a/pkg/aria2/aria2_test.go
+++ b/pkg/aria2/aria2_test.go
@@ -2,14 +2,15 @@ package aria2
 
 import (
 	"database/sql"
+	"github.com/cloudreve/Cloudreve/v3/pkg/mocks"
+	"github.com/cloudreve/Cloudreve/v3/pkg/mq"
+	"github.com/stretchr/testify/assert"
+	testMock "github.com/stretchr/testify/mock"
 	"testing"
 
 	"github.com/DATA-DOG/go-sqlmock"
 	model "github.com/cloudreve/Cloudreve/v3/models"
-	"github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor"
-	"github.com/cloudreve/Cloudreve/v3/pkg/cache"
 	"github.com/jinzhu/gorm"
-	"github.com/stretchr/testify/assert"
 )
 
 var mock sqlmock.Sqlmock
@@ -27,66 +28,39 @@ func TestMain(m *testing.M) {
 	m.Run()
 }
 
-func TestDummyAria2(t *testing.T) {
-	asserts := assert.New(t)
-	instance := DummyAria2{}
-	asserts.Error(instance.CreateTask(nil, nil))
-	_, err := instance.Status(nil)
-	asserts.Error(err)
-	asserts.Error(instance.Cancel(nil))
-	asserts.Error(instance.Select(nil, nil))
-}
-
 func TestInit(t *testing.T) {
-	monitor.MAX_RETRY = 0
-	asserts := assert.New(t)
-	cache.Set("setting_aria2_token", "1", 0)
-	cache.Set("setting_aria2_call_timeout", "5", 0)
-	cache.Set("setting_aria2_options", `[]`, 0)
+	a := assert.New(t)
+	mockPool := &mocks.NodePoolMock{}
+	mockPool.On("GetNodeByID", testMock.Anything).Return(nil)
+	mockQueue := mq.NewMQ()
 
-	// 未指定RPC地址,跳过
+	mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
+	Init(false, mockPool, mockQueue)
+	a.NoError(mock.ExpectationsWereMet())
+	mockPool.AssertExpectations(t)
+}
+
+func TestTestRPCConnection(t *testing.T) {
+	a := assert.New(t)
+
+	// url not legal
 	{
-		cache.Set("setting_aria2_rpcurl", "", 0)
-		Init(false)
-		asserts.IsType(&DummyAria2{}, Instance)
+		res, err := TestRPCConnection(string([]byte{0x7f}), "", 10)
+		a.Error(err)
+		a.Empty(res.Version)
 	}
 
-	// 无法解析服务器地址
+	// rpc failed
 	{
-		cache.Set("setting_aria2_rpcurl", string(byte(0x7f)), 0)
-		Init(false)
-		asserts.IsType(&DummyAria2{}, Instance)
-	}
-
-	// 无法解析全局配置
-	{
-		Instance = &RPCService{}
-		cache.Set("setting_aria2_options", "?", 0)
-		cache.Set("setting_aria2_rpcurl", "ws://127.0.0.1:1234", 0)
-		Init(false)
-		asserts.IsType(&DummyAria2{}, Instance)
-	}
-
-	// 连接失败
-	{
-		cache.Set("setting_aria2_options", "{}", 0)
-		cache.Set("setting_aria2_rpcurl", "http://127.0.0.1:1234", 0)
-		cache.Set("setting_aria2_call_timeout", "1", 0)
-		cache.Set("setting_aria2_interval", "100", 0)
-		mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"g_id"}).AddRow("1"))
-		Init(false)
-		asserts.NoError(mock.ExpectationsWereMet())
-		asserts.IsType(&RPCService{}, Instance)
+		res, err := TestRPCConnection("ws://0.0.0.0", "", 0)
+		a.Error(err)
+		a.Empty(res.Version)
 	}
 }
 
-func TestGetStatus(t *testing.T) {
-	asserts := assert.New(t)
-	asserts.Equal(4, GetStatus("complete"))
-	asserts.Equal(1, GetStatus("active"))
-	asserts.Equal(0, GetStatus("waiting"))
-	asserts.Equal(2, GetStatus("paused"))
-	asserts.Equal(3, GetStatus("error"))
-	asserts.Equal(5, GetStatus("removed"))
-	asserts.Equal(6, GetStatus("?"))
+func TestGetLoadBalancer(t *testing.T) {
+	a := assert.New(t)
+	a.NotPanics(func() {
+		GetLoadBalancer()
+	})
 }
diff --git a/pkg/aria2/caller.go b/pkg/aria2/caller.go
deleted file mode 100644
index 70e0bea..0000000
--- a/pkg/aria2/caller.go
+++ /dev/null
@@ -1,114 +0,0 @@
-package aria2
-
-import (
-	"context"
-	"path/filepath"
-	"strconv"
-	"strings"
-	"time"
-
-	model "github.com/cloudreve/Cloudreve/v3/models"
-	"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
-	"github.com/cloudreve/Cloudreve/v3/pkg/mq"
-	"github.com/cloudreve/Cloudreve/v3/pkg/util"
-)
-
-// RPCService 通过RPC服务的Aria2任务管理器
-type RPCService struct {
-	options *clientOptions
-	Caller  rpc.Client
-}
-
-type clientOptions struct {
-	Options map[string]interface{} // 创建下载时额外添加的设置
-}
-
-// Init 初始化
-func (client *RPCService) Init(server, secret string, timeout int, options map[string]interface{}) error {
-	// 客户端已存在,则关闭先前连接
-	if client.Caller != nil {
-		client.Caller.Close()
-	}
-
-	client.options = &clientOptions{
-		Options: options,
-	}
-	caller, err := rpc.New(context.Background(), server, secret, time.Duration(timeout)*time.Second,
-		mq.GlobalMQ)
-	client.Caller = caller
-	return err
-}
-
-// Status 查询下载状态
-func (client *RPCService) Status(task *model.Download) (rpc.StatusInfo, error) {
-	res, err := client.Caller.TellStatus(task.GID)
-	if err != nil {
-		// 失败后重试
-		util.Log().Debug("无法获取离线下载状态,%s,10秒钟后重试", err)
-		time.Sleep(time.Duration(10) * time.Second)
-		res, err = client.Caller.TellStatus(task.GID)
-	}
-
-	return res, err
-}
-
-// Cancel 取消下载
-func (client *RPCService) Cancel(task *model.Download) error {
-	// 取消下载任务
-	_, err := client.Caller.Remove(task.GID)
-	if err != nil {
-		util.Log().Warning("无法取消离线下载任务[%s], %s", task.GID, err)
-	}
-
-	//// 删除临时文件
-	//util.Log().Debug("离线下载任务[%s]已取消,1 分钟后删除临时文件", task.GID)
-	//go func(task *model.Download) {
-	//	select {
-	//	case <-time.After(time.Duration(60) * time.Second):
-	//		err := os.RemoveAll(task.Parent)
-	//		if err != nil {
-	//			util.Log().Warning("无法删除离线下载临时目录[%s], %s", task.Parent, err)
-	//		}
-	//	}
-	//}(task)
-
-	return err
-}
-
-// Select 选取要下载的文件
-func (client *RPCService) Select(task *model.Download, files []int) error {
-	var selected = make([]string, len(files))
-	for i := 0; i < len(files); i++ {
-		selected[i] = strconv.Itoa(files[i])
-	}
-	_, err := client.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")})
-	return err
-}
-
-// CreateTask 创建新任务
-func (client *RPCService) CreateTask(task *model.Download, groupOptions map[string]interface{}) (string, error) {
-	// 生成存储路径
-	path := filepath.Join(
-		model.GetSettingByName("aria2_temp_path"),
-		"aria2",
-		strconv.FormatInt(time.Now().UnixNano(), 10),
-	)
-
-	// 创建下载任务
-	options := map[string]interface{}{
-		"dir": path,
-	}
-	for k, v := range client.options.Options {
-		options[k] = v
-	}
-	for k, v := range groupOptions {
-		options[k] = v
-	}
-
-	gid, err := client.Caller.AddURI(task.Source, options)
-	if err != nil || gid == "" {
-		return "", err
-	}
-
-	return gid, nil
-}
diff --git a/pkg/aria2/caller_test.go b/pkg/aria2/caller_test.go
deleted file mode 100644
index f215689..0000000
--- a/pkg/aria2/caller_test.go
+++ /dev/null
@@ -1,52 +0,0 @@
-package aria2
-
-import (
-	"testing"
-
-	model "github.com/cloudreve/Cloudreve/v3/models"
-	"github.com/cloudreve/Cloudreve/v3/pkg/cache"
-	"github.com/stretchr/testify/assert"
-)
-
-func TestRPCService_Init(t *testing.T) {
-	asserts := assert.New(t)
-	caller := &RPCService{}
-	asserts.Error(caller.Init("ws://", "", 1, nil))
-	asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
-}
-
-func TestRPCService_Status(t *testing.T) {
-	asserts := assert.New(t)
-	caller := &RPCService{}
-	asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
-
-	_, err := caller.Status(&model.Download{})
-	asserts.Error(err)
-}
-
-func TestRPCService_Cancel(t *testing.T) {
-	asserts := assert.New(t)
-	caller := &RPCService{}
-	asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
-
-	err := caller.Cancel(&model.Download{Parent: "test"})
-	asserts.Error(err)
-}
-
-func TestRPCService_Select(t *testing.T) {
-	asserts := assert.New(t)
-	caller := &RPCService{}
-	asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
-
-	err := caller.Select(&model.Download{Parent: "test"}, []int{1, 2, 3})
-	asserts.Error(err)
-}
-
-func TestRPCService_CreateTask(t *testing.T) {
-	asserts := assert.New(t)
-	caller := &RPCService{}
-	asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
-	cache.Set("setting_aria2_temp_path", "test", 0)
-	err := caller.CreateTask(&model.Download{Parent: "test"}, map[string]interface{}{"1": "1"})
-	asserts.Error(err)
-}
diff --git a/pkg/aria2/notification_test.go b/pkg/aria2/notification_test.go
deleted file mode 100644
index 21a7ac1..0000000
--- a/pkg/aria2/notification_test.go
+++ /dev/null
@@ -1,52 +0,0 @@
-package aria2
-
-import (
-	"testing"
-
-	"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
-	"github.com/stretchr/testify/assert"
-)
-
-func TestNotifier_Notify(t *testing.T) {
-	asserts := assert.New(t)
-	notifier2 := &Notifier{}
-	notifyChan := make(chan StatusEvent, 10)
-	notifier2.Subscribe(notifyChan, "1")
-
-	// 未订阅
-	{
-		notifier2.Notify([]rpc.Event{rpc.Event{Gid: ""}}, 1)
-		asserts.Len(notifyChan, 0)
-	}
-
-	// 订阅
-	{
-		notifier2.Notify([]rpc.Event{{Gid: "1"}}, 1)
-		asserts.Len(notifyChan, 1)
-		<-notifyChan
-
-		notifier2.OnBtDownloadComplete([]rpc.Event{{Gid: "1"}})
-		asserts.Len(notifyChan, 1)
-		<-notifyChan
-
-		notifier2.OnDownloadStart([]rpc.Event{{Gid: "1"}})
-		asserts.Len(notifyChan, 1)
-		<-notifyChan
-
-		notifier2.OnDownloadPause([]rpc.Event{{Gid: "1"}})
-		asserts.Len(notifyChan, 1)
-		<-notifyChan
-
-		notifier2.OnDownloadStop([]rpc.Event{{Gid: "1"}})
-		asserts.Len(notifyChan, 1)
-		<-notifyChan
-
-		notifier2.OnDownloadComplete([]rpc.Event{{Gid: "1"}})
-		asserts.Len(notifyChan, 1)
-		<-notifyChan
-
-		notifier2.OnDownloadError([]rpc.Event{{Gid: "1"}})
-		asserts.Len(notifyChan, 1)
-		<-notifyChan
-	}
-}
diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go
index d8250e8..20b9e10 100644
--- a/pkg/auth/auth.go
+++ b/pkg/auth/auth.go
@@ -17,8 +17,10 @@ import (
 )
 
 var (
-	ErrAuthFailed = serializer.NewError(serializer.CodeNoPermissionErr, "鉴权失败", nil)
-	ErrExpired    = serializer.NewError(serializer.CodeSignExpired, "签名已过期", nil)
+	ErrAuthFailed        = serializer.NewError(serializer.CodeNoPermissionErr, "鉴权失败", nil)
+	ErrAuthHeaderMissing = serializer.NewError(serializer.CodeNoPermissionErr, "authorization header is missing", nil)
+	ErrExpiresMissing    = serializer.NewError(serializer.CodeNoPermissionErr, "expire timestamp is missing", nil)
+	ErrExpired           = serializer.NewError(serializer.CodeSignExpired, "签名已过期", nil)
 )
 
 // General 通用的认证接口
@@ -55,7 +57,7 @@ func CheckRequest(instance Auth, r *http.Request) error {
 		ok   bool
 	)
 	if sign, ok = r.Header["Authorization"]; !ok || len(sign) == 0 {
-		return ErrAuthFailed
+		return ErrAuthHeaderMissing
 	}
 	sign[0] = strings.TrimPrefix(sign[0], "Bearer ")
 
diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go
index 1092cb5..46533fb 100644
--- a/pkg/auth/auth_test.go
+++ b/pkg/auth/auth_test.go
@@ -80,6 +80,19 @@ func TestCheckRequest(t *testing.T) {
 	asserts := assert.New(t)
 	General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
 
+	// 缺少请求头
+	{
+		req, err := http.NewRequest(
+			"POST",
+			"http://127.0.0.1/api/v3/upload",
+			strings.NewReader("I am body."),
+		)
+		asserts.NoError(err)
+		err = CheckRequest(General, req)
+		asserts.Error(err)
+		asserts.Equal(ErrAuthHeaderMissing, err)
+	}
+
 	// 非上传请求 验证成功
 	{
 		req, err := http.NewRequest(
diff --git a/pkg/auth/hmac.go b/pkg/auth/hmac.go
index e0a9573..50849cc 100644
--- a/pkg/auth/hmac.go
+++ b/pkg/auth/hmac.go
@@ -33,7 +33,7 @@ func (auth HMACAuth) Check(body string, sign string) error {
 	signSlice := strings.Split(sign, ":")
 	// 如果未携带expires字段
 	if signSlice[len(signSlice)-1] == "" {
-		return ErrAuthFailed
+		return ErrExpiresMissing
 	}
 
 	// 验证是否过期
diff --git a/pkg/balancer/balancer_test.go b/pkg/balancer/balancer_test.go
new file mode 100644
index 0000000..4493bbb
--- /dev/null
+++ b/pkg/balancer/balancer_test.go
@@ -0,0 +1,12 @@
+package balancer
+
+import (
+	"github.com/stretchr/testify/assert"
+	"testing"
+)
+
+func TestNewBalancer(t *testing.T) {
+	a := assert.New(t)
+	a.NotNil(NewBalancer(""))
+	a.IsType(&RoundRobin{}, NewBalancer("RoundRobin"))
+}
diff --git a/pkg/balancer/roundrobin_test.go b/pkg/balancer/roundrobin_test.go
new file mode 100644
index 0000000..9cdcc00
--- /dev/null
+++ b/pkg/balancer/roundrobin_test.go
@@ -0,0 +1,42 @@
+package balancer
+
+import (
+	"github.com/stretchr/testify/assert"
+	"testing"
+)
+
+func TestRoundRobin_NextIndex(t *testing.T) {
+	a := assert.New(t)
+	r := &RoundRobin{}
+	total := 5
+	for i := 1; i < total; i++ {
+		a.Equal(i, r.NextIndex(total))
+	}
+	for i := 0; i < total; i++ {
+		a.Equal(i, r.NextIndex(total))
+	}
+}
+
+func TestRoundRobin_NextPeer(t *testing.T) {
+	a := assert.New(t)
+	r := &RoundRobin{}
+
+	// not slice
+	{
+		err, _ := r.NextPeer("s")
+		a.Equal(ErrInputNotSlice, err)
+	}
+
+	// no nodes
+	{
+		err, _ := r.NextPeer([]string{})
+		a.Equal(ErrNoAvaliableNode, err)
+	}
+
+	// pass
+	{
+		err, res := r.NextPeer([]string{"a"})
+		a.NoError(err)
+		a.Equal("a", res.(string))
+	}
+}
diff --git a/pkg/cluster/controller_test.go b/pkg/cluster/controller_test.go
new file mode 100644
index 0000000..0ee8651
--- /dev/null
+++ b/pkg/cluster/controller_test.go
@@ -0,0 +1,254 @@
+package cluster
+
+import (
+	model "github.com/cloudreve/Cloudreve/v3/models"
+	"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
+	"github.com/cloudreve/Cloudreve/v3/pkg/auth"
+	"github.com/cloudreve/Cloudreve/v3/pkg/mq"
+	"github.com/cloudreve/Cloudreve/v3/pkg/request"
+	"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
+	"github.com/stretchr/testify/assert"
+	testMock "github.com/stretchr/testify/mock"
+	"io"
+	"io/ioutil"
+	"net/http"
+	"strings"
+	"testing"
+)
+
+func TestInitController(t *testing.T) {
+	assert.NotPanics(t, func() {
+		InitController()
+	})
+}
+
+func TestSlaveController_HandleHeartBeat(t *testing.T) {
+	a := assert.New(t)
+	c := &slaveController{
+		masters: make(map[string]MasterInfo),
+	}
+
+	// first heart beat
+	{
+		_, err := c.HandleHeartBeat(&serializer.NodePingReq{
+			SiteID: "1",
+			Node:   &model.Node{},
+		})
+		a.NoError(err)
+
+		_, err = c.HandleHeartBeat(&serializer.NodePingReq{
+			SiteID: "2",
+			Node:   &model.Node{},
+		})
+		a.NoError(err)
+
+		a.Len(c.masters, 2)
+	}
+
+	// second heart beat, no fresh
+	{
+		_, err := c.HandleHeartBeat(&serializer.NodePingReq{
+			SiteID:  "1",
+			SiteURL: "http://127.0.0.1",
+			Node:    &model.Node{},
+		})
+		a.NoError(err)
+		a.Len(c.masters, 2)
+		a.Empty(c.masters["1"].URL)
+	}
+
+	// second heart beat, fresh
+	{
+		_, err := c.HandleHeartBeat(&serializer.NodePingReq{
+			SiteID:   "1",
+			IsUpdate: true,
+			SiteURL:  "http://127.0.0.1",
+			Node:     &model.Node{},
+		})
+		a.NoError(err)
+		a.Len(c.masters, 2)
+		a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
+	}
+
+	// second heart beat, fresh, url illegal
+	{
+		_, err := c.HandleHeartBeat(&serializer.NodePingReq{
+			SiteID:   "1",
+			IsUpdate: true,
+			SiteURL:  string([]byte{0x7f}),
+			Node:     &model.Node{},
+		})
+		a.Error(err)
+		a.Len(c.masters, 2)
+		a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
+	}
+}
+
+type nodeMock struct {
+	testMock.Mock
+}
+
+func (n nodeMock) Init(node *model.Node) {
+	n.Called(node)
+}
+
+func (n nodeMock) IsFeatureEnabled(feature string) bool {
+	args := n.Called(feature)
+	return args.Bool(0)
+}
+
+func (n nodeMock) SubscribeStatusChange(callback func(isActive bool, id uint)) {
+	n.Called(callback)
+}
+
+func (n nodeMock) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
+	args := n.Called(req)
+	return args.Get(0).(*serializer.NodePingResp), args.Error(1)
+}
+
+func (n nodeMock) IsActive() bool {
+	args := n.Called()
+	return args.Bool(0)
+}
+
+func (n nodeMock) GetAria2Instance() common.Aria2 {
+	args := n.Called()
+	return args.Get(0).(common.Aria2)
+}
+
+func (n nodeMock) ID() uint {
+	args := n.Called()
+	return args.Get(0).(uint)
+}
+
+func (n nodeMock) Kill() {
+	n.Called()
+}
+
+func (n nodeMock) IsMater() bool {
+	args := n.Called()
+	return args.Bool(0)
+}
+
+func (n nodeMock) MasterAuthInstance() auth.Auth {
+	args := n.Called()
+	return args.Get(0).(auth.Auth)
+}
+
+func (n nodeMock) SlaveAuthInstance() auth.Auth {
+	args := n.Called()
+	return args.Get(0).(auth.Auth)
+}
+
+func (n nodeMock) DBModel() *model.Node {
+	args := n.Called()
+	return args.Get(0).(*model.Node)
+}
+
+func TestSlaveController_GetAria2Instance(t *testing.T) {
+	a := assert.New(t)
+	mockNode := &nodeMock{}
+	mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
+	c := &slaveController{
+		masters: map[string]MasterInfo{
+			"1": {Instance: mockNode},
+		},
+	}
+
+	// node node found
+	{
+		res, err := c.GetAria2Instance("2")
+		a.Nil(res)
+		a.Equal(ErrMasterNotFound, err)
+	}
+
+	// node found
+	{
+		res, err := c.GetAria2Instance("1")
+		a.NotNil(res)
+		a.NoError(err)
+		mockNode.AssertExpectations(t)
+	}
+
+}
+
+type requestMock struct {
+	testMock.Mock
+}
+
+func (r requestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response {
+	return r.Called(method, target, body, opts).Get(0).(*request.Response)
+}
+
+func TestSlaveController_SendNotification(t *testing.T) {
+	a := assert.New(t)
+	c := &slaveController{
+		masters: map[string]MasterInfo{
+			"1": {},
+		},
+	}
+
+	// node not exit
+	{
+		a.Equal(ErrMasterNotFound, c.SendNotification("2", "", mq.Message{}))
+	}
+
+	// gob encode error
+	{
+		type randomType struct{}
+		a.Error(c.SendNotification("1", "", mq.Message{
+			Content: randomType{},
+		}))
+	}
+
+	// return none 200
+	{
+		mockRequest := &requestMock{}
+		mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s1", testMock.Anything, testMock.Anything).Return(&request.Response{
+			Response: &http.Response{StatusCode: http.StatusConflict},
+		})
+		c := &slaveController{
+			masters: map[string]MasterInfo{
+				"1": {Client: mockRequest},
+			},
+		}
+		a.Error(c.SendNotification("1", "s1", mq.Message{}))
+		mockRequest.AssertExpectations(t)
+	}
+
+	// master return error
+	{
+		mockRequest := &requestMock{}
+		mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s2", testMock.Anything, testMock.Anything).Return(&request.Response{
+			Response: &http.Response{
+				StatusCode: 200,
+				Body:       ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
+			},
+		})
+		c := &slaveController{
+			masters: map[string]MasterInfo{
+				"1": {Client: mockRequest},
+			},
+		}
+		a.Equal(1, c.SendNotification("1", "s2", mq.Message{}).(serializer.AppError).Code)
+		mockRequest.AssertExpectations(t)
+	}
+
+	// success
+	{
+		mockRequest := &requestMock{}
+		mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s3", testMock.Anything, testMock.Anything).Return(&request.Response{
+			Response: &http.Response{
+				StatusCode: 200,
+				Body:       ioutil.NopCloser(strings.NewReader("{\"code\":0}")),
+			},
+		})
+		c := &slaveController{
+			masters: map[string]MasterInfo{
+				"1": {Client: mockRequest},
+			},
+		}
+		a.NoError(c.SendNotification("1", "s3", mq.Message{}))
+		mockRequest.AssertExpectations(t)
+	}
+}
diff --git a/pkg/mocks/mocks.go b/pkg/mocks/mocks.go
index 2134e86..6b7e674 100644
--- a/pkg/mocks/mocks.go
+++ b/pkg/mocks/mocks.go
@@ -8,9 +8,11 @@ import (
 	"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
 	"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
 	"github.com/cloudreve/Cloudreve/v3/pkg/mq"
+	"github.com/cloudreve/Cloudreve/v3/pkg/request"
 	"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
 	"github.com/cloudreve/Cloudreve/v3/pkg/task"
 	testMock "github.com/stretchr/testify/mock"
+	"io"
 )
 
 type SlaveControllerMock struct {
@@ -184,3 +186,11 @@ func (t TaskPoolMock) Add(num int) {
 func (t TaskPoolMock) Submit(job task.Job) {
 	t.Called(job)
 }
+
+type RequestMock struct {
+	testMock.Mock
+}
+
+func (r RequestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response {
+	return r.Called(method, target, body, opts).Get(0).(*request.Response)
+}
diff --git a/routers/controllers/admin.go b/routers/controllers/admin.go
index a3ebfa5..fb0d6d6 100644
--- a/routers/controllers/admin.go
+++ b/routers/controllers/admin.go
@@ -1,6 +1,8 @@
 package controllers
 
 import (
+	"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
+	"github.com/cloudreve/Cloudreve/v3/pkg/mq"
 	"io"
 
 	model "github.com/cloudreve/Cloudreve/v3/models"
@@ -72,7 +74,7 @@ func AdminReloadService(c *gin.Context) {
 	case "email":
 		email.Init()
 	case "aria2":
-		aria2.Init(true)
+		aria2.Init(true, cluster.Default, mq.GlobalMQ)
 	}
 
 	c.JSON(200, serializer.Response{})
diff --git a/service/aria2/add.go b/service/aria2/add.go
index 8443c14..2c72c8b 100644
--- a/service/aria2/add.go
+++ b/service/aria2/add.go
@@ -48,9 +48,7 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo
 	}
 
 	// 获取 Aria2 负载均衡器
-	aria2.Lock.RLock()
-	lb := aria2.LB
-	aria2.Lock.RUnlock()
+	lb := aria2.GetLoadBalancer()
 
 	// 获取 Aria2 实例
 	err, node := cluster.Default.BalanceNodeByFeature("aria2", lb)