diff --git a/middleware/websocket/websocket.go b/middleware/websocket/websocket.go index 76b2bfed..e6db66a2 100644 --- a/middleware/websocket/websocket.go +++ b/middleware/websocket/websocket.go @@ -4,9 +4,12 @@ package websocket import ( + "bufio" + "bytes" "io" "net" "net/http" + "os" "os/exec" "strings" "time" @@ -88,15 +91,18 @@ func serveWS(w http.ResponseWriter, r *http.Request, config *Config) (int, error defer conn.Close() cmd := exec.Command(config.Command, config.Arguments...) + stdout, err := cmd.StdoutPipe() if err != nil { return http.StatusBadGateway, err } + defer stdout.Close() stdin, err := cmd.StdinPipe() if err != nil { return http.StatusBadGateway, err } + defer stdin.Close() metavars, err := buildEnv(cmd.Path, r) if err != nil { @@ -109,7 +115,31 @@ func serveWS(w http.ResponseWriter, r *http.Request, config *Config) (int, error return http.StatusBadGateway, err } - reader(conn, stdout, stdin) + done := make(chan struct{}) + go pumpStdout(conn, stdout, done) + pumpStdin(conn, stdin) + + stdin.Close() // close stdin to end the process + + if err := cmd.Process.Signal(os.Interrupt); err != nil { // signal an interrupt to kill the process + return http.StatusInternalServerError, err + } + + select { + case <-done: + case <-time.After(time.Second): + // terminate with extreme prejudice. + if err := cmd.Process.Signal(os.Kill); err != nil { + return http.StatusInternalServerError, err + } + <-done + } + + // not sure what we want to do here. + // status for an "exited" process is greater + // than 0, but isn't really an error per se. + // just going to ignore it for now. + cmd.Wait() return 0, nil } @@ -163,63 +193,60 @@ func buildEnv(cmdPath string, r *http.Request) (metavars []string, err error) { return } -// reader is the guts of this package. It takes the stdin and stdout pipes -// of the cmd we created in ServeWS and pipes them between the client and server -// over websockets. -func reader(conn *websocket.Conn, stdout io.ReadCloser, stdin io.WriteCloser) { +// pumpStdin handles reading data from the websocket connection and writing +// it to stdin of the process. +func pumpStdin(conn *websocket.Conn, stdin io.WriteCloser) { // Setup our connection's websocket ping/pong handlers from our const values. + defer conn.Close() conn.SetReadLimit(maxMessageSize) conn.SetReadDeadline(time.Now().Add(pongWait)) conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) - tickerChan := make(chan bool) - defer close(tickerChan) // make sure to close the ticker when we are done. - go ticker(conn, tickerChan) - for { - msgType, r, err := conn.NextReader() + _, message, err := conn.ReadMessage() if err != nil { - if msgType == -1 { - return // we got a disconnect from the client. We are good to close. - } - conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""), time.Time{}) - return + break } - - w, err := conn.NextWriter(msgType) - if err != nil { - conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""), time.Time{}) - return + message = append(message, '\n') + if _, err := stdin.Write(message); err != nil { + break } - - if _, err := io.Copy(stdin, r); err != nil { - conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""), time.Time{}) - return - } - - go func() { - if _, err := io.Copy(w, stdout); err != nil { - conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""), time.Time{}) - return - } - if err := w.Close(); err != nil { - conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""), time.Time{}) - return - } - }() } } -// ticker is start by the reader. Basically it is the method that simulates the websocket -// between the server and client to keep it alive with ping messages. -func ticker(conn *websocket.Conn, c chan bool) { +// pumpStdout handles reading data from stdout of the process and writing +// it to websocket connection. +func pumpStdout(conn *websocket.Conn, stdout io.Reader, done chan struct{}) { + go pinger(conn, done) + defer func() { + conn.Close() + close(done) // make sure to close the pinger when we are done. + }() + + s := bufio.NewScanner(stdout) + for s.Scan() { + conn.SetWriteDeadline(time.Now().Add(writeWait)) + if err := conn.WriteMessage(websocket.TextMessage, bytes.TrimSpace(s.Bytes())); err != nil { + break + } + } + if s.Err() != nil { + conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, s.Err().Error()), time.Time{}) + } +} + +// pinger simulates the websocket to keep it alive with ping messages. +func pinger(conn *websocket.Conn, done chan struct{}) { ticker := time.NewTicker(pingPeriod) defer ticker.Stop() for { // blocking loop with select to wait for stimulation. select { case <-ticker.C: - conn.WriteMessage(websocket.PingMessage, nil) - case <-c: + if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil { + conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, err.Error()), time.Time{}) + return + } + case <-done: return // clean up this routine. } }