diff --git a/itou/openid_connect/inclusion_connect/models.py b/itou/openid_connect/inclusion_connect/models.py index 99e7e474e9..e64416e4ef 100644 --- a/itou/openid_connect/inclusion_connect/models.py +++ b/itou/openid_connect/inclusion_connect/models.py @@ -19,13 +19,10 @@ class InclusionConnectState(OIDConnectState): @dataclasses.dataclass -class InclusionConnectPrescriberData(OIDConnectUserData): - kind: UserKind = UserKind.PRESCRIBER - identity_provider: IdentityProvider = IdentityProvider.INCLUSION_CONNECT - login_allowed_user_kinds: ClassVar[tuple[UserKind]] = (UserKind.PRESCRIBER, UserKind.EMPLOYER) - allowed_identity_provider_migration: ClassVar[tuple[IdentityProvider]] = (IdentityProvider.DJANGO,) - +class InclusionConnectUserData(OIDConnectUserData): def join_org(self, user: User, safir: str): + if not user.is_prescriber: + raise ValueError("Invalid user kind: %s", user.kind) try: organization = PrescriberOrganization.objects.get(code_safir_pole_emploi=safir) except PrescriberOrganization.DoesNotExist: @@ -36,7 +33,15 @@ def join_org(self, user: User, safir: str): @dataclasses.dataclass -class InclusionConnectEmployerData(OIDConnectUserData): +class InclusionConnectPrescriberData(InclusionConnectUserData): + kind: UserKind = UserKind.PRESCRIBER + identity_provider: IdentityProvider = IdentityProvider.INCLUSION_CONNECT + login_allowed_user_kinds: ClassVar[tuple[UserKind]] = (UserKind.PRESCRIBER, UserKind.EMPLOYER) + allowed_identity_provider_migration: ClassVar[tuple[IdentityProvider]] = (IdentityProvider.DJANGO,) + + +@dataclasses.dataclass +class InclusionConnectEmployerData(InclusionConnectUserData): kind: UserKind = UserKind.EMPLOYER identity_provider: IdentityProvider = IdentityProvider.INCLUSION_CONNECT login_allowed_user_kinds: ClassVar[tuple[UserKind]] = (UserKind.PRESCRIBER, UserKind.EMPLOYER) diff --git a/itou/openid_connect/inclusion_connect/views.py b/itou/openid_connect/inclusion_connect/views.py index 1df6b0d5d6..f6f7d73b04 100644 --- a/itou/openid_connect/inclusion_connect/views.py +++ b/itou/openid_connect/inclusion_connect/views.py @@ -317,7 +317,7 @@ def inclusion_connect_callback(request): code_safir_pole_emploi = user_data.get("structure_pe") # Only handle user creation for the moment, not updates. - if is_successful and code_safir_pole_emploi: + if is_successful and user.is_prescriber and code_safir_pole_emploi: try: ic_user_data.join_org(user=user, safir=code_safir_pole_emploi) except PrescriberOrganization.DoesNotExist: diff --git a/tests/openid_connect/inclusion_connect/tests.py b/tests/openid_connect/inclusion_connect/tests.py index af5d1b0491..dea5a018ed 100644 --- a/tests/openid_connect/inclusion_connect/tests.py +++ b/tests/openid_connect/inclusion_connect/tests.py @@ -593,6 +593,32 @@ def test_callback_update_FT_organization(self): ) self.assertQuerySetEqual(org.members.all(), [user]) + @respx.mock + def test_callback_update_FT_organization_as_employer_does_not_crash(self): + org = PrescriberPoleEmploiFactory(code_safir_pole_emploi=OIDC_USERINFO_FT_WITH_SAFIR["structure_pe"]) + mock_oauth_dance( + self.client, + UserKind.EMPLOYER, + oidc_userinfo=OIDC_USERINFO_FT_WITH_SAFIR.copy(), + ) + user = get_user(self.client) + assert user.is_authenticated + assert not user.prescribermembership_set.exists() + + # If he's a prescriber and uses the employer login button + user.kind = UserKind.PRESCRIBER + user.save() + self.client.logout() + mock_oauth_dance( + self.client, + UserKind.EMPLOYER, + oidc_userinfo=OIDC_USERINFO_FT_WITH_SAFIR.copy(), + register=False, + ) + user = get_user(self.client) + assert user.is_authenticated + self.assertQuerySetEqual(org.members.all(), [user]) + @respx.mock def test_callback_ft_users_with_no_org(self): PrescriberFactory(**dataclasses.asdict(InclusionConnectPrescriberData.from_user_info(OIDC_USERINFO)))