Feat: generic message queue implementation

This commit is contained in:
HFO4 2021-08-31 21:46:23 +08:00
parent 57d12cb2de
commit eae3688137
2 changed files with 309 additions and 0 deletions

160
pkg/mq/mq.go Normal file
View file

@ -0,0 +1,160 @@
package mq
import (
"encoding/gob"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"strconv"
"sync"
"time"
)
// Message 消息事件正文
type Message struct {
// 消息触发者
TriggeredBy string
// 事件标识
Event string
// 消息正文
Content interface{}
}
type CallbackFunc func(Message)
// MQ 消息队列
type MQ interface {
rpc.Notifier
// 发布一个消息
Publish(string, Message)
// 订阅一个消息主题
Subscribe(string, int) <-chan Message
// 订阅一个消息主题,注册触发回调函数
SubscribeCallback(string, CallbackFunc)
// 取消订阅一个消息主题
Unsubscribe(string, <-chan Message)
}
var GlobalMQ = NewMQ()
func NewMQ() MQ {
return &inMemoryMQ{
topics: make(map[string][]chan Message),
callbacks: make(map[string][]CallbackFunc),
}
}
func init() {
gob.Register(Message{})
gob.Register(rpc.Event{})
}
type inMemoryMQ struct {
topics map[string][]chan Message
callbacks map[string][]CallbackFunc
sync.RWMutex
}
func (i *inMemoryMQ) Publish(topic string, message Message) {
i.RLock()
subscribersChan, okChan := i.topics[topic]
subscribersCallback, okCallback := i.callbacks[topic]
i.RUnlock()
if okChan {
go func(subscribersChan []chan Message) {
for i := 0; i < len(subscribersChan); i++ {
select {
case subscribersChan[i] <- message:
case <-time.After(time.Millisecond * 500):
}
}
}(subscribersChan)
}
if okCallback {
for i := 0; i < len(subscribersCallback); i++ {
go subscribersCallback[i](message)
}
}
}
func (i *inMemoryMQ) Subscribe(topic string, buffer int) <-chan Message {
ch := make(chan Message, buffer)
i.Lock()
i.topics[topic] = append(i.topics[topic], ch)
i.Unlock()
return ch
}
func (i *inMemoryMQ) SubscribeCallback(topic string, callbackFunc CallbackFunc) {
i.Lock()
i.callbacks[topic] = append(i.callbacks[topic], callbackFunc)
i.Unlock()
}
func (i *inMemoryMQ) Unsubscribe(topic string, sub <-chan Message) {
i.Lock()
defer i.Unlock()
subscribers, ok := i.topics[topic]
if !ok {
return
}
var newSubs []chan Message
for _, subscriber := range subscribers {
if subscriber == sub {
continue
}
newSubs = append(newSubs, subscriber)
}
i.topics[topic] = newSubs
}
func (i *inMemoryMQ) Aria2Notify(events []rpc.Event, status int) {
for _, event := range events {
i.Publish(event.Gid, Message{
TriggeredBy: event.Gid,
Event: strconv.FormatInt(int64(status), 10),
Content: events,
})
}
}
// OnDownloadStart 下载开始
func (i *inMemoryMQ) OnDownloadStart(events []rpc.Event) {
i.Aria2Notify(events, common.Downloading)
}
// OnDownloadPause 下载暂停
func (i *inMemoryMQ) OnDownloadPause(events []rpc.Event) {
i.Aria2Notify(events, common.Paused)
}
// OnDownloadStop 下载停止
func (i *inMemoryMQ) OnDownloadStop(events []rpc.Event) {
i.Aria2Notify(events, common.Canceled)
}
// OnDownloadComplete 下载完成
func (i *inMemoryMQ) OnDownloadComplete(events []rpc.Event) {
i.Aria2Notify(events, common.Complete)
}
// OnDownloadError 下载出错
func (i *inMemoryMQ) OnDownloadError(events []rpc.Event) {
i.Aria2Notify(events, common.Error)
}
// OnBtDownloadComplete BT下载完成
func (i *inMemoryMQ) OnBtDownloadComplete(events []rpc.Event) {
i.Aria2Notify(events, common.Complete)
}

149
pkg/mq/mq_test.go Normal file
View file

@ -0,0 +1,149 @@
package mq
import (
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/stretchr/testify/assert"
"sync"
"testing"
"time"
)
func TestPublishAndSubscribe(t *testing.T) {
t.Parallel()
asserts := assert.New(t)
mq := NewMQ()
// No subscriber
{
asserts.NotPanics(func() {
mq.Publish("No subscriber", Message{})
})
}
// One channel subscriber
{
topic := "One channel subscriber"
msg := Message{TriggeredBy: "Tester"}
notifier := mq.Subscribe(topic, 0)
mq.Publish(topic, msg)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
wg.Done()
msgRecv := <-notifier
asserts.Equal(msg, msgRecv)
}()
wg.Wait()
}
// two channel subscriber
{
topic := "two channel subscriber"
msg := Message{TriggeredBy: "Tester"}
notifier := mq.Subscribe(topic, 0)
notifier2 := mq.Subscribe(topic, 0)
mq.Publish(topic, msg)
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
wg.Done()
msgRecv := <-notifier
asserts.Equal(msg, msgRecv)
}()
go func() {
wg.Done()
msgRecv := <-notifier2
asserts.Equal(msg, msgRecv)
}()
wg.Wait()
}
// two channel subscriber, one timeout
{
topic := "two channel subscriber, one timeout"
msg := Message{TriggeredBy: "Tester"}
mq.Subscribe(topic, 0)
notifier2 := mq.Subscribe(topic, 0)
mq.Publish(topic, msg)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
wg.Done()
msgRecv := <-notifier2
asserts.Equal(msg, msgRecv)
}()
wg.Wait()
}
// two channel subscriber, one unsubscribe
{
topic := "two channel subscriber, one unsubscribe"
msg := Message{TriggeredBy: "Tester"}
mq.Subscribe(topic, 0)
notifier2 := mq.Subscribe(topic, 0)
notifier := mq.Subscribe(topic, 0)
mq.Unsubscribe(topic, notifier)
mq.Publish(topic, msg)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
wg.Done()
msgRecv := <-notifier2
asserts.Equal(msg, msgRecv)
}()
wg.Wait()
select {
case <-notifier:
t.Error()
default:
}
}
}
func TestAria2Interface(t *testing.T) {
t.Parallel()
asserts := assert.New(t)
mq := NewMQ()
var (
OnDownloadStart int
OnDownloadPause int
OnDownloadStop int
OnDownloadComplete int
OnDownloadError int
)
l := sync.Mutex{}
mq.SubscribeCallback("TestAria2Interface", func(message Message) {
asserts.Equal("TestAria2Interface", message.TriggeredBy)
l.Lock()
defer l.Unlock()
switch message.Event {
case "1":
OnDownloadStart++
case "2":
OnDownloadPause++
case "5":
OnDownloadStop++
case "4":
OnDownloadComplete++
case "3":
OnDownloadError++
}
})
mq.OnDownloadStart([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
mq.OnDownloadPause([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
mq.OnDownloadStop([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
mq.OnDownloadComplete([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
mq.OnDownloadError([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
mq.OnBtDownloadComplete([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
time.Sleep(time.Duration(500) * time.Millisecond)
asserts.Equal(2, OnDownloadStart)
asserts.Equal(2, OnDownloadPause)
asserts.Equal(2, OnDownloadStop)
asserts.Equal(4, OnDownloadComplete)
asserts.Equal(2, OnDownloadError)
}