Feat: generic message queue implementation
This commit is contained in:
parent
57d12cb2de
commit
eae3688137
2 changed files with 309 additions and 0 deletions
160
pkg/mq/mq.go
Normal file
160
pkg/mq/mq.go
Normal file
|
@ -0,0 +1,160 @@
|
|||
package mq
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Message 消息事件正文
|
||||
type Message struct {
|
||||
// 消息触发者
|
||||
TriggeredBy string
|
||||
|
||||
// 事件标识
|
||||
Event string
|
||||
|
||||
// 消息正文
|
||||
Content interface{}
|
||||
}
|
||||
|
||||
type CallbackFunc func(Message)
|
||||
|
||||
// MQ 消息队列
|
||||
type MQ interface {
|
||||
rpc.Notifier
|
||||
|
||||
// 发布一个消息
|
||||
Publish(string, Message)
|
||||
|
||||
// 订阅一个消息主题
|
||||
Subscribe(string, int) <-chan Message
|
||||
|
||||
// 订阅一个消息主题,注册触发回调函数
|
||||
SubscribeCallback(string, CallbackFunc)
|
||||
|
||||
// 取消订阅一个消息主题
|
||||
Unsubscribe(string, <-chan Message)
|
||||
}
|
||||
|
||||
var GlobalMQ = NewMQ()
|
||||
|
||||
func NewMQ() MQ {
|
||||
return &inMemoryMQ{
|
||||
topics: make(map[string][]chan Message),
|
||||
callbacks: make(map[string][]CallbackFunc),
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
gob.Register(Message{})
|
||||
gob.Register(rpc.Event{})
|
||||
}
|
||||
|
||||
type inMemoryMQ struct {
|
||||
topics map[string][]chan Message
|
||||
callbacks map[string][]CallbackFunc
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
func (i *inMemoryMQ) Publish(topic string, message Message) {
|
||||
i.RLock()
|
||||
subscribersChan, okChan := i.topics[topic]
|
||||
subscribersCallback, okCallback := i.callbacks[topic]
|
||||
i.RUnlock()
|
||||
|
||||
if okChan {
|
||||
go func(subscribersChan []chan Message) {
|
||||
for i := 0; i < len(subscribersChan); i++ {
|
||||
select {
|
||||
case subscribersChan[i] <- message:
|
||||
case <-time.After(time.Millisecond * 500):
|
||||
}
|
||||
}
|
||||
}(subscribersChan)
|
||||
|
||||
}
|
||||
|
||||
if okCallback {
|
||||
for i := 0; i < len(subscribersCallback); i++ {
|
||||
go subscribersCallback[i](message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (i *inMemoryMQ) Subscribe(topic string, buffer int) <-chan Message {
|
||||
ch := make(chan Message, buffer)
|
||||
i.Lock()
|
||||
i.topics[topic] = append(i.topics[topic], ch)
|
||||
i.Unlock()
|
||||
return ch
|
||||
}
|
||||
|
||||
func (i *inMemoryMQ) SubscribeCallback(topic string, callbackFunc CallbackFunc) {
|
||||
i.Lock()
|
||||
i.callbacks[topic] = append(i.callbacks[topic], callbackFunc)
|
||||
i.Unlock()
|
||||
}
|
||||
|
||||
func (i *inMemoryMQ) Unsubscribe(topic string, sub <-chan Message) {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
subscribers, ok := i.topics[topic]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var newSubs []chan Message
|
||||
for _, subscriber := range subscribers {
|
||||
if subscriber == sub {
|
||||
continue
|
||||
}
|
||||
newSubs = append(newSubs, subscriber)
|
||||
}
|
||||
|
||||
i.topics[topic] = newSubs
|
||||
}
|
||||
|
||||
func (i *inMemoryMQ) Aria2Notify(events []rpc.Event, status int) {
|
||||
for _, event := range events {
|
||||
i.Publish(event.Gid, Message{
|
||||
TriggeredBy: event.Gid,
|
||||
Event: strconv.FormatInt(int64(status), 10),
|
||||
Content: events,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// OnDownloadStart 下载开始
|
||||
func (i *inMemoryMQ) OnDownloadStart(events []rpc.Event) {
|
||||
i.Aria2Notify(events, common.Downloading)
|
||||
}
|
||||
|
||||
// OnDownloadPause 下载暂停
|
||||
func (i *inMemoryMQ) OnDownloadPause(events []rpc.Event) {
|
||||
i.Aria2Notify(events, common.Paused)
|
||||
}
|
||||
|
||||
// OnDownloadStop 下载停止
|
||||
func (i *inMemoryMQ) OnDownloadStop(events []rpc.Event) {
|
||||
i.Aria2Notify(events, common.Canceled)
|
||||
}
|
||||
|
||||
// OnDownloadComplete 下载完成
|
||||
func (i *inMemoryMQ) OnDownloadComplete(events []rpc.Event) {
|
||||
i.Aria2Notify(events, common.Complete)
|
||||
}
|
||||
|
||||
// OnDownloadError 下载出错
|
||||
func (i *inMemoryMQ) OnDownloadError(events []rpc.Event) {
|
||||
i.Aria2Notify(events, common.Error)
|
||||
}
|
||||
|
||||
// OnBtDownloadComplete BT下载完成
|
||||
func (i *inMemoryMQ) OnBtDownloadComplete(events []rpc.Event) {
|
||||
i.Aria2Notify(events, common.Complete)
|
||||
}
|
149
pkg/mq/mq_test.go
Normal file
149
pkg/mq/mq_test.go
Normal file
|
@ -0,0 +1,149 @@
|
|||
package mq
|
||||
|
||||
import (
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPublishAndSubscribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
asserts := assert.New(t)
|
||||
mq := NewMQ()
|
||||
|
||||
// No subscriber
|
||||
{
|
||||
asserts.NotPanics(func() {
|
||||
mq.Publish("No subscriber", Message{})
|
||||
})
|
||||
}
|
||||
|
||||
// One channel subscriber
|
||||
{
|
||||
topic := "One channel subscriber"
|
||||
msg := Message{TriggeredBy: "Tester"}
|
||||
notifier := mq.Subscribe(topic, 0)
|
||||
mq.Publish(topic, msg)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Done()
|
||||
msgRecv := <-notifier
|
||||
asserts.Equal(msg, msgRecv)
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// two channel subscriber
|
||||
{
|
||||
topic := "two channel subscriber"
|
||||
msg := Message{TriggeredBy: "Tester"}
|
||||
notifier := mq.Subscribe(topic, 0)
|
||||
notifier2 := mq.Subscribe(topic, 0)
|
||||
mq.Publish(topic, msg)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
wg.Done()
|
||||
msgRecv := <-notifier
|
||||
asserts.Equal(msg, msgRecv)
|
||||
}()
|
||||
go func() {
|
||||
wg.Done()
|
||||
msgRecv := <-notifier2
|
||||
asserts.Equal(msg, msgRecv)
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// two channel subscriber, one timeout
|
||||
{
|
||||
topic := "two channel subscriber, one timeout"
|
||||
msg := Message{TriggeredBy: "Tester"}
|
||||
mq.Subscribe(topic, 0)
|
||||
notifier2 := mq.Subscribe(topic, 0)
|
||||
mq.Publish(topic, msg)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Done()
|
||||
msgRecv := <-notifier2
|
||||
asserts.Equal(msg, msgRecv)
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// two channel subscriber, one unsubscribe
|
||||
{
|
||||
topic := "two channel subscriber, one unsubscribe"
|
||||
msg := Message{TriggeredBy: "Tester"}
|
||||
mq.Subscribe(topic, 0)
|
||||
notifier2 := mq.Subscribe(topic, 0)
|
||||
notifier := mq.Subscribe(topic, 0)
|
||||
mq.Unsubscribe(topic, notifier)
|
||||
mq.Publish(topic, msg)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Done()
|
||||
msgRecv := <-notifier2
|
||||
asserts.Equal(msg, msgRecv)
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
case <-notifier:
|
||||
t.Error()
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAria2Interface(t *testing.T) {
|
||||
t.Parallel()
|
||||
asserts := assert.New(t)
|
||||
mq := NewMQ()
|
||||
var (
|
||||
OnDownloadStart int
|
||||
OnDownloadPause int
|
||||
OnDownloadStop int
|
||||
OnDownloadComplete int
|
||||
OnDownloadError int
|
||||
)
|
||||
l := sync.Mutex{}
|
||||
|
||||
mq.SubscribeCallback("TestAria2Interface", func(message Message) {
|
||||
asserts.Equal("TestAria2Interface", message.TriggeredBy)
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
switch message.Event {
|
||||
case "1":
|
||||
OnDownloadStart++
|
||||
case "2":
|
||||
OnDownloadPause++
|
||||
case "5":
|
||||
OnDownloadStop++
|
||||
case "4":
|
||||
OnDownloadComplete++
|
||||
case "3":
|
||||
OnDownloadError++
|
||||
}
|
||||
})
|
||||
|
||||
mq.OnDownloadStart([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
|
||||
mq.OnDownloadPause([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
|
||||
mq.OnDownloadStop([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
|
||||
mq.OnDownloadComplete([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
|
||||
mq.OnDownloadError([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
|
||||
mq.OnBtDownloadComplete([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
|
||||
|
||||
time.Sleep(time.Duration(500) * time.Millisecond)
|
||||
|
||||
asserts.Equal(2, OnDownloadStart)
|
||||
asserts.Equal(2, OnDownloadPause)
|
||||
asserts.Equal(2, OnDownloadStop)
|
||||
asserts.Equal(4, OnDownloadComplete)
|
||||
asserts.Equal(2, OnDownloadError)
|
||||
}
|
Loading…
Add table
Reference in a new issue