From 6222307ce8b292dc3a11b2d31041933be07abe39 Mon Sep 17 00:00:00 2001 From: Darcy Ye Date: Tue, 10 Dec 2024 18:40:59 +0800 Subject: [PATCH] feat(core): add SAML auth request handling endpoints --- packages/core/src/constants/index.ts | 1 + .../src/routes/swagger/utils/operation.ts | 8 +- .../libraries/saml-applications.ts | 4 + .../src/saml-applications/queries/index.ts | 82 +++++++ .../src/saml-applications/routes/anonymous.ts | 229 +++++++++++++++++- .../saml-applications/routes/utils.test.ts | 25 +- .../src/saml-applications/routes/utils.ts | 153 ++++++++++-- packages/core/src/tenants/Queries.ts | 2 + .../src/locales/en/errors/application.ts | 7 +- .../phrases/src/locales/en/errors/oidc.ts | 1 + 10 files changed, 457 insertions(+), 55 deletions(-) create mode 100644 packages/core/src/saml-applications/queries/index.ts diff --git a/packages/core/src/constants/index.ts b/packages/core/src/constants/index.ts index d5e4e0700..259544274 100644 --- a/packages/core/src/constants/index.ts +++ b/packages/core/src/constants/index.ts @@ -8,3 +8,4 @@ export const subjectTokenPrefix = 'sub_'; export const defaultIdPInitiatedSamlSsoSessionTtl = 10 * 60 * 1000; // 10 minutes export const idpInitiatedSamlSsoSessionCookieName = '_logto_idp_saml_sso_session_id'; +export const spInitiatedSamlSsoSessionCookieName = '_logto_sp_saml_sso_session_id'; diff --git a/packages/core/src/routes/swagger/utils/operation.ts b/packages/core/src/routes/swagger/utils/operation.ts index 0a3e2b3e9..0e921a3e2 100644 --- a/packages/core/src/routes/swagger/utils/operation.ts +++ b/packages/core/src/routes/swagger/utils/operation.ts @@ -134,8 +134,12 @@ export const buildRouterObjects = (routers: T[], option // Filter out universal routes (mostly like a proxy route to withtyped) .filter(({ path }) => !path.includes('.*')) // TODO: Remove this and bring back `/saml-applications` routes before release. - // Exclude `/saml-applications` routes for now. - .filter(({ path }) => !path.startsWith('/saml-applications')) + // Exclude `/saml-applications` routes and `/saml/:id/authn` for now. + .filter( + ({ path }) => + !path.startsWith('/saml-applications') && + !(path.startsWith('/saml') && path.endsWith('/authn')) + ) .flatMap(({ path: routerPath, stack, methods }) => methods .map((method) => method.toLowerCase()) diff --git a/packages/core/src/saml-applications/libraries/saml-applications.ts b/packages/core/src/saml-applications/libraries/saml-applications.ts index 5a7dbaca6..0dcf3c422 100644 --- a/packages/core/src/saml-applications/libraries/saml-applications.ts +++ b/packages/core/src/saml-applications/libraries/saml-applications.ts @@ -132,6 +132,10 @@ export const createSamlApplicationsLibrary = (queries: Queries) => { Location: buildSingleSignOnUrl(tenantEndpoint, id), Binding: BindingType.Redirect, }, + { + Location: buildSingleSignOnUrl(tenantEndpoint, id), + Binding: BindingType.Post, + }, ], }); diff --git a/packages/core/src/saml-applications/queries/index.ts b/packages/core/src/saml-applications/queries/index.ts new file mode 100644 index 000000000..2ad6039f1 --- /dev/null +++ b/packages/core/src/saml-applications/queries/index.ts @@ -0,0 +1,82 @@ +import { type ToZodObject } from '@logto/connector-kit'; +import { + SamlApplicationConfigs, + SamlApplicationSecrets, + Applications, + ApplicationType, + type Application, + type SamlApplicationConfig, + type SamlApplicationSecret, +} from '@logto/schemas'; +import type { CommonQueryMethods } from '@silverhand/slonik'; +import { sql } from '@silverhand/slonik'; +import { z } from 'zod'; + +import { convertToIdentifiers } from '#src/utils/sql.js'; + +const { table, fields } = convertToIdentifiers(Applications, true); +const { table: samlApplicationConfigsTable, fields: samlApplicationConfigsFields } = + convertToIdentifiers(SamlApplicationConfigs, true); +const { table: samlApplicationSecretsTable, fields: samlApplicationSecretsFields } = + convertToIdentifiers(SamlApplicationSecrets, true); + +type NullableObject = { + // eslint-disable-next-line @typescript-eslint/ban-types + [P in keyof T]: T[P] | null; +}; + +type SamlApplicationSecretDetails = Pick< + SamlApplicationSecret, + 'privateKey' | 'certificate' | 'active' | 'expiresAt' +>; + +export type SamlApplicationDetails = Pick< + Application, + 'id' | 'secret' | 'name' | 'description' | 'customData' | 'oidcClientMetadata' +> & + Pick & + NullableObject; + +const samlApplicationDetailsGuard = Applications.guard + .pick({ + id: true, + secret: true, + name: true, + description: true, + customData: true, + oidcClientMetadata: true, + }) + .merge( + SamlApplicationConfigs.guard.pick({ + attributeMapping: true, + entityId: true, + acsUrl: true, + }) + ) + .merge( + // Zod does not provide a way to convert all fields to nullable, so we need to do it manually. Other implementations seems can not make TypeScript happy. + z.object({ + privateKey: SamlApplicationSecrets.guard.shape.privateKey.nullable(), + certificate: SamlApplicationSecrets.guard.shape.certificate.nullable(), + active: SamlApplicationSecrets.guard.shape.active.nullable(), + expiresAt: SamlApplicationSecrets.guard.shape.expiresAt.nullable(), + }) + ) satisfies ToZodObject; + +export const createSamlApplicationQueries = (pool: CommonQueryMethods) => { + const getSamlApplicationDetailsById = async (id: string): Promise => { + const result = await pool.one(sql` + select ${fields.id} as id, ${fields.secret} as secret, ${fields.name} as name, ${fields.description} as description, ${fields.customData} as custom_data, ${fields.oidcClientMetadata} as oidc_client_metadata, ${samlApplicationConfigsFields.attributeMapping} as attribute_mapping, ${samlApplicationConfigsFields.entityId} as entity_id, ${samlApplicationConfigsFields.acsUrl} as acs_url, ${samlApplicationSecretsFields.privateKey} as private_key, ${samlApplicationSecretsFields.certificate} as certificate, ${samlApplicationSecretsFields.active} as active, ${samlApplicationSecretsFields.expiresAt} as expires_at + from ${table} + left join ${samlApplicationConfigsTable} on ${fields.id}=${samlApplicationConfigsFields.applicationId} + left join ${samlApplicationSecretsTable} on ${fields.id}=${samlApplicationSecretsFields.applicationId} + where ${fields.id}=${id} and ${fields.type}=${ApplicationType.SAML} and ${samlApplicationSecretsFields.active}=true + `); + + return samlApplicationDetailsGuard.parse(result); + }; + + return { + getSamlApplicationDetailsById, + }; +}; diff --git a/packages/core/src/saml-applications/routes/anonymous.ts b/packages/core/src/saml-applications/routes/anonymous.ts index 15ae85ac1..bb5a5bec3 100644 --- a/packages/core/src/saml-applications/routes/anonymous.ts +++ b/packages/core/src/saml-applications/routes/anonymous.ts @@ -1,5 +1,10 @@ +import { authRequestInfoGuard } from '@logto/schemas'; +import { generateStandardId, generateStandardShortId } from '@logto/shared'; +import { cond, removeUndefinedKeys } from '@silverhand/essentials'; +import { addMinutes } from 'date-fns'; import { z } from 'zod'; +import { spInitiatedSamlSsoSessionCookieName } from '#src/constants/index.js'; import { EnvSet, getTenantEndpoint } from '#src/env-set/index.js'; import RequestError from '#src/errors/RequestError/index.js'; import koaGuard from '#src/middleware/koa-guard.js'; @@ -10,8 +15,10 @@ import { generateAutoSubmitForm, createSamlResponse, handleOidcCallbackAndGetUserInfo, - setupSamlProviders, + getSamlIdpAndSp, + getSignInUrl, buildSamlAppCallbackUrl, + validateSamlApplicationDetails, } from './utils.js'; const samlApplicationSignInCallbackQueryParametersGuard = z.union([ @@ -30,7 +37,13 @@ export default function samlApplicationAnonymousRoutes { + const { + params: { id }, + query: { Signature, RelayState, ...rest }, + } = ctx.guard; + + const [{ metadata }, details] = await Promise.all([ + getSamlIdPMetadataByApplicationId(id), + getSamlApplicationDetailsById(id), + ]); + + const { entityId, acsUrl, redirectUri, certificate, privateKey } = + validateSamlApplicationDetails(details); + + const { idp, sp } = getSamlIdpAndSp({ + idp: { metadata, certificate, privateKey }, + sp: { entityId, acsUrl }, + }); + + const octetString = Object.keys(ctx.request.query) + // eslint-disable-next-line no-restricted-syntax + .map((key) => key + '=' + encodeURIComponent(ctx.request.query[key] as string)) + .join('&'); + const { SAMLRequest, SigAlg } = rest; + + // Parse login request + try { + const loginRequestResult = await idp.parseLoginRequest(sp, 'redirect', { + query: removeUndefinedKeys({ + SAMLRequest, + Signature, + SigAlg, + }), + octetString, + }); + + const extractResult = authRequestInfoGuard.safeParse(loginRequestResult.extract); + + if (!extractResult.success) { + throw new RequestError({ + code: 'application.saml.invalid_saml_request', + error: extractResult.error.flatten(), + }); + } + + assertThat( + extractResult.data.issuer === entityId, + 'application.saml.auth_request_issuer_not_match' + ); + + const state = generateStandardId(32); + const signInUrl = await getSignInUrl({ + issuer: envSet.oidc.issuer, + applicationId: id, + redirectUri, + state, + }); + + const currentDate = new Date(); + const expiresAt = addMinutes(currentDate, 60); // Lifetime of the session is 60 minutes. + const createSession = { + id: generateStandardId(32), + applicationId: id, + oidcState: state, + samlRequestId: extractResult.data.request.id, + authRequestInfo: extractResult.data, + // Expire the session in 60 minutes. + expiresAt: expiresAt.getTime(), + ...cond(RelayState && { relayState: RelayState }), + }; + + const insertSamlAppSession = await insertSession(createSession); + // Set the session ID to cookie for later use. + ctx.cookies.set(spInitiatedSamlSsoSessionCookieName, insertSamlAppSession.id, { + httpOnly: true, + sameSite: 'strict', + expires: expiresAt, + overwrite: true, + }); + + ctx.redirect(signInUrl.toString()); + } catch (error: unknown) { + if (error instanceof RequestError) { + throw error; + } + + throw new RequestError({ + code: 'application.saml.invalid_saml_request', + }); + } + + return next(); + } + ); + + // Post binding SAML authentication request endpoint + router.post( + '/saml/:id/authn', + koaGuard({ + params: z.object({ id: z.string() }), + body: z.object({ + SAMLRequest: z.string().min(1), + RelayState: z.string().optional(), + }), + status: [200, 302, 400, 404], + }), + async (ctx, next) => { + const { + params: { id }, + body: { SAMLRequest, RelayState }, + } = ctx.guard; + + const [{ metadata }, details] = await Promise.all([ + getSamlIdPMetadataByApplicationId(id), + getSamlApplicationDetailsById(id), + ]); + + const { acsUrl, entityId, redirectUri, privateKey, certificate } = + validateSamlApplicationDetails(details); + + const { idp, sp } = getSamlIdpAndSp({ + idp: { metadata, privateKey, certificate }, + sp: { entityId, acsUrl }, + }); + + // Parse login request + try { + const loginRequestResult = await idp.parseLoginRequest(sp, 'post', { + body: { + SAMLRequest, + }, + }); + + const extractResult = authRequestInfoGuard.safeParse(loginRequestResult.extract); + + if (!extractResult.success) { + throw new RequestError({ + code: 'application.saml.invalid_saml_request', + error: extractResult.error.flatten(), + }); + } + + assertThat( + extractResult.data.issuer === entityId, + 'application.saml.auth_request_issuer_not_match' + ); + + const state = generateStandardShortId(); + const signInUrl = await getSignInUrl({ + issuer: envSet.oidc.issuer, + applicationId: id, + redirectUri, + state, + }); + + const currentDate = new Date(); + const expiresAt = addMinutes(currentDate, 60); // Lifetime of the session is 60 minutes. + const insertSamlAppSession = await insertSession({ + id: generateStandardId(), + applicationId: id, + oidcState: state, + samlRequestId: extractResult.data.request.id, + authRequestInfo: extractResult.data, + // Expire the session in 60 minutes. + expiresAt: expiresAt.getTime(), + ...cond(RelayState && { relayState: RelayState }), + }); + // Set the session ID to cookie for later use. + ctx.cookies.set(spInitiatedSamlSsoSessionCookieName, insertSamlAppSession.id, { + httpOnly: true, + sameSite: 'strict', + expires: expiresAt, + overwrite: true, + }); + + ctx.redirect(signInUrl.toString()); + } catch (error: unknown) { + if (error instanceof RequestError) { + throw error; + } + + throw new RequestError({ + code: 'application.saml.invalid_saml_request', + }); + } + + return next(); + } + ); } diff --git a/packages/core/src/saml-applications/routes/utils.test.ts b/packages/core/src/saml-applications/routes/utils.test.ts index 73cbc4af3..62a712750 100644 --- a/packages/core/src/saml-applications/routes/utils.test.ts +++ b/packages/core/src/saml-applications/routes/utils.test.ts @@ -1,12 +1,7 @@ import nock from 'nock'; import type { IdentityProviderInstance, ServiceProviderInstance } from 'samlify'; -import { - createSamlTemplateCallback, - exchangeAuthorizationCode, - getUserInfo, - setupSamlProviders, -} from './utils.js'; +import { createSamlTemplateCallback, exchangeAuthorizationCode, getUserInfo } from './utils.js'; const { jest } = import.meta; @@ -158,21 +153,3 @@ describe('getUserInfo', () => { }); }); }); - -describe('setupSamlProviders', () => { - it('should setup SAML providers with correct configuration', () => { - const mockMetadata = '...'; - const mockPrivateKey = '-----BEGIN PRIVATE KEY-----...'; - const mockEntityId = 'https://sp.example.com'; - const mockAcsUrl = { - binding: 'urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST', - url: 'https://sp.example.com/acs', - }; - - const { idp, sp } = setupSamlProviders(mockMetadata, mockPrivateKey, mockEntityId, mockAcsUrl); - - expect(idp).toBeDefined(); - expect(sp).toBeDefined(); - expect(sp.entityMeta.getEntityID()).toBe(mockEntityId); - }); -}); diff --git a/packages/core/src/saml-applications/routes/utils.ts b/packages/core/src/saml-applications/routes/utils.ts index 076d127c7..25215bc95 100644 --- a/packages/core/src/saml-applications/routes/utils.ts +++ b/packages/core/src/saml-applications/routes/utils.ts @@ -1,7 +1,12 @@ +/* eslint-disable max-lines */ +// TODO: refactor this file to reduce LOC import { parseJson } from '@logto/connector-kit'; +import { Prompt, QueryKey, ReservedScope, UserScope } from '@logto/js'; +import { type SamlAcsUrl } from '@logto/schemas'; import { generateStandardId } from '@logto/shared'; -import { tryThat, appendPath } from '@silverhand/essentials'; -import camelcaseKeys from 'camelcase-keys'; +import { tryThat, appendPath, deduplicate } from '@silverhand/essentials'; +import camelcaseKeys, { type CamelCaseKeys } from 'camelcase-keys'; +import { XMLValidator } from 'fast-xml-parser'; import saml from 'samlify'; import { ZodError, z } from 'zod'; @@ -11,8 +16,11 @@ import { getRawUserInfoResponse, handleTokenExchange, } from '#src/sso/OidcConnector/utils.js'; -import { idTokenProfileStandardClaimsGuard } from '#src/sso/types/oidc.js'; -import { type IdTokenProfileStandardClaims } from '#src/sso/types/oidc.js'; +import { + idTokenProfileStandardClaimsGuard, + type OidcConfigResponse, + type IdTokenProfileStandardClaims, +} from '#src/sso/types/oidc.js'; import assertThat from '#src/utils/assert-that.js'; import { @@ -20,6 +28,7 @@ import { samlAttributeNameFormatBasic, samlValueXmlnsXsi, } from '../libraries/consts.js'; +import { type SamlApplicationDetails } from '../queries/index.js'; /** * Determines the SAML NameID format and value based on the user's claims and IdP's NameID format. @@ -223,20 +232,7 @@ export const handleOidcCallbackAndGetUserInfo = async ( issuer: string ) => { // Get OIDC configuration - const { tokenEndpoint, userinfoEndpoint } = await tryThat( - async () => fetchOidcConfigRaw(issuer), - (error) => { - if (error instanceof ZodError) { - throw new RequestError({ - code: 'oidc.invalid_request', - message: error.message, - error: error.flatten(), - }); - } - - throw error; - } - ); + const { tokenEndpoint, userinfoEndpoint } = await getOidcConfig(issuer); // Exchange authorization code for tokens const { accessToken } = await exchangeAuthorizationCode(tokenEndpoint, { @@ -252,12 +248,102 @@ export const handleOidcCallbackAndGetUserInfo = async ( return getUserInfo(accessToken, userinfoEndpoint); }; -export const setupSamlProviders = ( - metadata: string, - privateKey: string, - entityId: string, - acsUrl: { binding: string; url: string } -) => { +const getOidcConfig = async (issuer: string): Promise> => { + const oidcConfig = await tryThat( + async () => fetchOidcConfigRaw(issuer), + (error) => { + if (error instanceof ZodError) { + throw new RequestError({ + code: 'oidc.invalid_request', + message: error.message, + error: error.flatten(), + }); + } + + throw error; + } + ); + + return oidcConfig; +}; + +export const getSignInUrl = async ({ + issuer, + applicationId, + redirectUri, + scope, + state, +}: { + issuer: string; + applicationId: string; + redirectUri: string; + scope?: string; + state?: string; +}) => { + const { authorizationEndpoint } = await getOidcConfig(issuer); + + const queryParameters = new URLSearchParams({ + [QueryKey.ClientId]: applicationId, + [QueryKey.RedirectUri]: redirectUri, + [QueryKey.ResponseType]: 'code', + [QueryKey.Prompt]: Prompt.Login, + }); + + // TODO: get value of `scope` parameters according to setup in attribute mapping. + queryParameters.append( + QueryKey.Scope, + // For security reasons, DO NOT include the offline_access scope by default. + deduplicate([ + ReservedScope.OpenId, + UserScope.Profile, + UserScope.Roles, + UserScope.Organizations, + UserScope.OrganizationRoles, + UserScope.CustomData, + UserScope.Identities, + ...(scope?.split(' ') ?? []), + ]).join(' ') + ); + + if (state) { + queryParameters.append(QueryKey.State, state); + } + + return new URL(`${authorizationEndpoint}?${queryParameters.toString()}`); +}; + +export const validateSamlApplicationDetails = (details: SamlApplicationDetails) => { + const { + entityId, + acsUrl, + oidcClientMetadata: { redirectUris }, + privateKey, + certificate, + } = details; + + assertThat(acsUrl, 'application.saml.acs_url_required'); + assertThat(entityId, 'application.saml.entity_id_required'); + assertThat(redirectUris[0], 'oidc.invalid_redirect_uri'); + + assertThat(privateKey, 'application.saml.private_key_required'); + assertThat(certificate, 'application.saml.certificate_required'); + + return { + entityId, + acsUrl, + redirectUri: redirectUris[0], + privateKey, + certificate, + }; +}; + +export const getSamlIdpAndSp = ({ + idp: { metadata, privateKey, certificate }, + sp: { entityId, acsUrl }, +}: { + idp: { metadata: string; privateKey: string; certificate: string }; + sp: { entityId: string; acsUrl: SamlAcsUrl }; +}): { idp: saml.IdentityProviderInstance; sp: saml.ServiceProviderInstance } => { // eslint-disable-next-line new-cap const idp = saml.IdentityProvider({ metadata, @@ -295,6 +381,24 @@ export const setupSamlProviders = ( Location: acsUrl.url, }, ], + signingCert: certificate, + authnRequestsSigned: idp.entityMeta.isWantAuthnRequestsSigned(), + allowCreate: false, + }); + + // Used to check whether xml content is valid in format. + saml.setSchemaValidator({ + validate: async (xmlContent: string) => { + try { + XMLValidator.validate(xmlContent, { + allowBooleanAttributes: true, + }); + + return true; + } catch { + return false; + } + }, }); return { idp, sp }; @@ -302,3 +406,4 @@ export const setupSamlProviders = ( export const buildSamlAppCallbackUrl = (baseUrl: URL, samlApplicationId: string) => appendPath(baseUrl, `api/saml-applications/${samlApplicationId}/callback`).toString(); +/* eslint-enable max-lines */ diff --git a/packages/core/src/tenants/Queries.ts b/packages/core/src/tenants/Queries.ts index 59afe303f..509b7ca1c 100644 --- a/packages/core/src/tenants/Queries.ts +++ b/packages/core/src/tenants/Queries.ts @@ -29,6 +29,7 @@ import { createUserQueries } from '#src/queries/user.js'; import { createUsersRolesQueries } from '#src/queries/users-roles.js'; import { createVerificationStatusQueries } from '#src/queries/verification-status.js'; import { createSamlApplicationConfigQueries } from '#src/saml-applications/queries/configs.js'; +import { createSamlApplicationQueries } from '#src/saml-applications/queries/index.js'; import { createSamlApplicationSecretsQueries } from '#src/saml-applications/queries/secrets.js'; import { createSamlApplicationSessionQueries } from '#src/saml-applications/queries/sessions.js'; @@ -66,6 +67,7 @@ export default class Queries { samlApplicationSecrets = createSamlApplicationSecretsQueries(this.pool); samlApplicationConfigs = createSamlApplicationConfigQueries(this.pool); samlApplicationSessions = createSamlApplicationSessionQueries(this.pool); + samlApplications = createSamlApplicationQueries(this.pool); personalAccessTokens = new PersonalAccessTokensQueries(this.pool); verificationRecords = new VerificationRecordQueries(this.pool); accountCenters = new AccountCenterQueries(this.pool); diff --git a/packages/phrases/src/locales/en/errors/application.ts b/packages/phrases/src/locales/en/errors/application.ts index 9fa9e4d49..89b692332 100644 --- a/packages/phrases/src/locales/en/errors/application.ts +++ b/packages/phrases/src/locales/en/errors/application.ts @@ -28,8 +28,13 @@ const application = { can_not_delete_active_secret: 'Can not delete the active secret.', no_active_secret: 'No active secret found.', entity_id_required: 'Entity ID is required to generate metadata.', - acs_url_required: 'Assertion consumer service URL is required to generate metadata.', invalid_certificate_pem_format: 'Invalid PEM certificate format', + acs_url_required: 'Assertion Consumer Service URL is required.', + private_key_required: 'Private key is required.', + certificate_required: 'Certificate is required.', + invalid_saml_request: 'Invalid SAML authentication request.', + auth_request_issuer_not_match: + 'The issuer of the SAML authentication request mismatch with service provider entity ID.', }, }; diff --git a/packages/phrases/src/locales/en/errors/oidc.ts b/packages/phrases/src/locales/en/errors/oidc.ts index 4f65ef651..73e9c4375 100644 --- a/packages/phrases/src/locales/en/errors/oidc.ts +++ b/packages/phrases/src/locales/en/errors/oidc.ts @@ -6,6 +6,7 @@ const oidc = { insufficient_scope: 'Token missing scope `{{scope}}`.', invalid_request: 'Request is invalid.', invalid_grant: 'Grant request is invalid.', + invalid_issuer: 'Invalid issuer.', invalid_redirect_uri: "`redirect_uri` did not match any of the client's registered `redirect_uris`.", access_denied: 'Access denied.',