Skip to content

Commit

Permalink
Merge pull request #72 from uhoreg/request_transform
Browse files Browse the repository at this point in the history
Correctly encode Duration and optional fields
  • Loading branch information
Hywan authored Jan 22, 2024
2 parents 9606daa + 8cff979 commit 33d531a
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# matrix-sdk-crypto-wasm v4.0.0

- Properly encode missing and `Duration` parameters in requests.
([#72](https://github.com/matrix-org/matrix-rust-sdk-crypto-wasm/pull/72))

**BREAKING CHANGES**

- Rename `OlmMachine.init_from_store` introduced in v3.6.0 to
Expand Down
31 changes: 31 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ tracing-subscriber = { version = "0.3.14", default-features = false, features =
wasm-bindgen = "0.2.89"
wasm-bindgen-futures = "0.4.33"
zeroize = "1.6.0"
wasm-bindgen-test = "0.3.37"

[build-dependencies]
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"lint:types": "tsc --noEmit",
"build": "WASM_PACK_ARGS=--release ./scripts/build.sh",
"build:dev": "WASM_PACK_ARGS=--dev ./scripts/build.sh",
"test": "jest --verbose",
"test": "jest --verbose && yarn run wasm-pack test --node",
"doc": "typedoc --tsconfig .",
"prepack": "npm run build && npm run test"
}
Expand Down
116 changes: 109 additions & 7 deletions src/requests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! Types to handle requests.

use std::time::Duration;

use js_sys::JsString;
use matrix_sdk_common::ruma::{
api::client::keys::{
Expand All @@ -8,6 +10,7 @@ use matrix_sdk_common::ruma::{
upload_signatures::v3::Request as OriginalSignatureUploadRequest,
},
events::EventContent,
exports::serde::ser::Error,
};
use matrix_sdk_crypto::{
requests::{
Expand Down Expand Up @@ -316,7 +319,7 @@ macro_rules! request {
(
$destination_request:ident from $source_request:ident
$( extracts $( $field_name:ident : $field_type:tt ),+ $(,)? )?
$( $( and )? groups $( $grouped_field_name:ident ),+ $(,)? )?
$( $( and )? groups $( $grouped_field_name:ident $( { $transformation:expr } )? $( $optional:literal )? ),+ $(,)? )?
) => {

impl TryFrom<(String, &$source_request)> for $destination_request {
Expand All @@ -329,7 +332,7 @@ macro_rules! request {
@__try_from $destination_request from $source_request
(request_id = request_id.into(), request = request)
$( extracts [ $( $field_name : $field_type, )+ ] )?
$( groups [ $( $grouped_field_name, )+ ] )?
$( groups [ $( $grouped_field_name $( { $transformation } )? $( $optional )? , )+ ] )?
)
}
}
Expand All @@ -339,7 +342,7 @@ macro_rules! request {
@__try_from $destination_request:ident from $source_request:ident
(request_id = $request_id:expr, request = $request:expr)
$( extracts [ $( $field_name:ident : $field_type:tt ),* $(,)? ] )?
$( groups [ $( $grouped_field_name:ident ),* $(,)? ] )?
$( groups [ $( $grouped_field_name:ident $( { $transformation:expr } )? $( $optional:literal )? ),* $(,)? ] )?
) => {
{
Ok($destination_request {
Expand All @@ -353,7 +356,15 @@ macro_rules! request {
body: {
let mut map = serde_json::Map::new();
$(
map.insert(stringify!($grouped_field_name).to_owned(), serde_json::to_value(&$request.$grouped_field_name).unwrap());
let field = &$request.$grouped_field_name;
$(
let field = {
let $grouped_field_name = field;

$transformation
};
)?
request!(@__set_field $( $optional )? map : $grouped_field_name = field);
)*
let object = serde_json::Value::Object(map);

Expand All @@ -379,15 +390,25 @@ macro_rules! request {
( @__field_type as event_type ; request = $request:expr, field_name = $field_name:ident ) => {
$request.content.event_type().to_string().into()
};

( @__set_field $optional:literal $map:ident : $grouped_field_name:ident = $field:ident) => {
if let Some($field) = $field {
request!(@__set_field $map : $grouped_field_name = $field);
}
};

( @__set_field $map:ident : $grouped_field_name:ident = $field:ident) => {
$map.insert(stringify!($grouped_field_name).to_owned(), serde_json::to_value($field).unwrap());
};
}

// Generate the methods needed to convert rust `OutgoingRequests` into the js
// counterpart. Technically it's converting tuples `(String, &Original)`, where
// the first member is the request ID, into js requests. Used by
// `TryFrom<OutgoingRequest> for JsValue`.
request!(KeysUploadRequest from OriginalKeysUploadRequest groups device_keys, one_time_keys, fallback_keys);
request!(KeysQueryRequest from OriginalKeysQueryRequest groups timeout, device_keys);
request!(KeysClaimRequest from OriginalKeysClaimRequest groups timeout, one_time_keys);
request!(KeysUploadRequest from OriginalKeysUploadRequest groups device_keys "optional", one_time_keys, fallback_keys);
request!(KeysQueryRequest from OriginalKeysQueryRequest groups timeout { timeout.as_ref().map(Duration::as_millis).map(u64::try_from).transpose().map_err(serde_json::Error::custom)? } "optional", device_keys);
request!(KeysClaimRequest from OriginalKeysClaimRequest groups timeout { timeout.as_ref().map(Duration::as_millis).map(u64::try_from).transpose().map_err(serde_json::Error::custom)? } "optional", one_time_keys);
request!(ToDeviceRequest from OriginalToDeviceRequest extracts event_type: string, txn_id: string and groups messages);
request!(RoomMessageRequest from OriginalRoomMessageRequest extracts room_id: string, txn_id: string, event_type: event_type, content: json);
request!(KeysBackupRequest from OriginalKeysBackupRequest extracts version: string and groups rooms);
Expand Down Expand Up @@ -619,3 +640,84 @@ impl TryFrom<OriginalCrossSigningBootstrapRequests> for CrossSigningBootstrapReq
})
}
}

#[cfg(test)]
pub(crate) mod tests {
use std::collections::BTreeMap;

use matrix_sdk_common::ruma::{
api::client::keys::{
claim_keys::v3::Request as OriginalKeysClaimRequest,
upload_keys::v3::Request as OriginalKeysUploadRequest,
},
device_id, user_id, DeviceKeyAlgorithm,
};
use matrix_sdk_crypto::requests::KeysQueryRequest as OriginalKeysQueryRequest;
use serde_json::Value;
use wasm_bindgen_test::wasm_bindgen_test;

use super::{KeysClaimRequest, KeysQueryRequest, KeysUploadRequest};

#[wasm_bindgen_test]
// make sure that the timeout in a /keys/claim request is encoded as a number
fn test_keys_claim_request_with_timeout() {
let rust_request = OriginalKeysClaimRequest::new(BTreeMap::from([(
user_id!("@alice:localhost").to_owned(),
BTreeMap::from([(
device_id!("ABCDEFG").to_owned(),
DeviceKeyAlgorithm::SignedCurve25519,
)]),
)]));
let request = KeysClaimRequest::try_from(("ID".to_string(), &rust_request)).unwrap();
let body: Value = serde_json::from_str(&String::from(request.body)).unwrap();
assert!(body.as_object().unwrap().contains_key("timeout"));
assert!(body["timeout"].is_number());
}

#[wasm_bindgen_test]
// if a /keys/claim request has no timeout, make sure it isn't in the request
fn test_keys_claim_request_without_timeout() {
let mut rust_request = OriginalKeysClaimRequest::new(BTreeMap::from([(
user_id!("@alice:localhost").to_owned(),
BTreeMap::from([(
device_id!("ABCDEFG").to_owned(),
DeviceKeyAlgorithm::SignedCurve25519,
)]),
)]));
rust_request.timeout = None;
let request = KeysClaimRequest::try_from(("ID".to_string(), &rust_request)).unwrap();
let body: Value = serde_json::from_str(&String::from(request.body)).unwrap();
assert!(!body.as_object().unwrap().contains_key("timeout"));
}

#[wasm_bindgen_test]
// make sure that the timeout is encoded as a number in a /keys/query
fn test_keys_query_request_with_timeout() {
let rust_request = OriginalKeysQueryRequest {
timeout: Some(std::time::Duration::from_secs(10)),
device_keys: BTreeMap::new(),
};
let request = KeysQueryRequest::try_from(("ID".to_string(), &rust_request)).unwrap();
let body: Value = serde_json::from_str(&String::from(request.body)).unwrap();
assert!(body.as_object().unwrap().contains_key("timeout"));
assert!(body["timeout"].is_number());
}

#[wasm_bindgen_test]
// if a /keys/query request has no timeout, make sure it isn't in the request
fn test_keys_query_request_without_timeout() {
let rust_request = OriginalKeysQueryRequest { timeout: None, device_keys: BTreeMap::new() };
let request = KeysQueryRequest::try_from(("ID".to_string(), &rust_request)).unwrap();
let body: Value = serde_json::from_str(&String::from(request.body)).unwrap();
assert!(!body.as_object().unwrap().contains_key("timeout"));
}

#[wasm_bindgen_test]
// if a /keys/upload request no device_keys, make sure it isn't in the request
fn test_keys_upload_request_without_devices() {
let request = OriginalKeysUploadRequest::new();
let request = KeysUploadRequest::try_from(("ID".to_string(), &request)).unwrap();
let body: Value = serde_json::from_str(&String::from(request.body)).unwrap();
assert!(!body.as_object().unwrap().contains_key("device_keys"));
}
}
3 changes: 2 additions & 1 deletion tests/machine.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ describe(OlmMachine.name, () => {
expect(outgoingRequests[1].body).toBeDefined();

const body = JSON.parse(outgoingRequests[1].body);
expect(body.timeout).toBeDefined();
// default timeout in Rust is None, so timeout will be omitted
expect(body.timeout).not.toBeDefined();
expect(body.device_keys).toBeDefined();
}
});
Expand Down

0 comments on commit 33d531a

Please sign in to comment.