Test: balancer / auth / controller in pkg
This commit is contained in:
parent
f0089045d7
commit
416f4c1dd2
15 changed files with 375 additions and 285 deletions
|
@ -9,6 +9,7 @@ import (
|
|||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/crontab"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/email"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/task"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
@ -53,7 +54,7 @@ func Init(path string) {
|
|||
{
|
||||
"master",
|
||||
func() {
|
||||
aria2.Init(false)
|
||||
aria2.Init(false, cluster.Default, mq.GlobalMQ)
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
|
@ -33,7 +33,7 @@ func GetLoadBalancer() balancer.Balancer {
|
|||
}
|
||||
|
||||
// Init 初始化
|
||||
func Init(isReload bool) {
|
||||
func Init(isReload bool, pool cluster.Pool, mqClient mq.MQ) {
|
||||
Lock.Lock()
|
||||
LB = balancer.NewBalancer("RoundRobin")
|
||||
Lock.Unlock()
|
||||
|
@ -44,7 +44,7 @@ func Init(isReload bool) {
|
|||
|
||||
for i := 0; i < len(unfinished); i++ {
|
||||
// 创建任务监控
|
||||
monitor.NewMonitor(&unfinished[i], cluster.Default, mq.GlobalMQ)
|
||||
monitor.NewMonitor(&unfinished[i], pool, mqClient)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,14 +2,15 @@ package aria2
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mocks"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/stretchr/testify/assert"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var mock sqlmock.Sqlmock
|
||||
|
@ -27,66 +28,39 @@ func TestMain(m *testing.M) {
|
|||
m.Run()
|
||||
}
|
||||
|
||||
func TestDummyAria2(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
instance := DummyAria2{}
|
||||
asserts.Error(instance.CreateTask(nil, nil))
|
||||
_, err := instance.Status(nil)
|
||||
asserts.Error(err)
|
||||
asserts.Error(instance.Cancel(nil))
|
||||
asserts.Error(instance.Select(nil, nil))
|
||||
}
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
monitor.MAX_RETRY = 0
|
||||
asserts := assert.New(t)
|
||||
cache.Set("setting_aria2_token", "1", 0)
|
||||
cache.Set("setting_aria2_call_timeout", "5", 0)
|
||||
cache.Set("setting_aria2_options", `[]`, 0)
|
||||
a := assert.New(t)
|
||||
mockPool := &mocks.NodePoolMock{}
|
||||
mockPool.On("GetNodeByID", testMock.Anything).Return(nil)
|
||||
mockQueue := mq.NewMQ()
|
||||
|
||||
// 未指定RPC地址,跳过
|
||||
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||
Init(false, mockPool, mockQueue)
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockPool.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTestRPCConnection(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
// url not legal
|
||||
{
|
||||
cache.Set("setting_aria2_rpcurl", "", 0)
|
||||
Init(false)
|
||||
asserts.IsType(&DummyAria2{}, Instance)
|
||||
res, err := TestRPCConnection(string([]byte{0x7f}), "", 10)
|
||||
a.Error(err)
|
||||
a.Empty(res.Version)
|
||||
}
|
||||
|
||||
// 无法解析服务器地址
|
||||
// rpc failed
|
||||
{
|
||||
cache.Set("setting_aria2_rpcurl", string(byte(0x7f)), 0)
|
||||
Init(false)
|
||||
asserts.IsType(&DummyAria2{}, Instance)
|
||||
}
|
||||
|
||||
// 无法解析全局配置
|
||||
{
|
||||
Instance = &RPCService{}
|
||||
cache.Set("setting_aria2_options", "?", 0)
|
||||
cache.Set("setting_aria2_rpcurl", "ws://127.0.0.1:1234", 0)
|
||||
Init(false)
|
||||
asserts.IsType(&DummyAria2{}, Instance)
|
||||
}
|
||||
|
||||
// 连接失败
|
||||
{
|
||||
cache.Set("setting_aria2_options", "{}", 0)
|
||||
cache.Set("setting_aria2_rpcurl", "http://127.0.0.1:1234", 0)
|
||||
cache.Set("setting_aria2_call_timeout", "1", 0)
|
||||
cache.Set("setting_aria2_interval", "100", 0)
|
||||
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"g_id"}).AddRow("1"))
|
||||
Init(false)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.IsType(&RPCService{}, Instance)
|
||||
res, err := TestRPCConnection("ws://0.0.0.0", "", 0)
|
||||
a.Error(err)
|
||||
a.Empty(res.Version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStatus(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
asserts.Equal(4, GetStatus("complete"))
|
||||
asserts.Equal(1, GetStatus("active"))
|
||||
asserts.Equal(0, GetStatus("waiting"))
|
||||
asserts.Equal(2, GetStatus("paused"))
|
||||
asserts.Equal(3, GetStatus("error"))
|
||||
asserts.Equal(5, GetStatus("removed"))
|
||||
asserts.Equal(6, GetStatus("?"))
|
||||
func TestGetLoadBalancer(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
a.NotPanics(func() {
|
||||
GetLoadBalancer()
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,114 +0,0 @@
|
|||
package aria2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
// RPCService 通过RPC服务的Aria2任务管理器
|
||||
type RPCService struct {
|
||||
options *clientOptions
|
||||
Caller rpc.Client
|
||||
}
|
||||
|
||||
type clientOptions struct {
|
||||
Options map[string]interface{} // 创建下载时额外添加的设置
|
||||
}
|
||||
|
||||
// Init 初始化
|
||||
func (client *RPCService) Init(server, secret string, timeout int, options map[string]interface{}) error {
|
||||
// 客户端已存在,则关闭先前连接
|
||||
if client.Caller != nil {
|
||||
client.Caller.Close()
|
||||
}
|
||||
|
||||
client.options = &clientOptions{
|
||||
Options: options,
|
||||
}
|
||||
caller, err := rpc.New(context.Background(), server, secret, time.Duration(timeout)*time.Second,
|
||||
mq.GlobalMQ)
|
||||
client.Caller = caller
|
||||
return err
|
||||
}
|
||||
|
||||
// Status 查询下载状态
|
||||
func (client *RPCService) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||
res, err := client.Caller.TellStatus(task.GID)
|
||||
if err != nil {
|
||||
// 失败后重试
|
||||
util.Log().Debug("无法获取离线下载状态,%s,10秒钟后重试", err)
|
||||
time.Sleep(time.Duration(10) * time.Second)
|
||||
res, err = client.Caller.TellStatus(task.GID)
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
// Cancel 取消下载
|
||||
func (client *RPCService) Cancel(task *model.Download) error {
|
||||
// 取消下载任务
|
||||
_, err := client.Caller.Remove(task.GID)
|
||||
if err != nil {
|
||||
util.Log().Warning("无法取消离线下载任务[%s], %s", task.GID, err)
|
||||
}
|
||||
|
||||
//// 删除临时文件
|
||||
//util.Log().Debug("离线下载任务[%s]已取消,1 分钟后删除临时文件", task.GID)
|
||||
//go func(task *model.Download) {
|
||||
// select {
|
||||
// case <-time.After(time.Duration(60) * time.Second):
|
||||
// err := os.RemoveAll(task.Parent)
|
||||
// if err != nil {
|
||||
// util.Log().Warning("无法删除离线下载临时目录[%s], %s", task.Parent, err)
|
||||
// }
|
||||
// }
|
||||
//}(task)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Select 选取要下载的文件
|
||||
func (client *RPCService) Select(task *model.Download, files []int) error {
|
||||
var selected = make([]string, len(files))
|
||||
for i := 0; i < len(files); i++ {
|
||||
selected[i] = strconv.Itoa(files[i])
|
||||
}
|
||||
_, err := client.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")})
|
||||
return err
|
||||
}
|
||||
|
||||
// CreateTask 创建新任务
|
||||
func (client *RPCService) CreateTask(task *model.Download, groupOptions map[string]interface{}) (string, error) {
|
||||
// 生成存储路径
|
||||
path := filepath.Join(
|
||||
model.GetSettingByName("aria2_temp_path"),
|
||||
"aria2",
|
||||
strconv.FormatInt(time.Now().UnixNano(), 10),
|
||||
)
|
||||
|
||||
// 创建下载任务
|
||||
options := map[string]interface{}{
|
||||
"dir": path,
|
||||
}
|
||||
for k, v := range client.options.Options {
|
||||
options[k] = v
|
||||
}
|
||||
for k, v := range groupOptions {
|
||||
options[k] = v
|
||||
}
|
||||
|
||||
gid, err := client.Caller.AddURI(task.Source, options)
|
||||
if err != nil || gid == "" {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return gid, nil
|
||||
}
|
|
@ -1,52 +0,0 @@
|
|||
package aria2
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRPCService_Init(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
caller := &RPCService{}
|
||||
asserts.Error(caller.Init("ws://", "", 1, nil))
|
||||
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
|
||||
}
|
||||
|
||||
func TestRPCService_Status(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
caller := &RPCService{}
|
||||
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
|
||||
|
||||
_, err := caller.Status(&model.Download{})
|
||||
asserts.Error(err)
|
||||
}
|
||||
|
||||
func TestRPCService_Cancel(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
caller := &RPCService{}
|
||||
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
|
||||
|
||||
err := caller.Cancel(&model.Download{Parent: "test"})
|
||||
asserts.Error(err)
|
||||
}
|
||||
|
||||
func TestRPCService_Select(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
caller := &RPCService{}
|
||||
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
|
||||
|
||||
err := caller.Select(&model.Download{Parent: "test"}, []int{1, 2, 3})
|
||||
asserts.Error(err)
|
||||
}
|
||||
|
||||
func TestRPCService_CreateTask(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
caller := &RPCService{}
|
||||
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
|
||||
cache.Set("setting_aria2_temp_path", "test", 0)
|
||||
err := caller.CreateTask(&model.Download{Parent: "test"}, map[string]interface{}{"1": "1"})
|
||||
asserts.Error(err)
|
||||
}
|
|
@ -1,52 +0,0 @@
|
|||
package aria2
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNotifier_Notify(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
notifier2 := &Notifier{}
|
||||
notifyChan := make(chan StatusEvent, 10)
|
||||
notifier2.Subscribe(notifyChan, "1")
|
||||
|
||||
// 未订阅
|
||||
{
|
||||
notifier2.Notify([]rpc.Event{rpc.Event{Gid: ""}}, 1)
|
||||
asserts.Len(notifyChan, 0)
|
||||
}
|
||||
|
||||
// 订阅
|
||||
{
|
||||
notifier2.Notify([]rpc.Event{{Gid: "1"}}, 1)
|
||||
asserts.Len(notifyChan, 1)
|
||||
<-notifyChan
|
||||
|
||||
notifier2.OnBtDownloadComplete([]rpc.Event{{Gid: "1"}})
|
||||
asserts.Len(notifyChan, 1)
|
||||
<-notifyChan
|
||||
|
||||
notifier2.OnDownloadStart([]rpc.Event{{Gid: "1"}})
|
||||
asserts.Len(notifyChan, 1)
|
||||
<-notifyChan
|
||||
|
||||
notifier2.OnDownloadPause([]rpc.Event{{Gid: "1"}})
|
||||
asserts.Len(notifyChan, 1)
|
||||
<-notifyChan
|
||||
|
||||
notifier2.OnDownloadStop([]rpc.Event{{Gid: "1"}})
|
||||
asserts.Len(notifyChan, 1)
|
||||
<-notifyChan
|
||||
|
||||
notifier2.OnDownloadComplete([]rpc.Event{{Gid: "1"}})
|
||||
asserts.Len(notifyChan, 1)
|
||||
<-notifyChan
|
||||
|
||||
notifier2.OnDownloadError([]rpc.Event{{Gid: "1"}})
|
||||
asserts.Len(notifyChan, 1)
|
||||
<-notifyChan
|
||||
}
|
||||
}
|
|
@ -17,8 +17,10 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
ErrAuthFailed = serializer.NewError(serializer.CodeNoPermissionErr, "鉴权失败", nil)
|
||||
ErrExpired = serializer.NewError(serializer.CodeSignExpired, "签名已过期", nil)
|
||||
ErrAuthFailed = serializer.NewError(serializer.CodeNoPermissionErr, "鉴权失败", nil)
|
||||
ErrAuthHeaderMissing = serializer.NewError(serializer.CodeNoPermissionErr, "authorization header is missing", nil)
|
||||
ErrExpiresMissing = serializer.NewError(serializer.CodeNoPermissionErr, "expire timestamp is missing", nil)
|
||||
ErrExpired = serializer.NewError(serializer.CodeSignExpired, "签名已过期", nil)
|
||||
)
|
||||
|
||||
// General 通用的认证接口
|
||||
|
@ -55,7 +57,7 @@ func CheckRequest(instance Auth, r *http.Request) error {
|
|||
ok bool
|
||||
)
|
||||
if sign, ok = r.Header["Authorization"]; !ok || len(sign) == 0 {
|
||||
return ErrAuthFailed
|
||||
return ErrAuthHeaderMissing
|
||||
}
|
||||
sign[0] = strings.TrimPrefix(sign[0], "Bearer ")
|
||||
|
||||
|
|
|
@ -80,6 +80,19 @@ func TestCheckRequest(t *testing.T) {
|
|||
asserts := assert.New(t)
|
||||
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
|
||||
|
||||
// 缺少请求头
|
||||
{
|
||||
req, err := http.NewRequest(
|
||||
"POST",
|
||||
"http://127.0.0.1/api/v3/upload",
|
||||
strings.NewReader("I am body."),
|
||||
)
|
||||
asserts.NoError(err)
|
||||
err = CheckRequest(General, req)
|
||||
asserts.Error(err)
|
||||
asserts.Equal(ErrAuthHeaderMissing, err)
|
||||
}
|
||||
|
||||
// 非上传请求 验证成功
|
||||
{
|
||||
req, err := http.NewRequest(
|
||||
|
|
|
@ -33,7 +33,7 @@ func (auth HMACAuth) Check(body string, sign string) error {
|
|||
signSlice := strings.Split(sign, ":")
|
||||
// 如果未携带expires字段
|
||||
if signSlice[len(signSlice)-1] == "" {
|
||||
return ErrAuthFailed
|
||||
return ErrExpiresMissing
|
||||
}
|
||||
|
||||
// 验证是否过期
|
||||
|
|
12
pkg/balancer/balancer_test.go
Normal file
12
pkg/balancer/balancer_test.go
Normal file
|
@ -0,0 +1,12 @@
|
|||
package balancer
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewBalancer(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
a.NotNil(NewBalancer(""))
|
||||
a.IsType(&RoundRobin{}, NewBalancer("RoundRobin"))
|
||||
}
|
42
pkg/balancer/roundrobin_test.go
Normal file
42
pkg/balancer/roundrobin_test.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package balancer
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRoundRobin_NextIndex(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
r := &RoundRobin{}
|
||||
total := 5
|
||||
for i := 1; i < total; i++ {
|
||||
a.Equal(i, r.NextIndex(total))
|
||||
}
|
||||
for i := 0; i < total; i++ {
|
||||
a.Equal(i, r.NextIndex(total))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobin_NextPeer(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
r := &RoundRobin{}
|
||||
|
||||
// not slice
|
||||
{
|
||||
err, _ := r.NextPeer("s")
|
||||
a.Equal(ErrInputNotSlice, err)
|
||||
}
|
||||
|
||||
// no nodes
|
||||
{
|
||||
err, _ := r.NextPeer([]string{})
|
||||
a.Equal(ErrNoAvaliableNode, err)
|
||||
}
|
||||
|
||||
// pass
|
||||
{
|
||||
err, res := r.NextPeer([]string{"a"})
|
||||
a.NoError(err)
|
||||
a.Equal("a", res.(string))
|
||||
}
|
||||
}
|
254
pkg/cluster/controller_test.go
Normal file
254
pkg/cluster/controller_test.go
Normal file
|
@ -0,0 +1,254 @@
|
|||
package cluster
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/stretchr/testify/assert"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInitController(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
InitController()
|
||||
})
|
||||
}
|
||||
|
||||
func TestSlaveController_HandleHeartBeat(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
c := &slaveController{
|
||||
masters: make(map[string]MasterInfo),
|
||||
}
|
||||
|
||||
// first heart beat
|
||||
{
|
||||
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
|
||||
SiteID: "1",
|
||||
Node: &model.Node{},
|
||||
})
|
||||
a.NoError(err)
|
||||
|
||||
_, err = c.HandleHeartBeat(&serializer.NodePingReq{
|
||||
SiteID: "2",
|
||||
Node: &model.Node{},
|
||||
})
|
||||
a.NoError(err)
|
||||
|
||||
a.Len(c.masters, 2)
|
||||
}
|
||||
|
||||
// second heart beat, no fresh
|
||||
{
|
||||
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
|
||||
SiteID: "1",
|
||||
SiteURL: "http://127.0.0.1",
|
||||
Node: &model.Node{},
|
||||
})
|
||||
a.NoError(err)
|
||||
a.Len(c.masters, 2)
|
||||
a.Empty(c.masters["1"].URL)
|
||||
}
|
||||
|
||||
// second heart beat, fresh
|
||||
{
|
||||
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
|
||||
SiteID: "1",
|
||||
IsUpdate: true,
|
||||
SiteURL: "http://127.0.0.1",
|
||||
Node: &model.Node{},
|
||||
})
|
||||
a.NoError(err)
|
||||
a.Len(c.masters, 2)
|
||||
a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
|
||||
}
|
||||
|
||||
// second heart beat, fresh, url illegal
|
||||
{
|
||||
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
|
||||
SiteID: "1",
|
||||
IsUpdate: true,
|
||||
SiteURL: string([]byte{0x7f}),
|
||||
Node: &model.Node{},
|
||||
})
|
||||
a.Error(err)
|
||||
a.Len(c.masters, 2)
|
||||
a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
|
||||
}
|
||||
}
|
||||
|
||||
type nodeMock struct {
|
||||
testMock.Mock
|
||||
}
|
||||
|
||||
func (n nodeMock) Init(node *model.Node) {
|
||||
n.Called(node)
|
||||
}
|
||||
|
||||
func (n nodeMock) IsFeatureEnabled(feature string) bool {
|
||||
args := n.Called(feature)
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (n nodeMock) SubscribeStatusChange(callback func(isActive bool, id uint)) {
|
||||
n.Called(callback)
|
||||
}
|
||||
|
||||
func (n nodeMock) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
|
||||
args := n.Called(req)
|
||||
return args.Get(0).(*serializer.NodePingResp), args.Error(1)
|
||||
}
|
||||
|
||||
func (n nodeMock) IsActive() bool {
|
||||
args := n.Called()
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (n nodeMock) GetAria2Instance() common.Aria2 {
|
||||
args := n.Called()
|
||||
return args.Get(0).(common.Aria2)
|
||||
}
|
||||
|
||||
func (n nodeMock) ID() uint {
|
||||
args := n.Called()
|
||||
return args.Get(0).(uint)
|
||||
}
|
||||
|
||||
func (n nodeMock) Kill() {
|
||||
n.Called()
|
||||
}
|
||||
|
||||
func (n nodeMock) IsMater() bool {
|
||||
args := n.Called()
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (n nodeMock) MasterAuthInstance() auth.Auth {
|
||||
args := n.Called()
|
||||
return args.Get(0).(auth.Auth)
|
||||
}
|
||||
|
||||
func (n nodeMock) SlaveAuthInstance() auth.Auth {
|
||||
args := n.Called()
|
||||
return args.Get(0).(auth.Auth)
|
||||
}
|
||||
|
||||
func (n nodeMock) DBModel() *model.Node {
|
||||
args := n.Called()
|
||||
return args.Get(0).(*model.Node)
|
||||
}
|
||||
|
||||
func TestSlaveController_GetAria2Instance(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockNode := &nodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Instance: mockNode},
|
||||
},
|
||||
}
|
||||
|
||||
// node node found
|
||||
{
|
||||
res, err := c.GetAria2Instance("2")
|
||||
a.Nil(res)
|
||||
a.Equal(ErrMasterNotFound, err)
|
||||
}
|
||||
|
||||
// node found
|
||||
{
|
||||
res, err := c.GetAria2Instance("1")
|
||||
a.NotNil(res)
|
||||
a.NoError(err)
|
||||
mockNode.AssertExpectations(t)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
type requestMock struct {
|
||||
testMock.Mock
|
||||
}
|
||||
|
||||
func (r requestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response {
|
||||
return r.Called(method, target, body, opts).Get(0).(*request.Response)
|
||||
}
|
||||
|
||||
func TestSlaveController_SendNotification(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {},
|
||||
},
|
||||
}
|
||||
|
||||
// node not exit
|
||||
{
|
||||
a.Equal(ErrMasterNotFound, c.SendNotification("2", "", mq.Message{}))
|
||||
}
|
||||
|
||||
// gob encode error
|
||||
{
|
||||
type randomType struct{}
|
||||
a.Error(c.SendNotification("1", "", mq.Message{
|
||||
Content: randomType{},
|
||||
}))
|
||||
}
|
||||
|
||||
// return none 200
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s1", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{StatusCode: http.StatusConflict},
|
||||
})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Client: mockRequest},
|
||||
},
|
||||
}
|
||||
a.Error(c.SendNotification("1", "s1", mq.Message{}))
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// master return error
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s2", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
|
||||
},
|
||||
})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Client: mockRequest},
|
||||
},
|
||||
}
|
||||
a.Equal(1, c.SendNotification("1", "s2", mq.Message{}).(serializer.AppError).Code)
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// success
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s3", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":0}")),
|
||||
},
|
||||
})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Client: mockRequest},
|
||||
},
|
||||
}
|
||||
a.NoError(c.SendNotification("1", "s3", mq.Message{}))
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
}
|
|
@ -8,9 +8,11 @@ import (
|
|||
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/task"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
"io"
|
||||
)
|
||||
|
||||
type SlaveControllerMock struct {
|
||||
|
@ -184,3 +186,11 @@ func (t TaskPoolMock) Add(num int) {
|
|||
func (t TaskPoolMock) Submit(job task.Job) {
|
||||
t.Called(job)
|
||||
}
|
||||
|
||||
type RequestMock struct {
|
||||
testMock.Mock
|
||||
}
|
||||
|
||||
func (r RequestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response {
|
||||
return r.Called(method, target, body, opts).Get(0).(*request.Response)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package controllers
|
||||
|
||||
import (
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"io"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
|
@ -72,7 +74,7 @@ func AdminReloadService(c *gin.Context) {
|
|||
case "email":
|
||||
email.Init()
|
||||
case "aria2":
|
||||
aria2.Init(true)
|
||||
aria2.Init(true, cluster.Default, mq.GlobalMQ)
|
||||
}
|
||||
|
||||
c.JSON(200, serializer.Response{})
|
||||
|
|
|
@ -48,9 +48,7 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo
|
|||
}
|
||||
|
||||
// 获取 Aria2 负载均衡器
|
||||
aria2.Lock.RLock()
|
||||
lb := aria2.LB
|
||||
aria2.Lock.RUnlock()
|
||||
lb := aria2.GetLoadBalancer()
|
||||
|
||||
// 获取 Aria2 实例
|
||||
err, node := cluster.Default.BalanceNodeByFeature("aria2", lb)
|
||||
|
|
Loading…
Add table
Reference in a new issue