Cloudreve/pkg/aria2/rpc/call.go

274 lines
6.2 KiB
Go

package rpc
import (
"context"
"errors"
"log"
"net"
"net/http"
"net/url"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
)
type caller interface {
// Call sends a request of rpc to aria2 daemon
Call(method string, params, reply interface{}) (err error)
Close() error
}
type httpCaller struct {
uri string
c *http.Client
cancel context.CancelFunc
wg *sync.WaitGroup
once sync.Once
}
func newHTTPCaller(ctx context.Context, u *url.URL, timeout time.Duration, notifer Notifier) *httpCaller {
c := &http.Client{
Transport: &http.Transport{
MaxIdleConnsPerHost: 1,
MaxConnsPerHost: 1,
// TLSClientConfig: tlsConfig,
Dial: (&net.Dialer{
Timeout: timeout,
KeepAlive: 60 * time.Second,
}).Dial,
TLSHandshakeTimeout: 3 * time.Second,
ResponseHeaderTimeout: timeout,
},
}
var wg sync.WaitGroup
ctx, cancel := context.WithCancel(ctx)
h := &httpCaller{uri: u.String(), c: c, cancel: cancel, wg: &wg}
if notifer != nil {
h.setNotifier(ctx, *u, notifer)
}
return h
}
func (h *httpCaller) Close() (err error) {
h.once.Do(func() {
h.cancel()
h.wg.Wait()
})
return
}
func (h *httpCaller) setNotifier(ctx context.Context, u url.URL, notifer Notifier) (err error) {
u.Scheme = "ws"
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
return
}
h.wg.Add(1)
go func() {
defer h.wg.Done()
defer conn.Close()
select {
case <-ctx.Done():
conn.SetWriteDeadline(time.Now().Add(time.Second))
if err := conn.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
log.Printf("sending websocket close message: %v", err)
}
return
}
}()
h.wg.Add(1)
go func() {
defer h.wg.Done()
var request websocketResponse
var err error
for {
select {
case <-ctx.Done():
return
default:
}
if err = conn.ReadJSON(&request); err != nil {
select {
case <-ctx.Done():
return
default:
}
log.Printf("conn.ReadJSON|err:%v", err.Error())
return
}
switch request.Method {
case "aria2.onDownloadStart":
notifer.OnDownloadStart(request.Params)
case "aria2.onDownloadPause":
notifer.OnDownloadPause(request.Params)
case "aria2.onDownloadStop":
notifer.OnDownloadStop(request.Params)
case "aria2.onDownloadComplete":
notifer.OnDownloadComplete(request.Params)
case "aria2.onDownloadError":
notifer.OnDownloadError(request.Params)
case "aria2.onBtDownloadComplete":
notifer.OnBtDownloadComplete(request.Params)
default:
log.Printf("unexpected notification: %s", request.Method)
}
}
}()
return
}
func (h httpCaller) Call(method string, params, reply interface{}) (err error) {
payload, err := EncodeClientRequest(method, params)
if err != nil {
return
}
r, err := h.c.Post(h.uri, "application/json", payload)
if err != nil {
return
}
err = DecodeClientResponse(r.Body, &reply)
r.Body.Close()
return
}
type websocketCaller struct {
conn *websocket.Conn
sendChan chan *sendRequest
cancel context.CancelFunc
wg *sync.WaitGroup
once sync.Once
timeout time.Duration
}
func newWebsocketCaller(ctx context.Context, uri string, timeout time.Duration, notifier Notifier) (*websocketCaller, error) {
var header = http.Header{}
conn, _, err := websocket.DefaultDialer.Dial(uri, header)
if err != nil {
return nil, err
}
sendChan := make(chan *sendRequest, 16)
var wg sync.WaitGroup
ctx, cancel := context.WithCancel(ctx)
w := &websocketCaller{conn: conn, wg: &wg, cancel: cancel, sendChan: sendChan, timeout: timeout}
processor := NewResponseProcessor()
wg.Add(1)
go func() { // routine:recv
defer wg.Done()
defer cancel()
for {
select {
case <-ctx.Done():
return
default:
}
var resp websocketResponse
if err := conn.ReadJSON(&resp); err != nil {
select {
case <-ctx.Done():
return
default:
}
log.Printf("conn.ReadJSON|err:%v", err.Error())
return
}
if resp.Id == nil { // RPC notifications
if notifier != nil {
switch resp.Method {
case "aria2.onDownloadStart":
notifier.OnDownloadStart(resp.Params)
case "aria2.onDownloadPause":
notifier.OnDownloadPause(resp.Params)
case "aria2.onDownloadStop":
notifier.OnDownloadStop(resp.Params)
case "aria2.onDownloadComplete":
notifier.OnDownloadComplete(resp.Params)
case "aria2.onDownloadError":
notifier.OnDownloadError(resp.Params)
case "aria2.onBtDownloadComplete":
notifier.OnBtDownloadComplete(resp.Params)
default:
log.Printf("unexpected notification: %s", resp.Method)
}
}
continue
}
processor.Process(resp.clientResponse)
}
}()
wg.Add(1)
go func() { // routine:send
defer wg.Done()
defer cancel()
defer w.conn.Close()
for {
select {
case <-ctx.Done():
if err := w.conn.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
log.Printf("sending websocket close message: %v", err)
}
return
case req := <-sendChan:
processor.Add(req.request.Id, func(resp clientResponse) error {
err := resp.decode(req.reply)
req.cancel()
return err
})
w.conn.SetWriteDeadline(time.Now().Add(timeout))
w.conn.WriteJSON(req.request)
}
}
}()
return w, nil
}
func (w *websocketCaller) Close() (err error) {
w.once.Do(func() {
w.cancel()
w.wg.Wait()
})
return
}
func (w websocketCaller) Call(method string, params, reply interface{}) (err error) {
ctx, cancel := context.WithTimeout(context.Background(), w.timeout)
defer cancel()
select {
case w.sendChan <- &sendRequest{cancel: cancel, request: &clientRequest{
Version: "2.0",
Method: method,
Params: params,
Id: reqid(),
}, reply: reply}:
default:
return errors.New("sending channel blocking")
}
select {
case <-ctx.Done():
if err := ctx.Err(); err == context.DeadlineExceeded {
return err
}
}
return
}
type sendRequest struct {
cancel context.CancelFunc
request *clientRequest
reply interface{}
}
var reqid = func() func() uint64 {
var id = uint64(time.Now().UnixNano())
return func() uint64 {
return atomic.AddUint64(&id, 1)
}
}()