Skip to content

Commit

Permalink
Fix report keys import progress
Browse files Browse the repository at this point in the history
  • Loading branch information
BillCarsonFr committed Mar 25, 2024
1 parent d5a35f8 commit d75b11c
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 16 deletions.
64 changes: 56 additions & 8 deletions spec/integ/crypto/megolm-backup.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import { KeyBackupInfo, KeyBackupSession } from "../../../src/crypto-api/keyback
import { IKeyBackup } from "../../../src/crypto/backup";
import { flushPromises } from "../../test-utils/flushPromises";
import { defer, IDeferred } from "../../../src/utils";
import { ImportRoomKeysOpts } from "../../../src/crypto-api";

const ROOM_ID = testData.TEST_ROOM_ID;

Expand Down Expand Up @@ -298,6 +299,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe

describe("recover from backup", () => {
let aliceCrypto: CryptoApi;
let importMockImpl: jest.Mock;

beforeEach(async () => {
fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA);
Expand All @@ -309,6 +311,20 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe
// tell Alice to trust the dummy device that signed the backup
await waitForDeviceList();
await aliceCrypto.setDeviceVerified(testData.TEST_USER_ID, testData.TEST_DEVICE_ID);

importMockImpl = jest.fn().mockImplementation((keys: IMegolmSessionData[], opts?: ImportRoomKeysOpts) => {
// need to report progress
if (opts?.progressCallback) {
opts.progressCallback({
stage: "load_keys",
successes: keys.length,
failures: 0,
total: keys.length,
});
}
});
// @ts-ignore - mock a private method for testing purpose
aliceCrypto.importBackedUpRoomKeys = importMockImpl;
});

it("can restore from backup (Curve25519 version)", async function () {
Expand Down Expand Up @@ -384,10 +400,6 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe
}

it("Should import full backup in chunks", async function () {
const importMockImpl = jest.fn();
// @ts-ignore - mock a private method for testing purpose
aliceCrypto.importBackedUpRoomKeys = importMockImpl;

// We need several rooms with several sessions to test chunking
const { response, expectedTotal } = createBackupDownloadResponse([45, 300, 345, 12, 130]);

Expand Down Expand Up @@ -446,7 +458,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe
throw new Error("test error");
})
// Ok for other chunks
.mockResolvedValue(undefined);
.mockImplementation(importMockImpl);

const { response, expectedTotal } = createBackupDownloadResponse([100, 300]);

Expand Down Expand Up @@ -485,9 +497,6 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe
});

it("Should continue if some keys fails to decrypt", async function () {
// @ts-ignore - mock a private method for testing purpose
aliceCrypto.importBackedUpRoomKeys = jest.fn();

const decryptionFailureCount = 2;

const mockDecryptor = {
Expand Down Expand Up @@ -527,6 +536,45 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe
expect(result.imported).toStrictEqual(expectedTotal - decryptionFailureCount);
});

it("Should report failures when decryption works but import fails", async function () {
// @ts-ignore - mock a private method for testing purpose
aliceCrypto.importBackedUpRoomKeys = jest
.fn()
.mockImplementationOnce((keys: IMegolmSessionData[], opts?: ImportRoomKeysOpts) => {
// report 10 failures to import
opts!.progressCallback!({
stage: "load_keys",
successes: 20,
failures: 10,
total: 30,
});
return Promise.resolve();
})
// Ok for other chunks
.mockResolvedValue(importMockImpl);

const { response, expectedTotal } = createBackupDownloadResponse([30]);

fetchMock.get("express:/_matrix/client/v3/room_keys/keys", response);

const check = await aliceCrypto.checkKeyBackupAndEnable();

const progressCallback = jest.fn();
const result = await aliceClient.restoreKeyBackupWithRecoveryKey(
testData.BACKUP_DECRYPTION_KEY_BASE58,
undefined,
undefined,
check!.backupInfo!,
{
progressCallback,
},
);

expect(result.total).toStrictEqual(expectedTotal);
// A chunk failed to import
expect(result.imported).toStrictEqual(20);
});

it("recover specific session from backup", async function () {
fetchMock.get(
"express:/_matrix/client/v3/room_keys/keys/:room_id/:session_id",
Expand Down
43 changes: 36 additions & 7 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,13 @@ import { LocalNotificationSettings } from "./@types/local_notifications";
import { buildFeatureSupportMap, Feature, ServerSupport } from "./feature";
import { BackupDecryptor, CryptoBackend } from "./common-crypto/CryptoBackend";
import { RUST_SDK_STORE_PREFIX } from "./rust-crypto/constants";
import { BootstrapCrossSigningOpts, CrossSigningKeyInfo, CryptoApi, ImportRoomKeysOpts } from "./crypto-api";
import {
BootstrapCrossSigningOpts,
CrossSigningKeyInfo,
CryptoApi,
ImportRoomKeyProgressData,
ImportRoomKeysOpts,
} from "./crypto-api";
import { DeviceInfoMap } from "./crypto/DeviceList";
import {
AddSecretStorageKeyOpts,
Expand Down Expand Up @@ -3923,10 +3929,18 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
async (chunk) => {
// We have a chunk of decrypted keys: import them
try {
let success = 0;
let failures = 0;
const partialProgress = (stage: ImportRoomKeyProgressData): void => {
success = stage.successes ?? 0;
failures = stage.failures ?? 0;
};
await this.cryptoBackend!.importBackedUpRoomKeys(chunk, {
untrusted,
progressCallback: partialProgress,
});
totalImported += chunk.length;
totalImported += success;
totalFailures += failures;
} catch (e) {
totalFailures += chunk.length;
// We failed to import some keys, but we should still try to import the rest?
Expand All @@ -3953,11 +3967,25 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
for (const k of keys) {
k.room_id = targetRoomId!;
}
await this.cryptoBackend.importBackedUpRoomKeys(keys, {
progressCallback,
untrusted,
});
totalImported = keys.length;
try {
let success = 0;
let failures = 0;
const partialProgress = (stage: ImportRoomKeyProgressData): void => {
success = stage.successes ?? 0;
failures = stage.failures ?? 0;
};
await this.cryptoBackend!.importBackedUpRoomKeys(chunk, {

Check failure on line 3977 in src/client.ts

View workflow job for this annotation

GitHub Actions / Typescript Syntax Check

Cannot find name 'chunk'.
untrusted,
progressCallback: partialProgress,
});
totalImported += success;
totalFailures += failures;
} catch (e) {
totalFailures += keys.length;
// We failed to import some keys, but we should still try to import the rest?
// Log the error and continue
logger.error("Error importing keys from backup", e);
}
} else {
totalKeyCount = 1;
try {
Expand All @@ -3973,6 +4001,7 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
});
totalImported = 1;
} catch (e) {
totalFailures = 1;
this.logger.debug("Failed to decrypt megolm session from backup", e);
}
}
Expand Down
13 changes: 12 additions & 1 deletion src/rust-crypto/backup.ts
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ export class RustBackupManager extends TypedEventEmitter<RustBackupCryptoEvents,
}
keysByRoom.get(roomId)!.set(key.session_id, key);
}
await this.olmMachine.importBackedUpRoomKeys(
const result: RustSdkCryptoJs.RoomKeyImportResult = await this.olmMachine.importBackedUpRoomKeys(
keysByRoom,
(progress: BigInt, total: BigInt, failures: BigInt): void => {
const importOpt: ImportRoomKeyProgressData = {
Expand All @@ -235,6 +235,17 @@ export class RustBackupManager extends TypedEventEmitter<RustBackupCryptoEvents,
opts?.progressCallback?.(importOpt);
},
);
// call the progress callback one last time with the final state
if (opts?.progressCallback) {
// We use total count here and not imported count.
// Imported count could be 0 if all the keys were already imported.
opts.progressCallback({
total: result.totalCount,
successes: result.totalCount,
stage: "load_keys",
failures: keys.length - result.totalCount,
});
}
}

private keyBackupCheckInProgress: Promise<KeyBackupCheck | null> | null = null;
Expand Down

0 comments on commit d75b11c

Please sign in to comment.