#116 refactor: better error handling

Merged
timo merged 2 commits from error-handling into master 1 month ago
  1. +1335
    -1442
      src/client_server.rs
  2. +52
    -43
      src/database.rs
  3. +20
    -12
      src/database/account_data.rs
  4. +13
    -5
      src/database/global_edus.rs
  5. +13
    -13
      src/database/globals.rs
  6. +26
    -13
      src/database/media.rs
  7. +355
    -236
      src/database/rooms.rs
  8. +45
    -18
      src/database/rooms/edus.rs
  9. +57
    -13
      src/database/uiaa.rs
  10. +125
    -61
      src/database/users.rs
  11. +72
    -25
      src/error.rs
  12. +2
    -2
      src/main.rs
  13. +18
    -11
      src/pdu.rs
  14. +12
    -25
      src/ruma_wrapper.rs
  15. +15
    -17
      src/utils.rs
  16. +1
    -2
      sytest/sytest-whitelist

+ 1335
- 1442
src/client_server.rs
File diff suppressed because it is too large
View File


+ 52
- 43
src/database.rs View File

@@ -6,6 +6,7 @@ pub(self) mod rooms;
pub(self) mod uiaa;
pub(self) mod users;

use crate::{Error, Result};
use directories::ProjectDirs;
use log::info;
use std::fs::remove_dir_all;
@@ -25,84 +26,92 @@ pub struct Database {

impl Database {
/// Tries to remove the old database but ignores all errors.
pub fn try_remove(server_name: &str) {
pub fn try_remove(server_name: &str) -> Result<()> {
let mut path = ProjectDirs::from("xyz", "koesters", "conduit")
.unwrap()
.ok_or(Error::BadConfig(
"The OS didn't return a valid home directory path.",
))?
.data_dir()
.to_path_buf();
path.push(server_name);
let _ = remove_dir_all(path);

Ok(())
}

/// Load an existing database or create a new one.
pub fn load_or_create(config: &Config) -> Self {
pub fn load_or_create(config: &Config) -> Result<Self> {
let server_name = config.get_str("server_name").unwrap_or("localhost");

let path = config
.get_str("database_path")
.map(|x| x.to_owned())
.map(|x| Ok::<_, Error>(x.to_owned()))
.unwrap_or_else(|_| {
let path = ProjectDirs::from("xyz", "koesters", "conduit")
.unwrap()
.ok_or(Error::BadConfig(
"The OS didn't return a valid home directory path.",
))?
.data_dir()
.join(server_name);
path.to_str().unwrap().to_owned()
});

let db = sled::open(&path).unwrap();
Ok(path
.to_str()
.ok_or(Error::BadConfig("Database path contains invalid unicode."))?
.to_owned())
})?;

let db = sled::open(&path)?;
info!("Opened sled database at {}", path);

Self {
globals: globals::Globals::load(db.open_tree("global").unwrap(), config),
Ok(Self {
globals: globals::Globals::load(db.open_tree("global")?, config)?,
users: users::Users {
userid_password: db.open_tree("userid_password").unwrap(),
userid_displayname: db.open_tree("userid_displayname").unwrap(),
userid_avatarurl: db.open_tree("userid_avatarurl").unwrap(),
userdeviceid_token: db.open_tree("userdeviceid_token").unwrap(),
userdeviceid_metadata: db.open_tree("userdeviceid_metadata").unwrap(),
token_userdeviceid: db.open_tree("token_userdeviceid").unwrap(),
onetimekeyid_onetimekeys: db.open_tree("onetimekeyid_onetimekeys").unwrap(),
userdeviceid_devicekeys: db.open_tree("userdeviceid_devicekeys").unwrap(),
devicekeychangeid_userid: db.open_tree("devicekeychangeid_userid").unwrap(),
todeviceid_events: db.open_tree("todeviceid_events").unwrap(),
userid_password: db.open_tree("userid_password")?,
userid_displayname: db.open_tree("userid_displayname")?,
userid_avatarurl: db.open_tree("userid_avatarurl")?,
userdeviceid_token: db.open_tree("userdeviceid_token")?,
userdeviceid_metadata: db.open_tree("userdeviceid_metadata")?,
token_userdeviceid: db.open_tree("token_userdeviceid")?,
onetimekeyid_onetimekeys: db.open_tree("onetimekeyid_onetimekeys")?,
userdeviceid_devicekeys: db.open_tree("userdeviceid_devicekeys")?,
devicekeychangeid_userid: db.open_tree("devicekeychangeid_userid")?,
todeviceid_events: db.open_tree("todeviceid_events")?,
},
uiaa: uiaa::Uiaa {
userdeviceid_uiaainfo: db.open_tree("userdeviceid_uiaainfo").unwrap(),
userdeviceid_uiaainfo: db.open_tree("userdeviceid_uiaainfo")?,
},
rooms: rooms::Rooms {
edus: rooms::RoomEdus {
roomuserid_lastread: db.open_tree("roomuserid_lastread").unwrap(), // "Private" read receipt
roomlatestid_roomlatest: db.open_tree("roomlatestid_roomlatest").unwrap(), // Read receipts
roomactiveid_userid: db.open_tree("roomactiveid_userid").unwrap(), // Typing notifs
roomid_lastroomactiveupdate: db
.open_tree("roomid_lastroomactiveupdate")
.unwrap(),
roomuserid_lastread: db.open_tree("roomuserid_lastread")?, // "Private" read receipt
roomlatestid_roomlatest: db.open_tree("roomlatestid_roomlatest")?, // Read receipts
roomactiveid_userid: db.open_tree("roomactiveid_userid")?, // Typing notifs
roomid_lastroomactiveupdate: db.open_tree("roomid_lastroomactiveupdate")?,
},
pduid_pdu: db.open_tree("pduid_pdu").unwrap(),
eventid_pduid: db.open_tree("eventid_pduid").unwrap(),
roomid_pduleaves: db.open_tree("roomid_pduleaves").unwrap(),
roomstateid_pdu: db.open_tree("roomstateid_pdu").unwrap(),
pduid_pdu: db.open_tree("pduid_pdu")?,
eventid_pduid: db.open_tree("eventid_pduid")?,
roomid_pduleaves: db.open_tree("roomid_pduleaves")?,
roomstateid_pdu: db.open_tree("roomstateid_pdu")?,

alias_roomid: db.open_tree("alias_roomid").unwrap(),
aliasid_alias: db.open_tree("alias_roomid").unwrap(),
publicroomids: db.open_tree("publicroomids").unwrap(),
alias_roomid: db.open_tree("alias_roomid")?,
aliasid_alias: db.open_tree("alias_roomid")?,
publicroomids: db.open_tree("publicroomids")?,

userroomid_joined: db.open_tree("userroomid_joined").unwrap(),
roomuserid_joined: db.open_tree("roomuserid_joined").unwrap(),
userroomid_invited: db.open_tree("userroomid_invited").unwrap(),
roomuserid_invited: db.open_tree("roomuserid_invited").unwrap(),
userroomid_left: db.open_tree("userroomid_left").unwrap(),
userroomid_joined: db.open_tree("userroomid_joined")?,
roomuserid_joined: db.open_tree("roomuserid_joined")?,
userroomid_invited: db.open_tree("userroomid_invited")?,
roomuserid_invited: db.open_tree("roomuserid_invited")?,
userroomid_left: db.open_tree("userroomid_left")?,
},
account_data: account_data::AccountData {
roomuserdataid_accountdata: db.open_tree("roomuserdataid_accountdata").unwrap(),
roomuserdataid_accountdata: db.open_tree("roomuserdataid_accountdata")?,
},
global_edus: global_edus::GlobalEdus {
presenceid_presence: db.open_tree("presenceid_presence").unwrap(), // Presence
presenceid_presence: db.open_tree("presenceid_presence")?, // Presence
},
media: media::Media {
mediaid_file: db.open_tree("mediaid_file").unwrap(),
mediaid_file: db.open_tree("mediaid_file")?,
},
_db: db,
}
})
}
}

+ 20
- 12
src/database/account_data.rs View File

@@ -1,5 +1,6 @@
use crate::{utils, Error, Result};
use ruma::{
api::client::error::ErrorKind,
events::{collections::only::Event as EduEvent, EventJson, EventType},
identifiers::{RoomId, UserId},
};
@@ -20,7 +21,10 @@ impl AccountData {
globals: &super::globals::Globals,
) -> Result<()> {
if json.get("content").is_none() {
return Err(Error::BadRequest("json needs to have a content field"));
return Err(Error::BadRequest(
ErrorKind::BadJson,
"Json needs to have a content field.",
));
}
json.insert("type".to_owned(), kind.to_string().into());

@@ -62,9 +66,10 @@ impl AccountData {
key.push(0xff);
key.extend_from_slice(kind.to_string().as_bytes());

self.roomuserdataid_accountdata
.insert(key, &*serde_json::to_string(&json)?)
.unwrap();
self.roomuserdataid_accountdata.insert(
key,
&*serde_json::to_string(&json).expect("Map::to_string always works"),
)?;

Ok(())
}
@@ -109,17 +114,20 @@ impl AccountData {
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(k, v)| {
Ok::<_, Error>((
EventType::try_from(utils::string_from_bytes(
k.rsplit(|&b| b == 0xff)
.next()
.ok_or(Error::BadDatabase("roomuserdataid is invalid"))?,
)?)
.map_err(|_| Error::BadDatabase("roomuserdataid is invalid"))?,
serde_json::from_slice::<EventJson<EduEvent>>(&v).unwrap(),
EventType::try_from(
utils::string_from_bytes(k.rsplit(|&b| b == 0xff).next().ok_or_else(
|| Error::bad_database("RoomUserData ID in db is invalid."),
)?)
.map_err(|_| Error::bad_database("RoomUserData ID in db is invalid."))?,
)
.map_err(|_| Error::bad_database("RoomUserData ID in db is invalid."))?,
serde_json::from_slice::<EventJson<EduEvent>>(&v).map_err(|_| {
Error::bad_database("Database contains invalid account data.")
})?,
))
})
{
let (kind, data) = r.unwrap();
let (kind, data) = r?;
userdata.insert(kind, data);
}



+ 13
- 5
src/database/global_edus.rs View File

@@ -1,4 +1,4 @@
use crate::Result;
use crate::{Error, Result};
use ruma::events::EventJson;

pub struct GlobalEdus {
@@ -21,7 +21,10 @@ impl GlobalEdus {
.rev()
.filter_map(|r| r.ok())
.find(|key| {
key.rsplit(|&b| b == 0xff).next().unwrap() == presence.sender.to_string().as_bytes()
key.rsplit(|&b| b == 0xff)
.next()
.expect("rsplit always returns an element")
== presence.sender.to_string().as_bytes()
})
{
// This is the old global_latest
@@ -32,8 +35,10 @@ impl GlobalEdus {
presence_id.push(0xff);
presence_id.extend_from_slice(&presence.sender.to_string().as_bytes());

self.presenceid_presence
.insert(presence_id, &*serde_json::to_string(&presence)?)?;
self.presenceid_presence.insert(
presence_id,
&*serde_json::to_string(&presence).expect("PresenceEvent can be serialized"),
)?;

Ok(())
}
@@ -50,6 +55,9 @@ impl GlobalEdus {
.presenceid_presence
.range(&*first_possible_edu..)
.filter_map(|r| r.ok())
.map(|(_, v)| Ok(serde_json::from_slice(&v)?)))
.map(|(_, v)| {
Ok(serde_json::from_slice(&v)
.map_err(|_| Error::bad_database("Invalid presence event in db."))?)
}))
}
}

+ 13
- 13
src/database/globals.rs View File

@@ -1,4 +1,4 @@
use crate::{utils, Result};
use crate::{utils, Error, Result};

pub const COUNTER: &str = "c";

@@ -11,17 +11,16 @@ pub struct Globals {
}

impl Globals {
pub fn load(globals: sled::Tree, config: &rocket::Config) -> Self {
pub fn load(globals: sled::Tree, config: &rocket::Config) -> Result<Self> {
let keypair = ruma::signatures::Ed25519KeyPair::new(
&*globals
.update_and_fetch("keypair", utils::generate_keypair)
.unwrap()
.unwrap(),
.update_and_fetch("keypair", utils::generate_keypair)?
.expect("utils::generate_keypair always returns Some"),
"key1".to_owned(),
)
.unwrap();
.map_err(|_| Error::bad_database("Private or public keys are invalid."))?;

Self {
Ok(Self {
globals,
keypair,
reqwest_client: reqwest::Client::new(),
@@ -30,7 +29,7 @@ impl Globals {
.unwrap_or("localhost")
.to_owned(),
registration_disabled: config.get_bool("registration_disabled").unwrap_or(false),
}
})
}

/// Returns this server's keypair.
@@ -49,14 +48,15 @@ impl Globals {
.globals
.update_and_fetch(COUNTER, utils::increment)?
.expect("utils::increment will always put in a value"),
))
)
.map_err(|_| Error::bad_database("Count has invalid bytes."))?)
}

pub fn current_count(&self) -> Result<u64> {
Ok(self
.globals
.get(COUNTER)?
.map_or(0_u64, |bytes| utils::u64_from_bytes(&bytes)))
self.globals.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
Ok(utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Count has invalid bytes."))?)
})
}

pub fn server_name(&self) -> &str {


+ 26
- 13
src/database/media.rs View File

@@ -43,16 +43,20 @@ impl Media {
let content_type = utils::string_from_bytes(
parts
.next()
.ok_or(Error::BadDatabase("mediaid is invalid"))?,
)?;
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?,
)
.map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode."))?;

let filename_bytes = parts
.next()
.ok_or(Error::BadDatabase("mediaid is invalid"))?;
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;

let filename = if filename_bytes.is_empty() {
None
} else {
Some(utils::string_from_bytes(filename_bytes)?)
Some(utils::string_from_bytes(filename_bytes).map_err(|_| {
Error::bad_database("Filename in mediaid_file is invalid unicode.")
})?)
};

Ok(Some((filename, content_type, file.to_vec())))
@@ -89,16 +93,21 @@ impl Media {
let content_type = utils::string_from_bytes(
parts
.next()
.ok_or(Error::BadDatabase("mediaid is invalid"))?,
)?;
.ok_or_else(|| Error::bad_database("Invalid Media ID in db"))?,
)
.map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode."))?;

let filename_bytes = parts
.next()
.ok_or(Error::BadDatabase("mediaid is invalid"))?;
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;

let filename = if filename_bytes.is_empty() {
None
} else {
Some(utils::string_from_bytes(filename_bytes)?)
Some(
utils::string_from_bytes(filename_bytes)
.map_err(|_| Error::bad_database("Filename in db is invalid."))?,
)
};

Ok(Some((filename, content_type, file.to_vec())))
@@ -110,16 +119,20 @@ impl Media {
let content_type = utils::string_from_bytes(
parts
.next()
.ok_or(Error::BadDatabase("mediaid is invalid"))?,
)?;
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?,
)
.map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode."))?;

let filename_bytes = parts
.next()
.ok_or(Error::BadDatabase("mediaid is invalid"))?;
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;

let filename = if filename_bytes.is_empty() {
None
} else {
Some(utils::string_from_bytes(filename_bytes)?)
Some(utils::string_from_bytes(filename_bytes).map_err(|_| {
Error::bad_database("Filename in mediaid_file is invalid unicode.")
})?)
};

if let Ok(image) = image::load_from_memory(&file) {
@@ -132,7 +145,7 @@ impl Media {
let width_index = thumbnail_key
.iter()
.position(|&b| b == 0xff)
.ok_or(Error::BadDatabase("mediaid is invalid"))?
.ok_or_else(|| Error::bad_database("Media in db is invalid."))?
+ 1;
let mut widthheight = width.to_be_bytes().to_vec();
widthheight.extend_from_slice(&height.to_be_bytes());


+ 355
- 236
src/database/rooms.rs View File

@@ -5,6 +5,7 @@ pub use edus::RoomEdus;
use crate::{utils, Error, PduEvent, Result};
use log::error;
use ruma::{
api::client::error::ErrorKind,
events::{
room::{
join_rules, member,
@@ -61,30 +62,34 @@ impl Rooms {
.roomstateid_pdu
.scan_prefix(&room_id.to_string().as_bytes())
.values()
.map(|value| Ok::<_, Error>(serde_json::from_slice::<PduEvent>(&value?)?))
.map(|value| {
Ok::<_, Error>(
serde_json::from_slice::<PduEvent>(&value?)
.map_err(|_| Error::bad_database("Invalid PDU in db."))?,
)
})
{
let pdu = pdu?;
hashmap.insert(
(
pdu.kind.clone(),
pdu.state_key
.clone()
.expect("state events have a state key"),
),
pdu,
);
let state_key = pdu.state_key.clone().ok_or_else(|| {
Error::bad_database("Room state contains event without state_key.")
})?;
hashmap.insert((pdu.kind.clone(), state_key), pdu);
}
Ok(hashmap)
}

/// Returns the `count` of this pdu's id.
pub fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<u64>> {
Ok(self
.eventid_pduid
self.eventid_pduid
.get(event_id.to_string().as_bytes())?
.map(|pdu_id| {
utils::u64_from_bytes(&pdu_id[pdu_id.len() - mem::size_of::<u64>()..pdu_id.len()])
}))
.map_or(Ok(None), |pdu_id| {
Ok(Some(
utils::u64_from_bytes(
&pdu_id[pdu_id.len() - mem::size_of::<u64>()..pdu_id.len()],
)
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?,
))
})
}

/// Returns the json of a pdu.
@@ -92,11 +97,12 @@ impl Rooms {
self.eventid_pduid
.get(event_id.to_string().as_bytes())?
.map_or(Ok(None), |pdu_id| {
Ok(Some(serde_json::from_slice(
&self.pduid_pdu.get(pdu_id)?.ok_or(Error::BadDatabase(
"eventid_pduid points to nonexistent pdu",
))?,
)?))
Ok(Some(
serde_json::from_slice(&self.pduid_pdu.get(pdu_id)?.ok_or_else(|| {
Error::bad_database("eventid_pduid points to nonexistent pdu.")
})?)
.map_err(|_| Error::bad_database("Invalid PDU in db."))?,
))
})
}

@@ -112,28 +118,37 @@ impl Rooms {
self.eventid_pduid
.get(event_id.to_string().as_bytes())?
.map_or(Ok(None), |pdu_id| {
Ok(Some(serde_json::from_slice(
&self.pduid_pdu.get(pdu_id)?.ok_or(Error::BadDatabase(
"eventid_pduid points to nonexistent pdu",
))?,
)?))
Ok(Some(
serde_json::from_slice(&self.pduid_pdu.get(pdu_id)?.ok_or_else(|| {
Error::bad_database("eventid_pduid points to nonexistent pdu.")
})?)
.map_err(|_| Error::bad_database("Invalid PDU in db."))?,
))
})
}
/// Returns the pdu.
pub fn get_pdu_from_id(&self, pdu_id: &IVec) -> Result<Option<PduEvent>> {
self.pduid_pdu
.get(pdu_id)?
.map_or(Ok(None), |pdu| Ok(Some(serde_json::from_slice(&pdu)?)))
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some(
serde_json::from_slice(&pdu)
.map_err(|_| Error::bad_database("Invalid PDU in db."))?,
))
})
}

/// Returns the pdu.
pub fn replace_pdu(&self, pdu_id: &IVec, pdu: &PduEvent) -> Result<()> {
/// Removes a pdu and creates a new one with the same id.
fn replace_pdu(&self, pdu_id: &IVec, pdu: &PduEvent) -> Result<()> {
if self.pduid_pdu.get(&pdu_id)?.is_some() {
self.pduid_pdu
.insert(&pdu_id, &*serde_json::to_string(pdu)?)?;
self.pduid_pdu.insert(
&pdu_id,
&*serde_json::to_string(pdu).expect("PduEvent::to_string always works"),
)?;
Ok(())
} else {
Err(Error::BadRequest("pdu does not exist"))
Err(Error::BadRequest(
ErrorKind::NotFound,
"PDU does not exist.",
))
}
}

@@ -148,7 +163,14 @@ impl Rooms {
.roomid_pduleaves
.scan_prefix(prefix)
.values()
.map(|bytes| Ok::<_, Error>(EventId::try_from(&*utils::string_from_bytes(&bytes?)?)?))
.map(|bytes| {
Ok::<_, Error>(
EventId::try_from(utils::string_from_bytes(&bytes?).map_err(|_| {
Error::bad_database("EventID in roomid_pduleaves is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))?,
)
})
{
events.push(event?);
}
@@ -214,174 +236,205 @@ impl Rooms {
Ok(
serde_json::from_value::<EventJson<PowerLevelsEventContent>>(
power_levels.content.clone(),
)?
.deserialize()?,
)
.expect("EventJson::from_value always works.")
.deserialize()
.map_err(|_| Error::bad_database("Invalid PowerLevels event in db."))?,
)
},
)?;
{
let sender_membership = self
.room_state(&room_id)?
.get(&(EventType::RoomMember, sender.to_string()))
.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| {
Ok(
serde_json::from_value::<EventJson<member::MemberEventContent>>(
pdu.content.clone(),
)?
.deserialize()?
.membership,
let sender_membership = self
.room_state(&room_id)?
.get(&(EventType::RoomMember, sender.to_string()))
.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| {
Ok(
serde_json::from_value::<EventJson<member::MemberEventContent>>(
pdu.content.clone(),
)
.expect("EventJson::from_value always works.")
.deserialize()
.map_err(|_| Error::bad_database("Invalid Member event in db."))?
.membership,
)
})?;

let sender_power = power_levels.users.get(&sender).map_or_else(
|| {
if sender_membership != member::MembershipState::Join {
None
} else {
Some(&power_levels.users_default)
}
},
// If it's okay, wrap with Some(_)
Some,
);

if !match event_type {
EventType::RoomMember => {
let target_user_id = UserId::try_from(&**state_key).map_err(|_| {
Error::BadRequest(
ErrorKind::InvalidParam,
"State key of member event does not contain user id.",
)
})?;

let sender_power = power_levels.users.get(&sender).map_or_else(
|| {
if sender_membership != member::MembershipState::Join {
None
} else {
Some(&power_levels.users_default)
}
},
// If it's okay, wrap with Some(_)
Some,
);

if !match event_type {
EventType::RoomMember => {
let target_user_id = UserId::try_from(&**state_key)?;

let current_membership = self
.room_state(&room_id)?
.get(&(EventType::RoomMember, target_user_id.to_string()))
.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| {
Ok(serde_json::from_value::<
EventJson<member::MemberEventContent>,
>(pdu.content.clone())?
.deserialize()?
.membership)
})?;
let current_membership = self
.room_state(&room_id)?
.get(&(EventType::RoomMember, target_user_id.to_string()))
.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| {
Ok(
serde_json::from_value::<EventJson<member::MemberEventContent>>(
pdu.content.clone(),
)
.expect("EventJson::from_value always works.")
.deserialize()
.map_err(|_| Error::bad_database("Invalid Member event in db."))?
.membership,
)
})?;

let target_membership = serde_json::from_value::<
EventJson<member::MemberEventContent>,
>(content.clone())
.expect("EventJson::from_value always works.")
.deserialize()
.map_err(|_| Error::bad_database("Invalid Member event in db."))?
.membership;

let target_power = power_levels.users.get(&target_user_id).map_or_else(
|| {
if target_membership != member::MembershipState::Join {
None
} else {
Some(&power_levels.users_default)
}
},
// If it's okay, wrap with Some(_)
Some,
);

let target_membership = serde_json::from_value::<
EventJson<member::MemberEventContent>,
>(content.clone())?
.deserialize()?
.membership;

let target_power = power_levels.users.get(&target_user_id).map_or_else(
|| {
if target_membership != member::MembershipState::Join {
None
} else {
Some(&power_levels.users_default)
}
},
// If it's okay, wrap with Some(_)
Some,
);

let join_rules = self
.room_state(&room_id)?
let join_rules =
self.room_state(&room_id)?
.get(&(EventType::RoomJoinRules, "".to_owned()))
.map_or(join_rules::JoinRule::Public, |pdu| {
serde_json::from_value::<
.map_or(Ok::<_, Error>(join_rules::JoinRule::Public), |pdu| {
Ok(serde_json::from_value::<
EventJson<join_rules::JoinRulesEventContent>,
>(pdu.content.clone())
.unwrap()
.expect("EventJson::from_value always works.")
.deserialize()
.unwrap()
.join_rule
});

let authorized = if target_membership == member::MembershipState::Join {
let mut prev_events = prev_events.iter();
let prev_event = self
.get_pdu(prev_events.next().ok_or(Error::BadRequest(
"membership can't be the first event",
))?)?
.ok_or(Error::BadDatabase("pdu leave points to valid event"))?;
if prev_event.kind == EventType::RoomCreate
&& prev_event.prev_events.is_empty()
{
true
} else if sender != target_user_id {
false
} else if let member::MembershipState::Ban = current_membership {
false
} else {
join_rules == join_rules::JoinRule::Invite
&& (current_membership == member::MembershipState::Join
|| current_membership == member::MembershipState::Invite)
|| join_rules == join_rules::JoinRule::Public
}
} else if target_membership == member::MembershipState::Invite {
if let Some(third_party_invite_json) = content.get("third_party_invite")
{
if current_membership == member::MembershipState::Ban {
false
} else {
let _third_party_invite =
serde_json::from_value::<member::ThirdPartyInvite>(
third_party_invite_json.clone(),
)?;
todo!("handle third party invites");
}
} else if sender_membership != member::MembershipState::Join
|| current_membership == member::MembershipState::Join
|| current_membership == member::MembershipState::Ban
{
false
} else {
sender_power
.filter(|&p| p >= &power_levels.invite)
.is_some()
}
} else if target_membership == member::MembershipState::Leave {
if sender == target_user_id {
current_membership == member::MembershipState::Join
|| current_membership == member::MembershipState::Invite
} else if sender_membership != member::MembershipState::Join
|| current_membership == member::MembershipState::Ban
&& sender_power.filter(|&p| p < &power_levels.ban).is_some()
{
false
} else {
sender_power.filter(|&p| p >= &power_levels.kick).is_some()
&& target_power < sender_power
}
} else if target_membership == member::MembershipState::Ban {
if sender_membership != member::MembershipState::Join {
.map_err(|_| {
Error::bad_database("Database contains invalid JoinRules event")
})?
.join_rule)
})?;

let authorized = if target_membership == member::MembershipState::Join {
let mut prev_events = prev_events.iter();
let prev_event = self
.get_pdu(prev_events.next().ok_or(Error::BadRequest(
ErrorKind::Unknown,
"Membership can't be the first event",
))?)?
.ok_or_else(|| {
Error::bad_database("PDU leaf points to invalid event!")
})?;
if prev_event.kind == EventType::RoomCreate
&& prev_event.prev_events.is_empty()
{
true
} else if sender != target_user_id {
false
} else if let member::MembershipState::Ban = current_membership {
false
} else {
join_rules == join_rules::JoinRule::Invite
&& (current_membership == member::MembershipState::Join
|| current_membership == member::MembershipState::Invite)
|| join_rules == join_rules::JoinRule::Public
}
} else if target_membership == member::MembershipState::Invite {
if let Some(third_party_invite_json) = content.get("third_party_invite") {
if current_membership == member::MembershipState::Ban {
false
} else {
sender_power.filter(|&p| p >= &power_levels.ban).is_some()
&& target_power < sender_power
let _third_party_invite =
serde_json::from_value::<member::ThirdPartyInvite>(
third_party_invite_json.clone(),
)
.map_err(|_| {
Error::BadRequest(
ErrorKind::InvalidParam,
"ThirdPartyInvite is invalid",
)
})?;
todo!("handle third party invites");
}
} else if sender_membership != member::MembershipState::Join
|| current_membership == member::MembershipState::Join
|| current_membership == member::MembershipState::Ban
{
false
} else {
sender_power
.filter(|&p| p >= &power_levels.invite)
.is_some()
}
} else if target_membership == member::MembershipState::Leave {
if sender == target_user_id {
current_membership == member::MembershipState::Join
|| current_membership == member::MembershipState::Invite
} else if sender_membership != member::MembershipState::Join
|| current_membership == member::MembershipState::Ban
&& sender_power.filter(|&p| p < &power_levels.ban).is_some()
{
false
};

if authorized {
// Update our membership info
self.update_membership(&room_id, &target_user_id, &target_membership)?;
} else {
sender_power.filter(|&p| p >= &power_levels.kick).is_some()
&& target_power < sender_power
}
} else if target_membership == member::MembershipState::Ban {
if sender_membership != member::MembershipState::Join {
false
} else {
sender_power.filter(|&p| p >= &power_levels.ban).is_some()
&& target_power < sender_power
}
} else {
false
};

authorized
if authorized {
// Update our membership info
self.update_membership(&room_id, &target_user_id, &target_membership)?;
}
EventType::RoomCreate => prev_events.is_empty(),
// Not allow any of the following events if the sender is not joined.
_ if sender_membership != member::MembershipState::Join => false,

_ => {
// TODO
sender_power.unwrap_or(&power_levels.users_default)
>= &power_levels.state_default
}
} {
error!("Unauthorized");
// Not authorized
return Err(Error::BadRequest("event not authorized"));

authorized
}
EventType::RoomCreate => prev_events.is_empty(),
// Not allow any of the following events if the sender is not joined.
_ if sender_membership != member::MembershipState::Join => false,

_ => {
// TODO
sender_power.unwrap_or(&power_levels.users_default)
>= &power_levels.state_default
}
} {
error!("Unauthorized");
// Not authorized
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"Event is not authorized",
));
}
} else if !self.is_joined(&sender, &room_id)? {
return Err(Error::BadRequest("event not authorized"));
// TODO: auth rules apply to all events, not only those with a state key
error!("Unauthorized");
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"Event is not authorized",
));
}

// Our depth is the maximum depth of prev_events + 1
@@ -410,14 +463,14 @@ impl Rooms {
origin: globals.server_name().to_owned(),
origin_server_ts: utils::millis_since_unix_epoch()
.try_into()
.expect("this only fails many years in the future"),
.expect("time is valid"),
kind: event_type.clone(),
content: content.clone(),
state_key,
prev_events,
depth: depth
.try_into()
.expect("depth can overflow and should be deprecated..."),
.map_err(|_| Error::bad_database("Depth is invalid"))?,
auth_events: Vec::new(),
redacts: redacts.clone(),
unsigned,
@@ -430,18 +483,20 @@ impl Rooms {
// Generate event id
pdu.event_id = EventId::try_from(&*format!(
"${}",
ruma::signatures::reference_hash(&serde_json::to_value(&pdu)?)
.expect("ruma can calculate reference hashes")
ruma::signatures::reference_hash(
&serde_json::to_value(&pdu).expect("event is valid, we just created it")
)
.expect("ruma can calculate reference hashes")
))
.expect("ruma's reference hashes are correct");
.expect("ruma's reference hashes are valid event ids");

let mut pdu_json = serde_json::to_value(&pdu)?;
let mut pdu_json = serde_json::to_value(&pdu).expect("event is valid, we just created it");
ruma::signatures::hash_and_sign_event(
globals.server_name(),
globals.keypair(),
&mut pdu_json,
)
.expect("our new event can be hashed and signed");
.expect("event is valid, we just created it");

self.replace_pdu_leaves(&room_id, &pdu.event_id)?;

@@ -473,8 +528,15 @@ impl Rooms {
// TODO: Reason
let _reason = serde_json::from_value::<
EventJson<redaction::RedactionEventContent>,
>(content)?
.deserialize()?
>(content)
.expect("EventJson::from_value always works.")
.deserialize()
.map_err(|_| {
Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid redaction event content.",
)
})?
.reason;

self.redact_pdu(&redact_id)?;
@@ -528,7 +590,10 @@ impl Rooms {
})
.filter_map(|r| r.ok())
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(_, v)| Ok(serde_json::from_slice(&v)?)))
.map(|(_, v)| {
Ok(serde_json::from_slice(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?)
}))
}

/// Returns an iterator over all events in a room that happened before the event with id
@@ -552,7 +617,10 @@ impl Rooms {
.rev()
.filter_map(|r| r.ok())
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(_, v)| Ok(serde_json::from_slice(&v)?))
.map(|(_, v)| {
Ok(serde_json::from_slice(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?)
})
}

/// Returns an iterator over all events in a room that happened after the event with id
@@ -575,7 +643,10 @@ impl Rooms {
.range(current..)
.filter_map(|r| r.ok())
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(_, v)| Ok(serde_json::from_slice(&v)?))
.map(|(_, v)| {
Ok(serde_json::from_slice(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?)
})
}

/// Replace a PDU with the redacted form.
@@ -583,12 +654,15 @@ impl Rooms {
if let Some(pdu_id) = self.get_pdu_id(event_id)? {
let mut pdu = self
.get_pdu_from_id(&pdu_id)?
.ok_or(Error::BadDatabase("pduid points to invalid pdu"))?;
pdu.redact();
.ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?;
pdu.redact()?;
self.replace_pdu(&pdu_id, &pdu)?;
Ok(())
} else {
Err(Error::BadRequest("eventid does not exist"))
Err(Error::BadRequest(
ErrorKind::NotFound,
"Event ID does not exist.",
))
}
}

@@ -664,7 +738,10 @@ impl Rooms {
let room_id = self
.alias_roomid
.remove(alias.alias())?
.ok_or(Error::BadRequest("Alias does not exist"))?;
.ok_or(Error::BadRequest(
ErrorKind::NotFound,
"Alias does not exist.",
))?;

for key in self.aliasid_alias.scan_prefix(room_id).keys() {
self.aliasid_alias.remove(key?)?;
@@ -678,7 +755,12 @@ impl Rooms {
self.alias_roomid
.get(alias.alias())?
.map_or(Ok(None), |bytes| {
Ok(Some(RoomId::try_from(utils::string_from_bytes(&bytes)?)?))
Ok(Some(
RoomId::try_from(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Room ID in alias_roomid is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid."))?,
))
})
}

@@ -689,7 +771,10 @@ impl Rooms {
self.aliasid_alias
.scan_prefix(prefix)
.values()
.map(|bytes| Ok(RoomAliasId::try_from(utils::string_from_bytes(&bytes?)?)?))
.map(|bytes| {
Ok(serde_json::from_slice(&bytes?)
.map_err(|_| Error::bad_database("Alias in aliasid_alias is invalid."))?)
})
}

pub fn set_public(&self, room_id: &RoomId, public: bool) -> Result<()> {
@@ -707,54 +792,76 @@ impl Rooms {
}

pub fn public_rooms(&self) -> impl Iterator<Item = Result<RoomId>> {
self.publicroomids
.iter()
.keys()
.map(|bytes| Ok(RoomId::try_from(utils::string_from_bytes(&bytes?)?)?))
self.publicroomids.iter().keys().map(|bytes| {
Ok(
RoomId::try_from(utils::string_from_bytes(&bytes?).map_err(|_| {
Error::bad_database("Room ID in publicroomids is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid."))?,
)
})
}

/// Returns an iterator over all rooms a user joined.
/// Returns an iterator over all joined members of a room.
pub fn room_members(&self, room_id: &RoomId) -> impl Iterator<Item = Result<UserId>> {
self.roomuserid_joined
.scan_prefix(room_id.to_string())
.values()
.keys()
.map(|key| {
Ok(UserId::try_from(&*utils::string_from_bytes(
&key?
.rsplit(|&b| b == 0xff)
.next()
.ok_or(Error::BadDatabase("userroomid is invalid"))?,
)?)?)
Ok(UserId::try_from(
utils::string_from_bytes(
&key?
.rsplit(|&b| b == 0xff)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("User ID in roomuserid_joined is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid."))?)
})
}

/// Returns an iterator over all rooms a user joined.
/// Returns an iterator over all invited members of a room.
pub fn room_members_invited(&self, room_id: &RoomId) -> impl Iterator<Item = Result<UserId>> {
self.roomuserid_invited
.scan_prefix(room_id.to_string())
.keys()
.map(|key| {
Ok(UserId::try_from(&*utils::string_from_bytes(
&key?
.rsplit(|&b| b == 0xff)
.next()
.ok_or(Error::BadDatabase("userroomid is invalid"))?,
)?)?)
Ok(UserId::try_from(
utils::string_from_bytes(
&key?
.rsplit(|&b| b == 0xff)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("User ID in roomuserid_invited is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid."))?)
})
}

/// Returns an iterator over all rooms a user joined.
/// Returns an iterator over all left members of a room.
pub fn rooms_joined(&self, user_id: &UserId) -> impl Iterator<Item = Result<RoomId>> {
self.userroomid_joined
.scan_prefix(user_id.to_string())
.keys()
.map(|key| {
Ok(RoomId::try_from(&*utils::string_from_bytes(
&key?
.rsplit(|&b| b == 0xff)
.next()
.ok_or(Error::BadDatabase("userroomid is invalid"))?,
)?)?)
Ok(RoomId::try_from(
utils::string_from_bytes(
&key?
.rsplit(|&b| b == 0xff)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("Room ID in userroomid_joined is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid."))?)
})
}

@@ -764,12 +871,18 @@ impl Rooms {
.scan_prefix(&user_id.to_string())
.keys()
.map(|key| {
Ok(RoomId::try_from(&*utils::string_from_bytes(
&key?
.rsplit(|&b| b == 0xff)
.next()
.ok_or(Error::BadDatabase("userroomid is invalid"))?,
)?)?)
Ok(RoomId::try_from(
utils::string_from_bytes(
&key?
.rsplit(|&b| b == 0xff)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("Room ID in userroomid_invited is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?)
})
}

@@ -779,12 +892,18 @@ impl Rooms {
.scan_prefix(&user_id.to_string())
.keys()
.map(|key| {
Ok(RoomId::try_from(&*utils::string_from_bytes(
&key?
.rsplit(|&b| b == 0xff)
.next()
.ok_or(Error::BadDatabase("userroomid is invalid"))?,
)?)?)
Ok(RoomId::try_from(
utils::string_from_bytes(
&key?
.rsplit(|&b| b == 0xff)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("Room ID in userroomid_left is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_left is invalid."))?)
})
}



+ 45
- 18
src/database/rooms/edus.rs View File

@@ -33,7 +33,10 @@ impl RoomEdus {
.filter_map(|r| r.ok())
.take_while(|key| key.starts_with(&prefix))
.find(|key| {
key.rsplit(|&b| b == 0xff).next().unwrap() == user_id.to_string().as_bytes()
key.rsplit(|&b| b == 0xff)
.next()
.expect("rsplit always returns an element")
== user_id.to_string().as_bytes()
})
{
// This is the old room_latest
@@ -45,8 +48,10 @@ impl RoomEdus {
room_latest_id.push(0xff);
room_latest_id.extend_from_slice(&user_id.to_string().as_bytes());

self.roomlatestid_roomlatest
.insert(room_latest_id, &*serde_json::to_string(&event)?)?;
self.roomlatestid_roomlatest.insert(
room_latest_id,
&*serde_json::to_string(&event).expect("EduEvent::to_string always works"),
)?;

Ok(())
}
@@ -68,7 +73,11 @@ impl RoomEdus {
.range(&*first_possible_edu..)
.filter_map(|r| r.ok())
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(_, v)| Ok(serde_json::from_slice(&v)?)))
.map(|(_, v)| {
Ok(serde_json::from_slice(&v).map_err(|_| {
Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid.")
})?)
}))
}

/// Sets a user as typing until the timeout timestamp is reached or roomactive_remove is
@@ -152,17 +161,21 @@ impl RoomEdus {
.roomactiveid_userid
.scan_prefix(&prefix)
.keys()
.filter_map(|r| r.ok())
.take_while(|k| {
utils::u64_from_bytes(
k.split(|&c| c == 0xff)
.nth(1)
.expect("roomactive has valid timestamp and delimiters"),
) < current_timestamp
.map(|key| {
let key = key?;
Ok::<_, Error>((
key.clone(),
utils::u64_from_bytes(key.split(|&b| b == 0xff).nth(1).ok_or_else(|| {
Error::bad_database("RoomActive has invalid timestamp or delimiters.")
})?)
.map_err(|_| Error::bad_database("RoomActive has invalid timestamp bytes."))?,
))
})
.filter_map(|r| r.ok())
.take_while(|&(_, timestamp)| timestamp < current_timestamp)
{
// This is an outdated edu (time > timestamp)
self.roomactiveid_userid.remove(outdated_edu)?;
self.roomactiveid_userid.remove(outdated_edu.0)?;
found_outdated = true;
}

@@ -187,7 +200,11 @@ impl RoomEdus {
Ok(self
.roomid_lastroomactiveupdate
.get(&room_id.to_string().as_bytes())?
.map(|bytes| utils::u64_from_bytes(&bytes))
.map_or(Ok::<_, Error>(None), |bytes| {
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.")
})?))
})?
.unwrap_or(0))
}

@@ -202,7 +219,16 @@ impl RoomEdus {
.roomactiveid_userid
.scan_prefix(prefix)
.values()
.map(|user_id| Ok::<_, Error>(UserId::try_from(utils::string_from_bytes(&user_id?)?)?))
.map(|user_id| {
Ok::<_, Error>(
UserId::try_from(utils::string_from_bytes(&user_id?).map_err(|_| {
Error::bad_database("User ID in roomactiveid_userid is invalid unicode.")
})?)
.map_err(|_| {
Error::bad_database("User ID in roomactiveid_userid is invalid.")
})?,
)
})
{
user_ids.push(user_id?);
}
@@ -230,9 +256,10 @@ impl RoomEdus {
key.push(0xff);
key.extend_from_slice(&user_id.to_string().as_bytes());

Ok(self
.roomuserid_lastread
.get(key)?
.map(|v| utils::u64_from_bytes(&v)))
self.roomuserid_lastread.get(key)?.map_or(Ok(None), |v| {
Ok(Some(utils::u64_from_bytes(&v).map_err(|_| {
Error::bad_database("Invalid private read marker bytes")
})?))
})
}
}

+ 57
- 13
src/database/uiaa.rs View File

@@ -43,15 +43,51 @@ impl Uiaa {
// Find out what the user completed
match &**kind {
"m.login.password" => {
if auth_parameters["identifier"]["type"] != "m.id.user" {
panic!("identifier not supported");
let identifier = auth_parameters.get("identifier").ok_or(Error::BadRequest(
ErrorKind::MissingParam,
"m.login.password needs identifier.",
))?;

let identifier_type = identifier.get("type").ok_or(Error::BadRequest(
ErrorKind::MissingParam,
"Identifier needs a type.",
))?;

if identifier_type != "m.id.user" {
return Err(Error::BadRequest(
ErrorKind::Unrecognized,
"Identifier type not recognized.",
));
}

let user_id = UserId::parse_with_server_name(
auth_parameters["identifier"]["user"].as_str().unwrap(),
globals.server_name(),
)?;
let password = auth_parameters["password"].as_str().unwrap();
let username = identifier
.get("user")
.ok_or(Error::BadRequest(
ErrorKind::MissingParam,
"Identifier needs user field.",
))?
.as_str()
.ok_or(Error::BadRequest(
ErrorKind::BadJson,
"User is not a string.",
))?;

let user_id = UserId::parse_with_server_name(username, globals.server_name())
.map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.")
})?;

let password = auth_parameters
.get("password")
.ok_or(Error::BadRequest(
ErrorKind::MissingParam,
"Password is missing.",
))?
.as_str()
.ok_or(Error::BadRequest(
ErrorKind::BadJson,
"Password is not a string.",
))?;

// Check if password is correct
if let Some(hash) = users.password_hash(&user_id)? {
@@ -59,7 +95,6 @@ impl Uiaa {
argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false);

if !hash_matches {
debug!("Invalid password.");
uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody {
kind: ErrorKind::Forbidden,
message: "Invalid username or password.".to_owned(),
@@ -113,8 +148,10 @@ impl Uiaa {
userdeviceid.extend_from_slice(device_id.as_bytes());

if let Some(uiaainfo) = uiaainfo {
self.userdeviceid_uiaainfo
.insert(&userdeviceid, &*serde_json::to_string(&uiaainfo)?)?;
self.userdeviceid_uiaainfo.insert(
&userdeviceid,
&*serde_json::to_string(&uiaainfo).expect("UiaaInfo::to_string always works"),
)?;
} else {
self.userdeviceid_uiaainfo.remove(&userdeviceid)?;
}
@@ -136,8 +173,12 @@ impl Uiaa {
&self
.userdeviceid_uiaainfo
.get(&userdeviceid)?
.ok_or(Error::BadRequest("session does not exist"))?,
)?;
.ok_or(Error::BadRequest(
ErrorKind::Forbidden,
"UIAA session does not exist.",
))?,
)
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))?;

if uiaainfo
.session
@@ -145,7 +186,10 @@ impl Uiaa {
.filter(|&s| s == session)
.is_none()
{
return Err(Error::BadRequest("wrong session token"));
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"UIAA session token invalid.",
));
}

Ok(uiaainfo)


+ 125
- 61
src/database/users.rs View File

@@ -43,24 +43,36 @@ impl Users {
.get(token)?
.map_or(Ok(None), |bytes| {
let mut parts = bytes.split(|&b| b == 0xff);
let user_bytes = parts
.next()
.ok_or(Error::BadDatabase("token_userdeviceid value invalid"))?;
let device_bytes = parts
.next()
.ok_or(Error::BadDatabase("token_userdeviceid value invalid"))?;
let user_bytes = parts.next().ok_or_else(|| {
Error::bad_database("User ID in token_userdeviceid is invalid.")
})?;
let device_bytes = parts.next().ok_or_else(|| {
Error::bad_database("Device ID in token_userdeviceid is invalid.")
})?;

Ok(Some((
UserId::try_from(utils::string_from_bytes(&user_bytes)?)?,
utils::string_from_bytes(&device_bytes)?,
UserId::try_from(utils::string_from_bytes(&user_bytes).map_err(|_| {
Error::bad_database("User ID in token_userdeviceid is invalid unicode.")
})?)
.map_err(|_| {
Error::bad_database("User ID in token_userdeviceid is invalid.")
})?,
utils::string_from_bytes(&device_bytes).map_err(|_| {
Error::bad_database("Device ID in token_userdeviceid is invalid.")
})?,
)))
})
}

/// Returns an iterator over all users on this homeserver.
pub fn iter(&self) -> impl Iterator<Item = Result<UserId>> {
self.userid_password.iter().keys().map(|r| {
utils::string_from_bytes(&r?).and_then(|string| Ok(UserId::try_from(&*string)?))
self.userid_password.iter().keys().map(|bytes| {
Ok(
UserId::try_from(utils::string_from_bytes(&bytes?).map_err(|_| {
Error::bad_database("User ID in userid_password is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("User ID in userid_password is invalid."))?,
)
})
}

@@ -68,14 +80,22 @@ impl Users {
pub fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_password
.get(user_id.to_string())?
.map_or(Ok(None), |bytes| utils::string_from_bytes(&bytes).map(Some))
.map_or(Ok(None), |bytes| {
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Password hash in db is not valid string.")
})?))
})
}

/// Returns the displayname of a user on this homeserver.
pub fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_displayname
.get(user_id.to_string())?
.map_or(Ok(None), |bytes| utils::string_from_bytes(&bytes).map(Some))
.map_or(Ok(None), |bytes| {
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Displayname in db is invalid.")
})?))
})
}

/// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change.
@@ -94,7 +114,11 @@ impl Users {
pub fn avatar_url(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_avatarurl
.get(user_id.to_string())?
.map_or(Ok(None), |bytes| utils::string_from_bytes(&bytes).map(Some))
.map_or(Ok(None), |bytes| {
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Avatar URL in db is invalid.")
})?))
})
}

/// Sets a new avatar_url or removes it if avatar_url is None.
@@ -117,11 +141,8 @@ impl Users {
token: &str,
initial_device_display_name: Option<String>,
) -> Result<()> {
if !self.exists(user_id)? {
return Err(Error::BadRequest(
"tried to create device for nonexistent user",
));
}
// This method should never be called for nonexistent users.
assert!(self.exists(user_id)?);

let mut userdeviceid = user_id.to_string().as_bytes().to_vec();
userdeviceid.push(0xff);
@@ -134,7 +155,8 @@ impl Users {
display_name: initial_device_display_name,
last_seen_ip: None, // TODO
last_seen_ts: Some(SystemTime::now()),
})?
})
.expect("Device::to_string never fails.")
.as_bytes(),
)?;

@@ -185,23 +207,22 @@ impl Users {
&*bytes?
.rsplit(|&b| b == 0xff)
.next()
.ok_or(Error::BadDatabase("userdeviceid is invalid"))?,
)?)
.ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?,
)
.map_err(|_| {
Error::bad_database("Device ID in userdeviceid_metadata is invalid.")
})?)
})
}

/// Replaces the access token of one device.
pub fn set_token(&self, user_id: &UserId, device_id: &str, token: &str) -> Result<()> {
fn set_token(&self, user_id: &UserId, device_id: &str, token: &str) -> Result<()> {
let mut userdeviceid = user_id.to_string().as_bytes().to_vec();
userdeviceid.push(0xff);
userdeviceid.extend_from_slice(device_id.as_bytes());

// All devices have metadata
if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() {
return Err(Error::BadRequest(
"Tried to set token for nonexistent device",
));
}
assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some());

// Remove old token
if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? {
@@ -228,19 +249,23 @@ impl Users {
key.extend_from_slice(device_id.as_bytes());

// All devices have metadata
if self.userdeviceid_metadata.get(&key)?.is_none() {
return Err(Error::BadRequest(
"Tried to set token for nonexistent device",
));
}
// Only existing devices should be able to call this.
assert!(self.userdeviceid_metadata.get(&key)?.is_some());

key.push(0xff);
// TODO: Use AlgorithmAndDeviceId::to_string when it's available (and update everything,
// because there are no wrapping quotation marks anymore)
key.extend_from_slice(&serde_json::to_string(one_time_key_key)?.as_bytes());

self.onetimekeyid_onetimekeys
.insert(&key, &*serde_json::to_string(&one_time_key_value)?)?;
key.extend_from_slice(
&serde_json::to_string(one_time_key_key)
.expect("AlgorithmAndDeviceId::to_string always works")
.as_bytes(),
);

self.onetimekeyid_onetimekeys.insert(
&key,
&*serde_json::to_string(&one_time_key_value)
.expect("OneTimeKey::to_string always works"),
)?;

Ok(())
}
@@ -271,9 +296,11 @@ impl Users {
&*key
.rsplit(|&b| b == 0xff)
.next()
.ok_or(Error::BadDatabase("onetimekeyid is invalid"))?,
)?,
serde_json::from_slice(&*value)?,
.ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?,
)
.map_err(|_| Error::bad_database("OneTimeKeyId in db is invalid."))?,
serde_json::from_slice(&*value)
.map_err(|_| Error::bad_database("OneTimeKeys in db are invalid."))?,
))
})
.transpose()
@@ -297,11 +324,11 @@ impl Users {
.map(|bytes| {
Ok::<_, Error>(
serde_json::from_slice::<AlgorithmAndDeviceId>(
&*bytes?
.rsplit(|&b| b == 0xff)
.next()
.ok_or(Error::BadDatabase("onetimekeyid is invalid"))?,
)?
&*bytes?.rsplit(|&b| b == 0xff).next().ok_or_else(|| {
Error::bad_database("OneTimeKey ID in db is invalid.")
})?,
)
.map_err(|_| Error::bad_database("AlgorithmAndDeviceID in db is invalid."))?
.0,
)
})
@@ -323,8 +350,10 @@ impl Users {
userdeviceid.push(0xff);
userdeviceid.extend_from_slice(device_id.as_bytes());

self.userdeviceid_devicekeys
.insert(&userdeviceid, &*serde_json::to_string(&device_keys)?)?;
self.userdeviceid_devicekeys.insert(
&userdeviceid,
&*serde_json::to_string(&device_keys).expect("DeviceKeys::to_string always works"),
)?;

self.devicekeychangeid_userid
.insert(globals.next_count()?.to_be_bytes(), &*user_id.to_string())?;
@@ -344,14 +373,28 @@ impl Users {
self.userdeviceid_devicekeys
.scan_prefix(key)
.values()
.map(|bytes| Ok(serde_json::from_slice(&bytes?)?))
.map(|bytes| {
Ok(serde_json::from_slice(&bytes?)
.map_err(|_| Error::bad_database("DeviceKeys in db are invalid."))?)
})
}

pub fn device_keys_changed(&self, since: u64) -> impl Iterator<Item = Result<UserId>> {
self.devicekeychangeid_userid
.range(since.to_be_bytes()..)
.values()
.map(|bytes| Ok(UserId::try_from(utils::string_from_bytes(&bytes?)?)?))
.map(|bytes| {
Ok(
UserId::try_from(utils::string_from_bytes(&bytes?).map_err(|_| {
Error::bad_database(
"User ID in devicekeychangeid_userid is invalid unicode.",
)
})?)
.map_err(|_| {
Error::bad_database("User ID in devicekeychangeid_userid is invalid.")
})?,
)
})
}

pub fn all_device_keys(
@@ -366,9 +409,14 @@ impl Users {
let userdeviceid = utils::string_from_bytes(
key.rsplit(|&b| b == 0xff)
.next()
.ok_or(Error::BadDatabase("userdeviceid is invalid"))?,
)?;
Ok((userdeviceid, serde_json::from_slice(&*value)?))
.ok_or_else(|| Error::bad_database("UserDeviceID in db is invalid."))?,
)
.map_err(|_| Error::bad_database("UserDeviceId in db is invalid."))?;
Ok((
userdeviceid,
serde_json::from_slice(&*value)
.map_err(|_| Error::bad_database("DeviceKeys in db are invalid."))?,
))
})
}

@@ -392,8 +440,10 @@ impl Users {
json.insert("sender".to_owned(), sender.to_string().into());
json.insert("content".to_owned(), content);

self.todeviceid_events
.insert(&key, &*serde_json::to_string(&json)?)?;
self.todeviceid_events.insert(
&key,
&*serde_json::to_string(&json).expect("Map::to_string always works"),
)?;

Ok(())
}
@@ -413,7 +463,10 @@ impl Users {

for result in self.todeviceid_events.scan_prefix(&prefix).take(max) {
let (key, value) = result?;
events.push(serde_json::from_slice(&*value)?);
events.push(
serde_json::from_slice(&*value)
.map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?,
);
self.todeviceid_events.remove(key)?;
}

@@ -430,12 +483,15 @@ impl Users {
userdeviceid.push(0xff);
userdeviceid.extend_from_slice(device_id.as_bytes());

if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() {
return Err(Error::BadRequest("device does not exist"));
}
// Only existing devices should be able to call this.
assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some());

self.userdeviceid_metadata
.insert(userdeviceid, serde_json::to_string(device)?.as_bytes())?;
self.userdeviceid_metadata.insert(
userdeviceid,
serde_json::to_string(device)
.expect("Device::to_string always works")
.as_bytes(),
)?;

Ok(())
}
@@ -448,7 +504,11 @@ impl Users {

self.userdeviceid_metadata
.get(&userdeviceid)?
.map_or(Ok(None), |bytes| Ok(Some(serde_json::from_slice(&bytes)?)))
.map_or(Ok(None), |bytes| {
Ok(Some(serde_json::from_slice(&bytes).map_err(|_| {
Error::bad_database("Metadata in userdeviceid_metadata is invalid.")
})?))
})
}

pub fn all_devices_metadata(&self, user_id: &UserId) -> impl Iterator<Item = Result<Device>> {
@@ -458,6 +518,10 @@ impl Users {
self.userdeviceid_metadata
.scan_prefix(key)
.values()
.map(|bytes| Ok(serde_json::from_slice::<Device>(&bytes?)?))
.map(|bytes| {
Ok(serde_json::from_slice::<Device>(&bytes?).map_err(|_| {
Error::bad_database("Device in userdeviceid_metadata is invalid.")
})?)
})
}
}

+ 72
- 25
src/error.rs View File

@@ -1,41 +1,88 @@
use crate::RumaResponse;
use http::StatusCode;
use log::error;
use rocket::{
response::{self, Responder},
Request,
};
use ruma::api::client::{
error::{Error as RumaError, ErrorKind},
r0::uiaa::{UiaaInfo, UiaaResponse},
};
use thiserror::Error;

pub type Result<T> = std::result::Result<T, Error>;

#[derive(Error, Debug)]
pub enum Error {
#[error("problem with the database")]
#[error("There was a problem with the connection to the database.")]
SledError {
#[from]
source: sled::Error,
},
#[error("tried to parse invalid string")]
StringFromBytesError {
#[from]
source: std::string::FromUtf8Error,
},
#[error("tried to parse invalid identifier")]
SerdeJsonError {
#[from]
source: serde_json::Error,
},
#[error("tried to parse invalid identifier")]
RumaIdentifierError {
#[from]
source: ruma::identifiers::Error,
},
#[error("tried to parse invalid event")]
RumaEventError {
#[from]
source: ruma::events::InvalidEvent,
},
#[error("could not generate image")]
#[error("Could not generate an image.")]
ImageError {
#[from]
source: image::error::ImageError,
},
#[error("bad request")]
BadRequest(&'static str),
#[error("problem in that database")]
#[error("{0}")]
BadConfig(&'static str),
#[error("{0}")]
/// Don't create this directly. Use Error::bad_database instead.
BadDatabase(&'static str),
#[error("uiaa")]
Uiaa(UiaaInfo),

#[error("{0}: {1}")]
BadRequest(ErrorKind, &'static str),
#[error("{0}")]
Conflict(&'static str), // This is only needed for when a room alias already exists
}

impl Error {
pub fn bad_database(message: &'static str) -> Self {
error!("BadDatabase: {}", message);
Self::BadDatabase(message)
}
}

#[rocket::async_trait]
impl<'r> Responder<'r> for Error {
async fn respond_to(self, r: &'r Request<'_>) -> response::Result<'r> {
if let Self::Uiaa(uiaainfo) = &self {
return RumaResponse::from(UiaaResponse::AuthResponse(uiaainfo.clone()))
.respond_to(r)
.await;
}

let message = format!("{}", self);

use ErrorKind::*;
let (kind, status_code) = match self {
Self::BadRequest(kind, _) => (
kind,
match kind {
Forbidden | GuestAccessForbidden | ThreepidAuthFailed | ThreepidDenied => {
StatusCode::FORBIDDEN
}
Unauthorized | UnknownToken | MissingToken => StatusCode::UNAUTHORIZED,
NotFound => StatusCode::NOT_FOUND,
LimitExceeded => StatusCode::TOO_MANY_REQUESTS,
UserDeactivated => StatusCode::FORBIDDEN,
TooLarge => StatusCode::PAYLOAD_TOO_LARGE,
_ => StatusCode::BAD_REQUEST,
},
),
Self::Conflict(_) => (Unknown, StatusCode::CONFLICT),
_ => (Unknown, StatusCode::INTERNAL_SERVER_ERROR),
};

RumaResponse::from(RumaError {
kind,
message,
status_code,
})
.respond_to(r)
.await
}
}

+ 2
- 2
src/main.rs View File

@@ -12,7 +12,7 @@ mod utils;
pub use database::Database;
pub use error::{Error, Result};
pub use pdu::PduEvent;
pub use ruma_wrapper::{MatrixResult, Ruma};
pub use ruma_wrapper::{ConduitResult, Ruma, RumaResponse};

use rocket::{fairing::AdHoc, routes};

@@ -95,7 +95,7 @@ fn setup_rocket() -> rocket::Rocket {
],
)
.attach(AdHoc::on_attach("Config", |rocket| {
let data = Database::load_or_create(&rocket.config());
let data = Database::load_or_create(&rocket.config()).expect("valid config");

Ok(rocket.manage(data))
}))


+ 18
- 11
src/pdu.rs View File

@@ -1,3 +1,4 @@
use crate::{Error, Result};
use js_int::UInt;
use ruma::{
api::federation::pdu::EventHash,
@@ -36,7 +37,7 @@ pub struct PduEvent {
}

impl PduEvent {
pub fn redact(&mut self) {
pub fn redact(&mut self) -> Result<()> {
self.unsigned.clear();
let allowed = match self.kind {
EventType::RoomMember => vec!["membership"],
@@ -56,7 +57,11 @@ impl PduEvent {
_ => vec![],
};

let old_content = self.content.as_object_mut().unwrap(); // TODO error
let old_content = self
.content
.as_object_mut()
.ok_or_else(|| Error::bad_database("PDU in db has invalid content."))?;

let mut new_content = serde_json::Map::new();

for key in allowed {
@@ -71,21 +76,23 @@ impl PduEvent {
);

self.content = new_content.into();

Ok(())
}

pub fn to_room_event(&self) -> EventJson<RoomEvent> {
// Can only fail in rare circumstances that won't ever happen here, see
// https://docs.rs/serde_json/1.0.50/serde_json/fn.to_string.html
let json = serde_json::to_string(&self).unwrap();
// EventJson's deserialize implementation always returns `Ok(...)`
serde_json::from_str::<EventJson<RoomEvent>>(&json).unwrap()
let json = serde_json::to_string(&self).expect("PDUs are always valid");
serde_json::from_str::<EventJson<RoomEvent>>(&json)
.expect("EventJson::from_str always works")
}
pub fn to_state_event(&self) -> EventJson<StateEvent> {
let json = serde_json::to_string(&self).unwrap();
serde_json::from_str::<EventJson<StateEvent>>(&json).unwrap()
let json = serde_json::to_string(&self).expect("PDUs are always valid");
serde_json::from_str::<EventJson<StateEvent>>(&json)
.expect("EventJson::from_str always works")
}
pub fn to_stripped_state_event(&self) -> EventJson<AnyStrippedStateEvent> {
let json = serde_json::to_string(&self).unwrap();
serde_json::from_str::<EventJson<AnyStrippedStateEvent>>(&json).unwrap()
let json = serde_json::to_string(&self).expect("PDUs are always valid");
serde_json::from_str::<EventJson<AnyStrippedStateEvent>>(&json)
.expect("EventJson::from_str always works")
}
}

+ 12
- 25
src/ruma_wrapper.rs View File

@@ -1,4 +1,4 @@
use crate::utils;
use crate::{utils, Error};
use log::warn;