1
Fork 0
mirror of https://github.com/caddyserver/caddy.git synced 2024-12-16 21:56:40 -05:00
caddy/caddyhttp/websocket/websocket.go
linquize 0ba427a6f4 websocket: Enhancements, message types, and tests (#2359)
* websocket: Should reset respawn parameter when processing next config entry

* websocket: add message types: lines, text, binary

* websocket: Add unit test

* Add websocket sample files
2019-07-19 13:29:49 -06:00

482 lines
13 KiB
Go

// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package websocket implements a WebSocket server by executing
// a command and piping its input and output through the WebSocket
// connection.
package websocket
import (
"bufio"
"bytes"
"io"
"log"
"net"
"net/http"
"os"
"os/exec"
"strings"
"time"
"unicode/utf8"
"github.com/caddyserver/caddy/caddyhttp/httpserver"
"github.com/gorilla/websocket"
)
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
// Maximum message size allowed from peer.
maxMessageSize = 1024 * 1024 * 10 // 10 MB default.
)
var (
// GatewayInterface is the dialect of CGI being used by the server
// to communicate with the script. See CGI spec, 4.1.4
GatewayInterface string
// ServerSoftware is the name and version of the information server
// software making the CGI request. See CGI spec, 4.1.17
ServerSoftware string
)
type (
// WebSocket is a type that holds configuration for the
// websocket middleware generally, like a list of all the
// websocket endpoints.
WebSocket struct {
// Next is the next HTTP handler in the chain for when the path doesn't match
Next httpserver.Handler
// Sockets holds all the web socket endpoint configurations
Sockets []Config
}
// Config holds the configuration for a single websocket
// endpoint which may serve multiple websocket connections.
Config struct {
Path string
Command string
Arguments []string
Respawn bool // TODO: Not used, but parser supports it until we decide on it
Type string
BufSize int
}
wsGetUpgrader interface {
GetUpgrader() wsUpgrader
}
wsUpgrader interface {
Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (wsConn, error)
}
wsConn interface {
Close() error
ReadMessage() (messageType int, p []byte, err error)
SetPongHandler(h func(appData string) error)
SetReadDeadline(t time.Time) error
SetReadLimit(limit int64)
SetWriteDeadline(t time.Time) error
WriteControl(messageType int, data []byte, deadline time.Time) error
WriteMessage(messageType int, data []byte) error
}
)
// ServeHTTP converts the HTTP request to a WebSocket connection and serves it up.
func (ws WebSocket) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
for _, sockConfig := range ws.Sockets {
if httpserver.Path(r.URL.Path).Matches(sockConfig.Path) {
return serveWS(w, r, &sockConfig)
}
}
// Didn't match a websocket path, so pass-through
return ws.Next.ServeHTTP(w, r)
}
// serveWS is used for setting and upgrading the HTTP connection to a websocket connection.
// It also spawns the child process that is associated with matched HTTP path/url.
func serveWS(w http.ResponseWriter, r *http.Request, config *Config) (int, error) {
gu, castok := w.(wsGetUpgrader)
var u wsUpgrader
if gu != nil && castok {
u = gu.GetUpgrader()
} else {
u = &realWsUpgrader{o: &websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool { return true },
}}
}
conn, err := u.Upgrade(w, r, nil)
if err != nil {
// the connection has been "handled" -- WriteHeader was called with Upgrade,
// so don't return an error status code; just return an error
return 0, err
}
defer conn.Close()
cmd := exec.Command(config.Command, config.Arguments...)
stdout, err := cmd.StdoutPipe()
if err != nil {
return http.StatusBadGateway, err
}
defer stdout.Close()
stdin, err := cmd.StdinPipe()
if err != nil {
return http.StatusBadGateway, err
}
defer stdin.Close()
metavars, err := buildEnv(cmd.Path, r)
if err != nil {
return http.StatusBadGateway, err
}
cmd.Env = metavars
if err := cmd.Start(); err != nil {
return http.StatusBadGateway, err
}
done := make(chan struct{})
go pumpStdout(conn, stdout, done, config)
pumpStdin(conn, stdin, config)
_ = stdin.Close() // close stdin to end the process
if err := cmd.Process.Signal(os.Interrupt); err != nil { // signal an interrupt to kill the process
return http.StatusInternalServerError, err
}
select {
case <-done:
case <-time.After(time.Second):
// terminate with extreme prejudice.
if err := cmd.Process.Signal(os.Kill); err != nil {
return http.StatusInternalServerError, err
}
<-done
}
// not sure what we want to do here.
// status for an "exited" process is greater
// than 0, but isn't really an error per se.
// just going to ignore it for now.
if err := cmd.Wait(); err != nil {
log.Println("[ERROR] failed to release resources: ", err)
}
return 0, nil
}
// buildEnv creates the meta-variables for the child process according
// to the CGI 1.1 specification: http://tools.ietf.org/html/rfc3875#section-4.1
// cmdPath should be the path of the command being run.
// The returned string slice can be set to the command's Env property.
func buildEnv(cmdPath string, r *http.Request) (metavars []string, err error) {
if !strings.Contains(r.RemoteAddr, ":") {
r.RemoteAddr += ":"
}
remoteHost, remotePort, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return
}
if !strings.Contains(r.Host, ":") {
r.Host += ":"
}
serverHost, serverPort, err := net.SplitHostPort(r.Host)
if err != nil {
return
}
metavars = []string{
`AUTH_TYPE=`, // Not used
`CONTENT_LENGTH=`, // Not used
`CONTENT_TYPE=`, // Not used
`GATEWAY_INTERFACE=` + GatewayInterface,
`PATH_INFO=`, // TODO
`PATH_TRANSLATED=`, // TODO
`QUERY_STRING=` + r.URL.RawQuery,
`REMOTE_ADDR=` + remoteHost,
`REMOTE_HOST=` + remoteHost, // Host lookups are slow - don't do them
`REMOTE_IDENT=`, // Not used
`REMOTE_PORT=` + remotePort,
`REMOTE_USER=`, // Not used,
`REQUEST_METHOD=` + r.Method,
`REQUEST_URI=` + r.RequestURI,
`SCRIPT_NAME=` + cmdPath, // path of the program being executed
`SERVER_NAME=` + serverHost,
`SERVER_PORT=` + serverPort,
`SERVER_PROTOCOL=` + r.Proto,
`SERVER_SOFTWARE=` + ServerSoftware,
}
// Add each HTTP header to the environment as well
for header, values := range r.Header {
value := strings.Join(values, ", ")
header = strings.ToUpper(header)
header = strings.Replace(header, "-", "_", -1)
value = strings.Replace(value, "\n", " ", -1)
metavars = append(metavars, "HTTP_"+header+"="+value)
}
return
}
// pumpStdin handles reading data from the websocket connection and writing
// it to stdin of the process.
func pumpStdin(conn wsConn, stdin io.WriteCloser, config *Config) {
// Setup our connection's websocket ping/pong handlers from our const values.
defer conn.Close()
conn.SetReadLimit(maxMessageSize)
if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
log.Println("[ERROR] failed to set read deadline: ", err)
}
conn.SetPongHandler(func(string) error {
if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
log.Println("[ERROR] failed to set read deadline: ", err)
}
return nil
})
for {
_, message, err := conn.ReadMessage()
if err != nil {
break
}
if config.Type == "lines" {
// no '\n' from client, so append '\n' to spawned process
message = append(message, '\n')
}
if _, err := stdin.Write(message); err != nil {
break
}
}
}
// pumpStdout handles reading data from stdout of the process and writing
// it to websocket connection.
func pumpStdout(conn wsConn, stdout io.Reader, done chan struct{}, config *Config) {
go pinger(conn, done)
defer func() {
_ = conn.Close()
close(done) // make sure to close the pinger when we are done.
}()
if config.Type == "lines" {
// message must end with '\n'
s := bufio.NewScanner(stdout)
if config.BufSize > 0 {
s.Buffer(make([]byte, config.BufSize), config.BufSize)
}
for s.Scan() {
if err := conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
log.Println("[ERROR] failed to set write deadline: ", err)
}
if err := conn.WriteMessage(websocket.TextMessage, bytes.TrimSpace(s.Bytes())); err != nil {
break
}
}
if s.Err() != nil {
err := conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, s.Err().Error()), time.Time{})
if err != nil {
log.Println("[ERROR] WriteControl failed: ", err)
}
}
} else if config.Type == "text" {
// handle UTF-8 text message, newline is not required
r := bufio.NewReader(stdout)
var err1 error
var len int
remainBuf := make([]byte, utf8.UTFMax)
remainLen := 0
bufSize := config.BufSize
if bufSize <= 0 {
bufSize = 2048
}
for {
out := make([]byte, bufSize)
copy(out[:remainLen], remainBuf[:remainLen])
len, err1 = r.Read(out[remainLen:])
if err1 != nil {
break
}
len += remainLen
remainLen = findIncompleteRuneLength(out, len)
if remainLen > 0 {
remainBuf = out[len-remainLen : len]
}
if err := conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
log.Println("[ERROR] failed to set write deadline: ", err)
}
if err := conn.WriteMessage(websocket.TextMessage, out[0:len-remainLen]); err != nil {
break
}
}
if err1 != nil {
err := conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, err1.Error()), time.Time{})
if err != nil {
log.Println("[ERROR] WriteControl failed: ", err)
}
}
} else if config.Type == "binary" {
// treat message as binary data
r := bufio.NewReader(stdout)
var err1 error
var len int
bufSize := config.BufSize
if bufSize <= 0 {
bufSize = 2048
}
for {
out := make([]byte, bufSize)
len, err1 = r.Read(out)
if err1 != nil {
break
}
if err := conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
log.Println("[ERROR] failed to set write deadline: ", err)
}
if err := conn.WriteMessage(websocket.BinaryMessage, out[0:len]); err != nil {
break
}
}
if err1 != nil {
err := conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, err1.Error()), time.Time{})
if err != nil {
log.Println("[ERROR] WriteControl failed: ", err)
}
}
}
}
// pinger simulates the websocket to keep it alive with ping messages.
func pinger(conn wsConn, done chan struct{}) {
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
for { // blocking loop with select to wait for stimulation.
select {
case <-ticker.C:
if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil {
err := conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, err.Error()), time.Time{})
if err != nil {
log.Println("[ERROR] WriteControl failed: ", err)
}
return
}
case <-done:
return // clean up this routine.
}
}
}
type realWsUpgrader struct {
o *websocket.Upgrader
}
type realWsConn struct {
o *websocket.Conn
}
func (u *realWsUpgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (wsConn, error) {
a, b := u.o.Upgrade(w, r, responseHeader)
return &realWsConn{o: a}, b
}
func (c *realWsConn) Close() error {
return c.o.Close()
}
func (c *realWsConn) ReadMessage() (messageType int, p []byte, err error) {
return c.o.ReadMessage()
}
func (c *realWsConn) SetPongHandler(h func(appData string) error) {
c.o.SetPongHandler(h)
}
func (c *realWsConn) SetReadDeadline(t time.Time) error {
return c.o.SetReadDeadline(t)
}
func (c *realWsConn) SetReadLimit(limit int64) {
c.o.SetReadLimit(limit)
}
func (c *realWsConn) SetWriteDeadline(t time.Time) error {
return c.o.SetWriteDeadline(t)
}
func (c *realWsConn) WriteControl(messageType int, data []byte, deadline time.Time) error {
return c.o.WriteControl(messageType, data, deadline)
}
func (c *realWsConn) WriteMessage(messageType int, data []byte) error {
return c.o.WriteMessage(messageType, data)
}
func findIncompleteRuneLength(p []byte, length int) int {
if length == 0 {
return 0
}
if rune(p[length-1]) < utf8.RuneSelf {
// ASCII 7-bit always complete
return 0
}
lowest := length - utf8.UTFMax
if lowest < 0 {
lowest = 0
}
for start := length - 1; start >= lowest; start-- {
if (p[start] >> 5) == 0x06 {
// 2-byte utf-8 start byte
if length-start >= 2 {
// enough bytes
return 0
}
// 1 byte outstanding
return 1
}
if (p[start] >> 4) == 0x0E {
// 3-byte utf-8 start byte
if length-start >= 3 {
// enough bytes
return 0
}
// some bytes outstanding
return length - start
}
if (p[start] >> 3) == 0x1E {
// 4-byte utf-8 start byte
if length-start >= 4 {
// enough bytes
return 0
}
// some bytes outstanding
return length - start
}
}
// No utf-8 start byte
return 0
}