Skip to content

Commit

Permalink
feat: maintain mls group list during migration [WPB-1115] (#15318)
Browse files Browse the repository at this point in the history
* feat: wipe mls group if user is removed / leave mls-capable conversation

* feat: add users to mls group when conversation is mixed

* feat: restart periodic key material timers on app reload

* test: adding users to mls/mixed/proteus group

* test: add users to mls group

* runfix: joining mls capable conversations

* test: remove / leave conversation

* runfix: add users to mixed conversation

* runfix: show unestablished mixed conversations

* refactor: test

* refactor: apply cr suggestion

* refactor: add MLSCapableConversation type
  • Loading branch information
PatrykBuniX committed Jun 20, 2023
1 parent 8cf3517 commit c62fa8c
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 72 deletions.
3 changes: 3 additions & 0 deletions src/__mocks__/@wireapp/core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ export class Account extends EventEmitter {
isMLSConversationEstablished: jest.fn(),
joinByExternalCommit: jest.fn(),
addUsersToMLSConversation: jest.fn(),
removeUserFromConversation: jest.fn(),
removeUsersFromMLSConversation: jest.fn(),
addUsersToProteusConversation: jest.fn(),
messageTimer: {
setConversationLevelTimer: jest.fn(),
},
Expand Down
216 changes: 178 additions & 38 deletions src/script/conversation/ConversationRepository.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,14 @@ import {LegacyEventRecord, StorageService} from '../storage';

jest.deepUnmock('axios');

const mlsCapableProtocols = [ConversationProtocol.MLS, ConversationProtocol.MIXED];

const _generateConversation = (
conversation_type = CONVERSATION_TYPE.REGULAR,
connection_status = ConnectionStatus.ACCEPTED,
conversationProtocol = ConversationProtocol.PROTEUS,
domain = '',
groupId = 'groupId',
) => {
const conversation = new Conversation(createUuid(), domain, conversationProtocol);
conversation.type(conversation_type);
Expand All @@ -85,8 +88,8 @@ const _generateConversation = (
connectionEntity.status(connection_status);
conversation.connection(connectionEntity);

if (conversationProtocol === ConversationProtocol.MLS) {
conversation.groupId = 'groupId';
if (mlsCapableProtocols.includes(conversationProtocol)) {
conversation.groupId = groupId;
}

return conversation;
Expand Down Expand Up @@ -816,45 +819,41 @@ describe('ConversationRepository', () => {
});
});

it('should add other self clients to mls group if user was event creator', () => {
const mockDomain = 'example.com';
const mockSelfClientId = 'self-client-id';
const selfUser = generateUser({id: createUuid(), domain: mockDomain});
it.each([ConversationProtocol.MIXED, ConversationProtocol.MLS])(
'should add other self clients to mls/mixed conversation MLS group if user was event creator',
protocol => {
const mockDomain = 'example.com';
const mockSelfClientId = 'self-client-id';
const selfUser = generateUser({id: createUuid(), domain: mockDomain});

const conversationEntity = _generateConversation(
CONVERSATION_TYPE.REGULAR,
undefined,
ConversationProtocol.MLS,
mockDomain,
);
testFactory.conversation_repository['saveConversation'](conversationEntity);
const conversationEntity = _generateConversation(CONVERSATION_TYPE.REGULAR, undefined, protocol, mockDomain);
testFactory.conversation_repository['saveConversation'](conversationEntity);

const memberJoinEvent = {
conversation: conversationEntity.id,
data: {
user_ids: [selfUser.id],
},
from: selfUser.id,
time: '2015-04-27T11:42:31.475Z',
type: CONVERSATION_EVENT.MEMBER_JOIN,
} as ConversationMemberJoinEvent;

spyOn(testFactory.conversation_repository['userState'], 'self').and.returnValue(selfUser);

Object.defineProperty(container.resolve(Core), 'clientId', {
get: jest.fn(() => mockSelfClientId),
});

return testFactory.conversation_repository['handleConversationEvent'](memberJoinEvent).then(() => {
expect(testFactory.conversation_repository['onMemberJoin']).toHaveBeenCalled();
expect(testFactory.conversation_repository.updateParticipatingUserEntities).toHaveBeenCalled();
expect(container.resolve(Core).service!.conversation.addUsersToMLSConversation).toHaveBeenCalledWith({
conversationId: conversationEntity.qualifiedId,
groupId: 'groupId',
qualifiedUsers: [{domain: mockDomain, id: selfUser.id, skipOwnClientId: mockSelfClientId}],
const memberJoinEvent = {
conversation: conversationEntity.id,
data: {
user_ids: [selfUser.id],
},
from: selfUser.id,
time: '2015-04-27T11:42:31.475Z',
type: CONVERSATION_EVENT.MEMBER_JOIN,
} as ConversationMemberJoinEvent;

spyOn(testFactory.conversation_repository['userState'], 'self').and.returnValue(selfUser);

container.resolve(Core).clientId = mockSelfClientId;

return testFactory.conversation_repository['handleConversationEvent'](memberJoinEvent).then(() => {
expect(testFactory.conversation_repository['onMemberJoin']).toHaveBeenCalled();
expect(testFactory.conversation_repository.updateParticipatingUserEntities).toHaveBeenCalled();
expect(container.resolve(Core).service!.conversation.addUsersToMLSConversation).toHaveBeenCalledWith({
conversationId: conversationEntity.qualifiedId,
groupId: 'groupId',
qualifiedUsers: [{domain: mockDomain, id: selfUser.id, skipOwnClientId: mockSelfClientId}],
});
});
});
});
},
);

it('should ignore member-join event when joining a 1to1 conversation', () => {
const selfUser = generateUser();
Expand Down Expand Up @@ -1626,4 +1625,145 @@ describe('ConversationRepository', () => {
expect(updatedConversation.epoch).toEqual(newEpoch);
});
});

describe('addUsers', () => {
it('should add users to proteus conversation', async () => {
const conversation = _generateConversation();
const conversationRepository = await testFactory.exposeConversationActors();

const usersToAdd = [generateUser(), generateUser()];

const coreConversationService = container.resolve(Core).service!.conversation;
spyOn(coreConversationService, 'addUsersToProteusConversation');

await conversationRepository.addUsers(conversation, usersToAdd);
expect(coreConversationService.addUsersToProteusConversation).toHaveBeenCalledWith({
conversationId: conversation.qualifiedId,
qualifiedUsers: usersToAdd.map(user => user.qualifiedId),
});
});

it('should add users to mls group of mixed conversation', async () => {
const mockedGroupId = `mockedGroupId`;
const conversation = _generateConversation(undefined, undefined, ConversationProtocol.MIXED, '', mockedGroupId);
const conversationRepository = await testFactory.exposeConversationActors();

const usersToAdd = [generateUser(), generateUser()];

const coreConversationService = container.resolve(Core).service!.conversation;
spyOn(coreConversationService, 'addUsersToMLSConversation');

await conversationRepository.addUsers(conversation, usersToAdd);
expect(coreConversationService.addUsersToProteusConversation).toHaveBeenCalledWith({
conversationId: conversation.qualifiedId,
qualifiedUsers: usersToAdd.map(user => user.qualifiedId),
});
expect(coreConversationService.addUsersToMLSConversation).toHaveBeenCalledWith({
conversationId: conversation.qualifiedId,
qualifiedUsers: usersToAdd.map(user => user.qualifiedId),
groupId: mockedGroupId,
});
});

it('should add users to mls group of mls conversation', async () => {
const mockedGroupId = `mockedGroupId`;
const conversation = _generateConversation(undefined, undefined, ConversationProtocol.MLS, '', mockedGroupId);
const conversationRepository = await testFactory.exposeConversationActors();

const usersToAdd = [generateUser(), generateUser()];

const coreConversationService = container.resolve(Core).service!.conversation;
spyOn(coreConversationService, 'addUsersToMLSConversation');

await conversationRepository.addUsers(conversation, usersToAdd);
expect(coreConversationService.addUsersToMLSConversation).toHaveBeenCalledWith({
conversationId: conversation.qualifiedId,
qualifiedUsers: usersToAdd.map(user => user.qualifiedId),
groupId: mockedGroupId,
});
});
});

describe('removeMember', () => {
it.each([ConversationProtocol.PROTEUS, ConversationProtocol.MIXED])(
'should remove member from %s conversation',
async protocol => {
const conversationRepository = await testFactory.exposeConversationActors();

const conversation = _generateConversation(undefined, undefined, protocol);

const selfUser = generateUser();
conversation.selfUser(selfUser);

const user1 = generateUser();
const user2 = generateUser();

conversation.participating_user_ets([user1, user2]);

const coreConversationService = container.resolve(Core).service!.conversation;

await conversationRepository.removeMember(conversation, user1.qualifiedId);

expect(coreConversationService.removeUserFromConversation).toHaveBeenCalledWith(
conversation.qualifiedId,
user1.qualifiedId,
);
},
);

it('should remove member from mls conversation', async () => {
const conversationRepository = await testFactory.exposeConversationActors();

const mockGroupId = 'mockGroupId';
const conversation = _generateConversation(undefined, undefined, ConversationProtocol.MLS, '', mockGroupId);

const selfUser = generateUser();
conversation.selfUser(selfUser);

const user1 = generateUser();
const user2 = generateUser();

conversation.participating_user_ets([user1, user2]);

const coreConversationService = container.resolve(Core).service!.conversation;

jest
.spyOn(coreConversationService, 'removeUsersFromMLSConversation')
.mockResolvedValueOnce({events: [], conversation: {} as BackendConversation});
await conversationRepository.removeMember(conversation, user1.qualifiedId);

expect(coreConversationService.removeUsersFromMLSConversation).toHaveBeenCalledWith({
conversationId: conversation.qualifiedId,
qualifiedUserIds: [user1.qualifiedId],
groupId: mockGroupId,
});
});

describe('leaveConversation', () => {
it.each([ConversationProtocol.PROTEUS, ConversationProtocol.MIXED, ConversationProtocol.MLS])(
'should leave %s conversation',
async protocol => {
const conversationRepository = await testFactory.exposeConversationActors();

const conversation = _generateConversation(undefined, undefined, protocol);

const selfUser = generateUser();
conversation.selfUser(selfUser);

spyOn(testFactory.conversation_repository['userState'], 'self').and.returnValue(selfUser);

conversation.participating_user_ets([generateUser(), generateUser()]);

const coreConversationService = container.resolve(Core).service!.conversation;

await conversationRepository.removeMember(conversation, selfUser.qualifiedId);

expect(coreConversationService.removeUserFromConversation).toHaveBeenCalledWith(
conversation.qualifiedId,
selfUser.qualifiedId,
);
},
);
});
});
});
40 changes: 18 additions & 22 deletions src/script/conversation/ConversationRepository.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ import {ConversationFilter} from './ConversationFilter';
import {ConversationLabelRepository} from './ConversationLabelRepository';
import {ConversationDatabaseData, ConversationMapper} from './ConversationMapper';
import {ConversationRoleRepository} from './ConversationRoleRepository';
import {isMLSConversation} from './ConversationSelectors';
import {isMLSCapableConversation, isMLSConversation} from './ConversationSelectors';
import {ConversationService} from './ConversationService';
import {ConversationState} from './ConversationState';
import {ConversationStateHandler} from './ConversationStateHandler';
Expand Down Expand Up @@ -952,7 +952,7 @@ export class ConversationRepository {
}
this.deleteConversationFromRepository(conversationId);
await this.conversationService.deleteConversationFromDb(conversationId.id);
if (isMLSConversation(conversationEntity)) {
if (isMLSCapableConversation(conversationEntity)) {
await this.conversationService.wipeMLSConversation(conversationEntity);
}
};
Expand Down Expand Up @@ -1453,25 +1453,25 @@ export class ConversationRepository {

const qualifiedUsers = userEntities.map(userEntity => userEntity.qualifiedId);

const {qualifiedId: conversationId, groupId} = conversation;

try {
if (conversation.isUsingMLSProtocol && groupId) {
const {events} = await this.core.service!.conversation.addUsersToMLSConversation({
conversationId,
groupId,
if (isProteusConversation(conversation) || isMixedConversation(conversation)) {
const conversationMemberJoinEvent = await this.core.service!.conversation.addUsersToProteusConversation({
conversationId: conversation.qualifiedId,
qualifiedUsers,
});
if (!!events.length) {
events.forEach(event => this.eventRepository.injectEvent(event));
if (conversationMemberJoinEvent) {
await this.eventRepository.injectEvent(conversationMemberJoinEvent, EventRepository.SOURCE.BACKEND_RESPONSE);
}
} else {
const conversationMemberJoinEvent = await this.core.service!.conversation.addUsersToProteusConversation({
conversationId,
}

if (isMLSCapableConversation(conversation)) {
const {events} = await this.core.service!.conversation.addUsersToMLSConversation({
conversationId: conversation.qualifiedId,
groupId: conversation.groupId,
qualifiedUsers,
});
if (conversationMemberJoinEvent) {
this.eventRepository.injectEvent(conversationMemberJoinEvent, EventRepository.SOURCE.BACKEND_RESPONSE);
if (!!events.length && isMLSConversation(conversation)) {
events.forEach(event => this.eventRepository.injectEvent(event));
}
}
} catch (error) {
Expand Down Expand Up @@ -2627,7 +2627,7 @@ export class ConversationRepository {
const qualifiedUserIds =
eventData.users?.map(user => user.qualified_id) || eventData.user_ids.map(userId => ({domain: '', id: userId}));

if (conversationEntity.isUsingMLSProtocol) {
if (isMLSCapableConversation(conversationEntity)) {
const isSelfJoin = isFromSelf && selfUserJoins;
await this.handleMLSConversationMemberJoin(conversationEntity, isSelfJoin);
}
Expand All @@ -2647,13 +2647,9 @@ export class ConversationRepository {
* @param conversation Conversation member joined to
* @param isSelfJoin whether user has joined by itself, if so we need to add other self clients to mls group
*/
private async handleMLSConversationMemberJoin(conversation: Conversation, isSelfJoin: boolean) {
private async handleMLSConversationMemberJoin(conversation: MLSCapableConversation, isSelfJoin: boolean) {
const {groupId} = conversation;

if (!groupId) {
throw new Error(`groupId not found for MLS conversation ${conversation.id}`);
}

const isMLSConversationEstablished = await this.core.service!.conversation.isMLSConversationEstablished(groupId);

if (!isMLSConversationEstablished) {
Expand Down Expand Up @@ -2711,7 +2707,7 @@ export class ConversationRepository {
eventJson.from = this.userState.self().id;
}

if (isMLSConversation(conversationEntity)) {
if (isMLSCapableConversation(conversationEntity)) {
await this.conversationService.wipeMLSConversation(conversationEntity);
}
} else {
Expand Down
5 changes: 5 additions & 0 deletions src/script/conversation/ConversationSelectors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {Conversation} from '../entity/Conversation';
export type ProteusConversation = Conversation & {protocol: ConversationProtocol.PROTEUS};
export type MixedConversation = Conversation & {groupId: string; protocol: ConversationProtocol.MIXED};
export type MLSConversation = Conversation & {groupId: string; protocol: ConversationProtocol.MLS};
export type MLSCapableConversation = MixedConversation | MLSConversation;

export function isProteusConversation(conversation: Conversation): conversation is ProteusConversation {
return !conversation.groupId && conversation.protocol === ConversationProtocol.PROTEUS;
Expand All @@ -37,6 +38,10 @@ export function isMLSConversation(conversation: Conversation): conversation is M
return !!conversation.groupId && conversation.protocol === ConversationProtocol.MLS;
}

export function isMLSCapableConversation(conversation: Conversation): conversation is MLSCapableConversation {
return isMixedConversation(conversation) || isMLSConversation(conversation);
}

export function isSelfConversation(conversation: Conversation): boolean {
return conversation.type() === CONVERSATION_TYPE.SELF;
}
Expand Down
6 changes: 3 additions & 3 deletions src/script/conversation/ConversationService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import {container} from 'tsyringe';

import {getLogger, Logger} from 'Util/Logger';

import {MLSConversation} from './ConversationSelectors';
import {MLSCapableConversation, MLSConversation} from './ConversationSelectors';

import type {Conversation as ConversationEntity} from '../entity/Conversation';
import type {EventService} from '../event/EventService';
Expand Down Expand Up @@ -423,8 +423,8 @@ export class ConversationService {
* Wipes MLS conversation in corecrypto and deletes the conversation state.
* @param mlsConversation mls conversation
*/
async wipeMLSConversation(mlsConversation: MLSConversation) {
const {groupId} = mlsConversation;
async wipeMLSCapableConversation(conversation: MLSCapableConversation) {
const {groupId} = conversation;
await this.core.service!.conversation.wipeMLSConversation(groupId);
return useMLSConversationState.getState().wipeConversationState(groupId);
}
Expand Down
Loading

0 comments on commit c62fa8c

Please sign in to comment.