0
Fork 0
mirror of https://github.com/dani-garcia/vaultwarden.git synced 2025-01-21 01:12:28 -05:00

Initial version of websockets notification support.

For now only folder notifications are sent (create, rename, delete).
The notifications are only tested between two web-vault sessions in different browsers, mobile apps and browser extensions are untested.

The websocket server is exposed in port 3012, while the rocket server is exposed in another port (8000 by default). To make notifications work, both should be accessible in the same port, which requires a reverse proxy.

My testing is done with Caddy server, and the following config:

```
localhost {

    # The negotiation endpoint is also proxied to Rocket
    proxy /notifications/hub/negotiate 0.0.0.0:8000 {
        transparent
    }

    # Notifications redirected to the websockets server
    proxy /notifications/hub 0.0.0.0:3012 {
        websocket
    }

    # Proxy the Root directory to Rocket
    proxy / 0.0.0.0:8000 {
        transparent
    }
}
```

This exposes the service in port 2015.
This commit is contained in:
Daniel García 2018-08-30 17:43:46 +02:00
parent f94e626021
commit d70864ac73
7 changed files with 765 additions and 168 deletions

537
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -15,9 +15,18 @@ reqwest = "0.8.8"
# multipart/form-data support
multipart = "0.15.2"
# WebSockets library
ws = "0.7.8"
# MessagePack library
rmpv = "0.4.0"
# Concurrent hashmap implementation
chashmap = "2.2.0"
# A generic serialization/deserialization framework
serde = "1.0.74"
serde_derive = "1.0.74"
serde = "1.0.75"
serde_derive = "1.0.75"
serde_json = "1.0.26"
# A safe, extensible ORM and Query builder
@ -34,7 +43,7 @@ ring = { version = "= 0.11.0", features = ["rsa_signing"] }
uuid = { version = "0.6.5", features = ["v4"] }
# Date and time library for Rust
chrono = "0.4.5"
chrono = "0.4.6"
# TOTP library
oath = "0.10.2"
@ -58,9 +67,13 @@ lazy_static = "1.1.0"
num-traits = "0.2.5"
num-derive = "0.2.2"
# Number encoding library
byteorder = "1.2.6"
[patch.crates-io]
# Make jwt use ring 0.11, to match rocket
jsonwebtoken = { path = "libs/jsonwebtoken" }
rmp = { git = 'https://github.com/dani-garcia/msgpack-rust' }
# Version 0.1.2 from crates.io lacks a commit that fixes a certificate error
u2f = { git = 'https://github.com/wisespace-io/u2f-rs', rev = '193de35093a44' }

View file

@ -1,9 +1,10 @@
use rocket::State;
use rocket_contrib::{Json, Value};
use db::DbConn;
use db::models::*;
use api::{JsonResult, EmptyResult, JsonUpcase};
use api::{JsonResult, EmptyResult, JsonUpcase, WebSocketUsers, UpdateType};
use auth::Headers;
#[get("/folders")]
@ -40,23 +41,24 @@ pub struct FolderData {
}
#[post("/folders", data = "<data>")]
fn post_folders(data: JsonUpcase<FolderData>, headers: Headers, conn: DbConn) -> JsonResult {
fn post_folders(data: JsonUpcase<FolderData>, headers: Headers, conn: DbConn, ws: State<WebSocketUsers>) -> JsonResult {
let data: FolderData = data.into_inner().data;
let mut folder = Folder::new(headers.user.uuid.clone(), data.Name);
folder.save(&conn);
ws.send_folder_update(UpdateType::SyncFolderCreate, &folder);
Ok(Json(folder.to_json()))
}
#[post("/folders/<uuid>", data = "<data>")]
fn post_folder(uuid: String, data: JsonUpcase<FolderData>, headers: Headers, conn: DbConn) -> JsonResult {
put_folder(uuid, data, headers, conn)
fn post_folder(uuid: String, data: JsonUpcase<FolderData>, headers: Headers, conn: DbConn, ws: State<WebSocketUsers>) -> JsonResult {
put_folder(uuid, data, headers, conn, ws)
}
#[put("/folders/<uuid>", data = "<data>")]
fn put_folder(uuid: String, data: JsonUpcase<FolderData>, headers: Headers, conn: DbConn) -> JsonResult {
fn put_folder(uuid: String, data: JsonUpcase<FolderData>, headers: Headers, conn: DbConn, ws: State<WebSocketUsers>) -> JsonResult {
let data: FolderData = data.into_inner().data;
let mut folder = match Folder::find_by_uuid(&uuid, &conn) {
@ -71,17 +73,18 @@ fn put_folder(uuid: String, data: JsonUpcase<FolderData>, headers: Headers, conn
folder.name = data.Name;
folder.save(&conn);
ws.send_folder_update(UpdateType::SyncFolderUpdate, &folder);
Ok(Json(folder.to_json()))
}
#[post("/folders/<uuid>/delete")]
fn delete_folder_post(uuid: String, headers: Headers, conn: DbConn) -> EmptyResult {
delete_folder(uuid, headers, conn)
fn delete_folder_post(uuid: String, headers: Headers, conn: DbConn, ws: State<WebSocketUsers>) -> EmptyResult {
delete_folder(uuid, headers, conn, ws)
}
#[delete("/folders/<uuid>")]
fn delete_folder(uuid: String, headers: Headers, conn: DbConn) -> EmptyResult {
fn delete_folder(uuid: String, headers: Headers, conn: DbConn, ws: State<WebSocketUsers>) -> EmptyResult {
let folder = match Folder::find_by_uuid(&uuid, &conn) {
Some(folder) => folder,
_ => err!("Invalid folder")
@ -93,7 +96,10 @@ fn delete_folder(uuid: String, headers: Headers, conn: DbConn) -> EmptyResult {
// Delete the actual folder entry
match folder.delete(&conn) {
Ok(()) => Ok(()),
Ok(()) => {
ws.send_folder_update(UpdateType::SyncFolderDelete, &folder);
Ok(())
}
Err(_) => err!("Failed deleting folder")
}
}

View file

@ -9,6 +9,7 @@ pub use self::icons::routes as icons_routes;
pub use self::identity::routes as identity_routes;
pub use self::web::routes as web_routes;
pub use self::notifications::routes as notifications_routes;
pub use self::notifications::{start_notification_server, WebSocketUsers, UpdateType};
use rocket::response::status::BadRequest;
use rocket_contrib::Json;

View file

@ -1,9 +1,9 @@
use rocket::Route;
use rocket_contrib::Json;
use db::DbConn;
use api::JsonResult;
use auth::Headers;
use db::DbConn;
pub fn routes() -> Vec<Route> {
routes![negotiate]
@ -11,10 +11,9 @@ pub fn routes() -> Vec<Route> {
#[post("/hub/negotiate")]
fn negotiate(_headers: Headers, _conn: DbConn) -> JsonResult {
use data_encoding::BASE64URL;
use crypto;
use data_encoding::BASE64URL;
// Store this in db?
let conn_id = BASE64URL.encode(&crypto::get_random(vec![0u8; 16]));
// TODO: Implement transports
@ -23,9 +22,338 @@ fn negotiate(_headers: Headers, _conn: DbConn) -> JsonResult {
Ok(Json(json!({
"connectionId": conn_id,
"availableTransports":[
// {"transport":"WebSockets", "transferFormats":["Text","Binary"]},
{"transport":"WebSockets", "transferFormats":["Text","Binary"]},
// {"transport":"ServerSentEvents", "transferFormats":["Text"]},
// {"transport":"LongPolling", "transferFormats":["Text","Binary"]}
]
})))
}
///
/// Websockets server
///
use std::sync::Arc;
use std::thread;
use ws::{self, util::Token, Factory, Handler, Handshake, Message, Sender, WebSocket};
use chashmap::CHashMap;
use chrono::NaiveDateTime;
use serde_json::from_str;
use db::models::{Cipher, Folder, User};
use rmpv::Value;
fn serialize(val: Value) -> Vec<u8> {
use rmpv::encode::write_value;
let mut buf = Vec::new();
write_value(&mut buf, &val).expect("Error encoding MsgPack");
// Add size bytes at the start
// Extracted from BinaryMessageFormat.js
let mut size = buf.len();
let mut len_buf: Vec<u8> = Vec::new();
loop {
let mut size_part = size & 0x7f;
size = size >> 7;
if size > 0 {
size_part = size_part | 0x80;
}
len_buf.push(size_part as u8);
if size <= 0 {
break;
}
}
len_buf.append(&mut buf);
len_buf
}
fn serialize_date(date: NaiveDateTime) -> Value {
let seconds: i64 = date.timestamp();
let nanos: i64 = date.timestamp_subsec_nanos() as i64;
let timestamp = nanos << 34 | seconds;
use byteorder::{BigEndian, WriteBytesExt};
let mut bs = [0u8; 8];
bs.as_mut()
.write_i64::<BigEndian>(timestamp)
.expect("Unable to write");
// -1 is Timestamp
// https://github.com/msgpack/msgpack/blob/master/spec.md#timestamp-extension-type
Value::Ext(-1, bs.to_vec())
}
fn convert_option<T: Into<Value>>(option: Option<T>) -> Value {
match option {
Some(a) => a.into(),
None => Value::Nil,
}
}
// Server WebSocket handler
pub struct WSHandler {
out: Sender,
user_uuid: Option<String>,
users: WebSocketUsers,
}
const RECORD_SEPARATOR: u8 = 0x1e;
const INITIAL_RESPONSE: [u8; 3] = [0x7b, 0x7d, RECORD_SEPARATOR]; // {, }, <RS>
#[derive(Deserialize)]
struct InitialMessage {
protocol: String,
version: i32,
}
const PING_MS: u64 = 15_000;
const PING: Token = Token(1);
impl Handler for WSHandler {
fn on_open(&mut self, hs: Handshake) -> ws::Result<()> {
// TODO: Improve this split
let path = hs.request.resource();
let mut query_split: Vec<_> = path.split("?").nth(1).unwrap().split("&").collect();
query_split.sort();
let access_token = &query_split[0][13..];
let _id = &query_split[1][3..];
// Validate the user
use auth;
let claims = match auth::decode_jwt(access_token) {
Ok(claims) => claims,
Err(_) => {
return Err(ws::Error::new(
ws::ErrorKind::Internal,
"Invalid access token provided",
))
}
};
// Assign the user to the handler
let user_uuid = claims.sub;
self.user_uuid = Some(user_uuid.clone());
// Add the current Sender to the user list
let handler_insert = self.out.clone();
let handler_update = self.out.clone();
self.users.map.upsert(
user_uuid,
|| vec![handler_insert],
|ref mut v| v.push(handler_update),
);
// Schedule a ping to keep the connection alive
self.out.timeout(PING_MS, PING)
}
fn on_message(&mut self, msg: Message) -> ws::Result<()> {
println!("Server got message '{}'. ", msg);
if let Message::Text(text) = msg.clone() {
let json = &text[..text.len() - 1]; // Remove last char
if let Ok(InitialMessage { protocol, version }) = from_str::<InitialMessage>(json) {
if &protocol == "messagepack" && version == 1 {
return self.out.send(&INITIAL_RESPONSE[..]); // Respond to initial message
}
}
}
// If it's not the initial message, just echo the message
self.out.send(msg)
}
fn on_timeout(&mut self, event: Token) -> ws::Result<()> {
if event == PING {
// send ping
self.out.send(create_ping())?;
// reschedule the timeout
self.out.timeout(PING_MS, PING)
} else {
Err(ws::Error::new(
ws::ErrorKind::Internal,
"Invalid timeout token provided",
))
}
}
}
struct WSFactory {
pub users: WebSocketUsers,
}
impl WSFactory {
pub fn init() -> Self {
WSFactory {
users: WebSocketUsers {
map: Arc::new(CHashMap::new()),
},
}
}
}
impl Factory for WSFactory {
type Handler = WSHandler;
fn connection_made(&mut self, out: Sender) -> Self::Handler {
println!("WS: Connection made");
WSHandler {
out,
user_uuid: None,
users: self.users.clone(),
}
}
fn connection_lost(&mut self, handler: Self::Handler) {
println!("WS: Connection lost");
// Remove handler
let user_uuid = &handler.user_uuid.unwrap();
if let Some(mut user_conn) = self.users.map.get_mut(user_uuid) {
user_conn.remove_item(&handler.out);
}
}
}
#[derive(Clone)]
pub struct WebSocketUsers {
pub map: Arc<CHashMap<String, Vec<Sender>>>,
}
impl WebSocketUsers {
fn send_update(&self, user_uuid: &String, data: Vec<u8>) -> ws::Result<()> {
if let Some(user) = self.map.get(user_uuid) {
for sender in user.iter() {
sender.send(data.clone())?;
}
}
Ok(())
}
// NOTE: The last modified date needs to be updated before calling these methods
pub fn send_user_update(&self, ut: UpdateType, user: &User) {
let data = create_update(
vec![
("UserId".into(), user.uuid.clone().into()),
("Date".into(), serialize_date(user.updated_at)),
].into(),
ut,
);
self.send_update(&user.uuid.clone(), data).ok();
}
pub fn send_folder_update(&self, ut: UpdateType, folder: &Folder) {
let data = create_update(
vec![
("Id".into(), folder.uuid.clone().into()),
("UserId".into(), folder.user_uuid.clone().into()),
("RevisionDate".into(), serialize_date(folder.updated_at)),
].into(),
ut,
);
self.send_update(&folder.user_uuid, data).ok();
}
pub fn send_cipher_update(&self, ut: UpdateType, cipher: &Cipher, user_uuids: &[&String]) {
let user_uuid = convert_option(cipher.user_uuid.clone());
let org_uuid = convert_option(cipher.organization_uuid.clone());
let data = create_update(
vec![
("Id".into(), cipher.uuid.clone().into()),
("UserId".into(), user_uuid),
("OrganizationId".into(), org_uuid),
("CollectionIds".into(), Value::Nil),
("RevisionDate".into(), serialize_date(cipher.updated_at)),
].into(),
ut,
);
for user_uuid in user_uuids {
self.send_update(user_uuid, data.clone()).ok();
}
}
}
/* Message Structure
[
1, // MessageType.Invocation
{}, // Headers
null, // InvocationId
"ReceiveMessage", // Target
[ // Arguments
{
"ContextId": "app_id",
"Type": ut as i32,
"Payload": {}
}
]
]
*/
fn create_update(payload: Vec<(Value, Value)>, ut: UpdateType) -> Vec<u8> {
use rmpv::Value as V;
let value = V::Array(vec![
1.into(),
V::Array(vec![]),
V::Nil,
"ReceiveMessage".into(),
V::Array(vec![V::Map(vec![
("ContextId".into(), "app_id".into()),
("Type".into(), (ut as i32).into()),
("Payload".into(), payload.into()),
])]),
]);
serialize(value)
}
fn create_ping() -> Vec<u8> {
serialize(Value::Array(vec![6.into()]))
}
#[allow(dead_code)]
pub enum UpdateType {
SyncCipherUpdate = 0,
SyncCipherCreate = 1,
SyncLoginDelete = 2,
SyncFolderDelete = 3,
SyncCiphers = 4,
SyncVault = 5,
SyncOrgKeys = 6,
SyncFolderCreate = 7,
SyncFolderUpdate = 8,
SyncCipherDelete = 9,
SyncSettings = 10,
LogOut = 11,
}
pub fn start_notification_server() -> WebSocketUsers {
let factory = WSFactory::init();
let users = factory.users.clone();
thread::spawn(move || {
WebSocket::new(factory)
.unwrap()
.listen("0.0.0.0:3012")
.unwrap();
});
users
}

View file

@ -82,13 +82,13 @@ impl Folder {
}
}
pub fn delete(self, conn: &DbConn) -> QueryResult<()> {
pub fn delete(&self, conn: &DbConn) -> QueryResult<()> {
User::update_uuid_revision(&self.user_uuid, conn);
FolderCipher::delete_all_by_folder(&self.uuid, &conn)?;
diesel::delete(
folders::table.filter(
folders::uuid.eq(self.uuid)
folders::uuid.eq(&self.uuid)
)
).execute(&**conn).and(Ok(()))
}

View file

@ -1,10 +1,13 @@
#![feature(plugin, custom_derive)]
#![feature(plugin, custom_derive, vec_remove_item)]
#![plugin(rocket_codegen)]
#![allow(proc_macro_derive_resolution_fallback)] // TODO: Remove this when diesel update fixes warnings
extern crate rocket;
extern crate rocket_contrib;
extern crate reqwest;
extern crate multipart;
extern crate ws;
extern crate rmpv;
extern crate chashmap;
extern crate serde;
#[macro_use]
extern crate serde_derive;
@ -27,6 +30,7 @@ extern crate lazy_static;
#[macro_use]
extern crate num_derive;
extern crate num_traits;
extern crate byteorder;
use std::{env, path::Path, process::{exit, Command}};
use rocket::Rocket;
@ -47,6 +51,7 @@ fn init_rocket() -> Rocket {
.mount("/icons", api::icons_routes())
.mount("/notifications", api::notifications_routes())
.manage(db::init_pool())
.manage(api::start_notification_server())
}
// Embed the migrations from the migrations folder into the application
@ -71,7 +76,6 @@ fn main() {
check_web_vault();
migrations::run_migrations();
init_rocket().launch();
}