diff --git a/packages/core/src/routes/interaction/verifications/mfa-verification.test.ts b/packages/core/src/routes/interaction/verifications/mfa-verification.test.ts index 29d158938..97ab51ae5 100644 --- a/packages/core/src/routes/interaction/verifications/mfa-verification.test.ts +++ b/packages/core/src/routes/interaction/verifications/mfa-verification.test.ts @@ -22,10 +22,12 @@ const { jest } = import.meta; const { mockEsmWithActual } = createMockUtils(jest); const findUserById = jest.fn(); +const updateUserById = jest.fn(); const tenantContext = new MockTenant(undefined, { users: { findUserById, + updateUserById, }, }); @@ -48,7 +50,7 @@ const baseCtx = { signInExperience: { ...mockSignInExperience, mfa: { - factors: [], + factors: [MfaFactor.TOTP], policy: MfaPolicy.UserControlled, }, }, @@ -150,8 +152,35 @@ describe('validateMandatoryBindMfa', () => { ); }); - it('user mfaVerifications and bindMfa missing and not required should pass', async () => { + it('user mfaVerifications and bindMfa missing, and not required should throw (for skip)', async () => { findUserById.mockResolvedValueOnce(mockUser); + await expect( + validateMandatoryBindMfa(tenantContext, baseCtx, signInInteraction) + ).rejects.toMatchError( + new RequestError( + { + code: 'user.missing_mfa', + status: 422, + }, + { availableFactors: [MfaFactor.TOTP], skippable: true } + ) + ); + expect(updateUserById).toHaveBeenCalledWith(signInInteraction.accountId, { + customData: { + mfa: { + skipped: true, + }, + }, + }); + }); + + it('user mfaVerifications and bindMfa missing, mark skipped, and not required should pass', async () => { + findUserById.mockResolvedValueOnce({ + ...mockUser, + customData: { + mfa: { skipped: true }, + }, + }); await expect( validateMandatoryBindMfa(tenantContext, baseCtx, signInInteraction) ).resolves.not.toThrow(); diff --git a/packages/core/src/routes/interaction/verifications/mfa-verification.ts b/packages/core/src/routes/interaction/verifications/mfa-verification.ts index a142a1b1e..079a07a77 100644 --- a/packages/core/src/routes/interaction/verifications/mfa-verification.ts +++ b/packages/core/src/routes/interaction/verifications/mfa-verification.ts @@ -1,7 +1,8 @@ -import { InteractionEvent, MfaFactor, MfaPolicy } from '@logto/schemas'; +import { InteractionEvent, MfaFactor, MfaPolicy, type JsonObject } from '@logto/schemas'; import { deduplicate } from '@silverhand/essentials'; import { type Context } from 'koa'; import type Provider from 'oidc-provider'; +import { z } from 'zod'; import RequestError from '#src/errors/RequestError/index.js'; import type TenantContext from '#src/tenants/TenantContext.js'; @@ -73,6 +74,92 @@ export const verifyMfa = async ( return interaction; }; +const userMfaDataKey = 'mfa'; +/** + * Check if the user has skipped MFA binding + */ +const isMfaSkipped = (customData: JsonObject): boolean => { + const userMfaDataGuard = z.object({ + skipped: z.boolean().optional(), + }); + + const parsed = z.object({ [userMfaDataKey]: userMfaDataGuard }).safeParse(customData); + + return parsed.success ? parsed.data[userMfaDataKey].skipped === true : false; +}; + +const validateMandatoryBindMfaForSignIn = async ( + tenant: TenantContext, + ctx: WithInteractionSieContext & WithInteractionDetailsContext, + interaction: VerifiedSignInInteractionResult +): Promise => { + const { + mfa: { policy, factors }, + } = ctx.signInExperience; + const { bindMfas } = interaction; + const availableFactors = factors.filter((factor) => factor !== MfaFactor.BackupCode); + + // No available MFA, skip check + if (availableFactors.length === 0) { + return interaction; + } + + // If the user has linked new MFA in current interaction + const hasFactorInBindMfas = Boolean( + bindMfas && + availableFactors.some((factor) => bindMfas.some((bindMfa) => bindMfa.type === factor)) + ); + + const { accountId } = interaction; + const { mfaVerifications, customData } = await tenant.queries.users.findUserById(accountId); + + // If the user has linked MFA before + const hasFactorInUser = factors.some((factor) => + mfaVerifications.some(({ type }) => type === factor) + ); + + // MFA is bound in current interaction or MFA is bound before, skip check + if (hasFactorInBindMfas || hasFactorInUser) { + return interaction; + } + + // Mandatory, can not skip, throw error + if (policy === MfaPolicy.Mandatory) { + throw new RequestError( + { + code: 'user.missing_mfa', + status: 422, + }, + { availableFactors } + ); + } + + if (isMfaSkipped(customData)) { + return interaction; + } + + if (!isMfaSkipped(customData)) { + // Update user custom data to skip MFA binding + // that means that this prompt is only shown once + await tenant.queries.users.updateUserById(accountId, { + customData: { + ...customData, + [userMfaDataKey]: { + skipped: true, + }, + }, + }); + } + + throw new RequestError( + { + code: 'user.missing_mfa', + status: 422, + }, + { availableFactors, skippable: true } + ); +}; + export const validateMandatoryBindMfa = async ( tenant: TenantContext, ctx: WithInteractionSieContext & WithInteractionDetailsContext, @@ -84,7 +171,8 @@ export const validateMandatoryBindMfa = async ( const { event, bindMfas } = interaction; const availableFactors = factors.filter((factor) => factor !== MfaFactor.BackupCode); - if (policy !== MfaPolicy.Mandatory) { + // No available MFA, skip check + if (availableFactors.length === 0) { return interaction; } @@ -94,6 +182,10 @@ export const validateMandatoryBindMfa = async ( ); if (event === InteractionEvent.Register) { + if (policy !== MfaPolicy.Mandatory) { + return interaction; + } + assertThat( hasFactorInBind, new RequestError( @@ -107,21 +199,7 @@ export const validateMandatoryBindMfa = async ( } if (event === InteractionEvent.SignIn) { - const { accountId } = interaction; - const { mfaVerifications } = await tenant.queries.users.findUserById(accountId); - const hasFactorInUser = factors.some((factor) => - mfaVerifications.some(({ type }) => type === factor) - ); - assertThat( - hasFactorInBind || hasFactorInUser, - new RequestError( - { - code: 'user.missing_mfa', - status: 422, - }, - { availableFactors } - ) - ); + return validateMandatoryBindMfaForSignIn(tenant, ctx, interaction); } return interaction; diff --git a/packages/integration-tests/src/helpers/sign-in-experience.ts b/packages/integration-tests/src/helpers/sign-in-experience.ts index 0d76fdda9..3337a832c 100644 --- a/packages/integration-tests/src/helpers/sign-in-experience.ts +++ b/packages/integration-tests/src/helpers/sign-in-experience.ts @@ -75,6 +75,14 @@ export const enableAllVerificationCodeSignInMethods = async ( mfa: { factors: [], policy: MfaPolicy.UserControlled }, }); +export const enableUserControlledMfaWithTotp = async () => + updateSignInExperience({ + mfa: { + factors: [MfaFactor.TOTP], + policy: MfaPolicy.UserControlled, + }, + }); + export const enableMandatoryMfaWithTotp = async () => updateSignInExperience({ mfa: { diff --git a/packages/integration-tests/src/tests/api/interaction/mfa/totp.test.ts b/packages/integration-tests/src/tests/api/interaction/mfa/totp.test.ts index 47db8a0a2..6cb710d81 100644 --- a/packages/integration-tests/src/tests/api/interaction/mfa/totp.test.ts +++ b/packages/integration-tests/src/tests/api/interaction/mfa/totp.test.ts @@ -13,6 +13,7 @@ import { expectRejects } from '#src/helpers/index.js'; import { enableAllPasswordSignInMethods, enableMandatoryMfaWithTotp, + enableUserControlledMfaWithTotp, } from '#src/helpers/sign-in-experience.js'; import { generateNewUser, generateNewUserProfile } from '#src/helpers/user.js'; @@ -161,6 +162,59 @@ describe('sign in and fulfill mfa (mandatory TOTP)', () => { }); }); +describe('sign in and fulfill mfa (user-controlled TOTP)', () => { + beforeAll(async () => { + await enableAllPasswordSignInMethods({ + identifiers: [SignInIdentifier.Username], + password: true, + verify: false, + }); + await enableUserControlledMfaWithTotp(); + }); + + it('should fail with missing_mfa error for normal sign in', async () => { + const { userProfile, user } = await generateNewUser({ username: true, password: true }); + const client = await initClient(); + + await client.successSend(putInteraction, { + event: InteractionEvent.SignIn, + identifier: { + username: userProfile.username, + password: userProfile.password, + }, + }); + + await expectRejects(client.submitInteraction(), { + code: 'user.missing_mfa', + statusCode: 422, + }); + + await deleteUser(user.id); + }); + + it('should sign in and skip totp', async () => { + const { userProfile, user } = await generateNewUser({ username: true, password: true }); + const client = await initClient(); + + await client.successSend(putInteraction, { + event: InteractionEvent.SignIn, + identifier: { + username: userProfile.username, + password: userProfile.password, + }, + }); + + await expectRejects(client.submitInteraction(), { + code: 'user.missing_mfa', + statusCode: 422, + }); + + // Try again, should auto skip + await client.submitInteraction(); + await deleteUser(user.id); + }); +}); + describe('sign in and verify mfa (TOTP)', () => { beforeAll(async () => { await enableAllPasswordSignInMethods({