0
Fork 0
mirror of https://github.com/logto-io/logto.git synced 2025-03-03 22:15:32 -05:00

feat(core): add SAML auth request handling endpoints

This commit is contained in:
Darcy Ye 2024-12-10 18:40:59 +08:00
parent fd2ea4a24e
commit 6222307ce8
No known key found for this signature in database
GPG key ID: B46F4C07EDEFC610
10 changed files with 457 additions and 55 deletions

View file

@ -8,3 +8,4 @@ export const subjectTokenPrefix = 'sub_';
export const defaultIdPInitiatedSamlSsoSessionTtl = 10 * 60 * 1000; // 10 minutes
export const idpInitiatedSamlSsoSessionCookieName = '_logto_idp_saml_sso_session_id';
export const spInitiatedSamlSsoSessionCookieName = '_logto_sp_saml_sso_session_id';

View file

@ -134,8 +134,12 @@ export const buildRouterObjects = <T extends UnknownRouter>(routers: T[], option
// Filter out universal routes (mostly like a proxy route to withtyped)
.filter(({ path }) => !path.includes('.*'))
// TODO: Remove this and bring back `/saml-applications` routes before release.
// Exclude `/saml-applications` routes for now.
.filter(({ path }) => !path.startsWith('/saml-applications'))
// Exclude `/saml-applications` routes and `/saml/:id/authn` for now.
.filter(
({ path }) =>
!path.startsWith('/saml-applications') &&
!(path.startsWith('/saml') && path.endsWith('/authn'))
)
.flatMap<RouteObject>(({ path: routerPath, stack, methods }) =>
methods
.map((method) => method.toLowerCase())

View file

@ -132,6 +132,10 @@ export const createSamlApplicationsLibrary = (queries: Queries) => {
Location: buildSingleSignOnUrl(tenantEndpoint, id),
Binding: BindingType.Redirect,
},
{
Location: buildSingleSignOnUrl(tenantEndpoint, id),
Binding: BindingType.Post,
},
],
});

View file

@ -0,0 +1,82 @@
import { type ToZodObject } from '@logto/connector-kit';
import {
SamlApplicationConfigs,
SamlApplicationSecrets,
Applications,
ApplicationType,
type Application,
type SamlApplicationConfig,
type SamlApplicationSecret,
} from '@logto/schemas';
import type { CommonQueryMethods } from '@silverhand/slonik';
import { sql } from '@silverhand/slonik';
import { z } from 'zod';
import { convertToIdentifiers } from '#src/utils/sql.js';
const { table, fields } = convertToIdentifiers(Applications, true);
const { table: samlApplicationConfigsTable, fields: samlApplicationConfigsFields } =
convertToIdentifiers(SamlApplicationConfigs, true);
const { table: samlApplicationSecretsTable, fields: samlApplicationSecretsFields } =
convertToIdentifiers(SamlApplicationSecrets, true);
type NullableObject<T> = {
// eslint-disable-next-line @typescript-eslint/ban-types
[P in keyof T]: T[P] | null;
};
type SamlApplicationSecretDetails = Pick<
SamlApplicationSecret,
'privateKey' | 'certificate' | 'active' | 'expiresAt'
>;
export type SamlApplicationDetails = Pick<
Application,
'id' | 'secret' | 'name' | 'description' | 'customData' | 'oidcClientMetadata'
> &
Pick<SamlApplicationConfig, 'attributeMapping' | 'entityId' | 'acsUrl'> &
NullableObject<SamlApplicationSecretDetails>;
const samlApplicationDetailsGuard = Applications.guard
.pick({
id: true,
secret: true,
name: true,
description: true,
customData: true,
oidcClientMetadata: true,
})
.merge(
SamlApplicationConfigs.guard.pick({
attributeMapping: true,
entityId: true,
acsUrl: true,
})
)
.merge(
// Zod does not provide a way to convert all fields to nullable, so we need to do it manually. Other implementations seems can not make TypeScript happy.
z.object({
privateKey: SamlApplicationSecrets.guard.shape.privateKey.nullable(),
certificate: SamlApplicationSecrets.guard.shape.certificate.nullable(),
active: SamlApplicationSecrets.guard.shape.active.nullable(),
expiresAt: SamlApplicationSecrets.guard.shape.expiresAt.nullable(),
})
) satisfies ToZodObject<SamlApplicationDetails>;
export const createSamlApplicationQueries = (pool: CommonQueryMethods) => {
const getSamlApplicationDetailsById = async (id: string): Promise<SamlApplicationDetails> => {
const result = await pool.one(sql`
select ${fields.id} as id, ${fields.secret} as secret, ${fields.name} as name, ${fields.description} as description, ${fields.customData} as custom_data, ${fields.oidcClientMetadata} as oidc_client_metadata, ${samlApplicationConfigsFields.attributeMapping} as attribute_mapping, ${samlApplicationConfigsFields.entityId} as entity_id, ${samlApplicationConfigsFields.acsUrl} as acs_url, ${samlApplicationSecretsFields.privateKey} as private_key, ${samlApplicationSecretsFields.certificate} as certificate, ${samlApplicationSecretsFields.active} as active, ${samlApplicationSecretsFields.expiresAt} as expires_at
from ${table}
left join ${samlApplicationConfigsTable} on ${fields.id}=${samlApplicationConfigsFields.applicationId}
left join ${samlApplicationSecretsTable} on ${fields.id}=${samlApplicationSecretsFields.applicationId}
where ${fields.id}=${id} and ${fields.type}=${ApplicationType.SAML} and ${samlApplicationSecretsFields.active}=true
`);
return samlApplicationDetailsGuard.parse(result);
};
return {
getSamlApplicationDetailsById,
};
};

View file

@ -1,5 +1,10 @@
import { authRequestInfoGuard } from '@logto/schemas';
import { generateStandardId, generateStandardShortId } from '@logto/shared';
import { cond, removeUndefinedKeys } from '@silverhand/essentials';
import { addMinutes } from 'date-fns';
import { z } from 'zod';
import { spInitiatedSamlSsoSessionCookieName } from '#src/constants/index.js';
import { EnvSet, getTenantEndpoint } from '#src/env-set/index.js';
import RequestError from '#src/errors/RequestError/index.js';
import koaGuard from '#src/middleware/koa-guard.js';
@ -10,8 +15,10 @@ import {
generateAutoSubmitForm,
createSamlResponse,
handleOidcCallbackAndGetUserInfo,
setupSamlProviders,
getSamlIdpAndSp,
getSignInUrl,
buildSamlAppCallbackUrl,
validateSamlApplicationDetails,
} from './utils.js';
const samlApplicationSignInCallbackQueryParametersGuard = z.union([
@ -30,7 +37,13 @@ export default function samlApplicationAnonymousRoutes<T extends AnonymousRouter
const {
samlApplications: { getSamlIdPMetadataByApplicationId },
} = libraries;
const { applications, samlApplicationSecrets, samlApplicationConfigs } = queries;
const {
applications,
samlApplicationSecrets,
samlApplicationConfigs,
samlApplications: { getSamlApplicationDetailsById },
samlApplicationSessions: { insertSession },
} = queries;
router.get(
'/saml-applications/:id/metadata',
@ -101,7 +114,7 @@ export default function samlApplicationAnonymousRoutes<T extends AnonymousRouter
// TODO: we will refactor the following code later, to reduce the DB query connections.
// Get SAML configuration
const { metadata } = await getSamlIdPMetadataByApplicationId(id);
const { privateKey } =
const { privateKey, certificate } =
await samlApplicationSecrets.findActiveSamlApplicationSecretByApplicationId(id);
const { entityId, acsUrl } =
await samlApplicationConfigs.findSamlApplicationConfigByApplicationId(id);
@ -110,7 +123,10 @@ export default function samlApplicationAnonymousRoutes<T extends AnonymousRouter
assertThat(acsUrl, 'application.saml.acs_url_required');
// Setup SAML providers and create response
const { idp, sp } = setupSamlProviders(metadata, privateKey, entityId, acsUrl);
const { idp, sp } = getSamlIdpAndSp({
idp: { metadata, privateKey, certificate },
sp: { entityId, acsUrl },
});
const { context, entityEndpoint } = await createSamlResponse(idp, sp, userInfo);
// Return auto-submit form
@ -118,4 +134,209 @@ export default function samlApplicationAnonymousRoutes<T extends AnonymousRouter
return next();
}
);
// Redirect binding SAML authentication request endpoint
router.get(
'/saml/:id/authn',
koaGuard({
params: z.object({ id: z.string() }),
query: z
.object({
SAMLRequest: z.string().min(1),
Signature: z.string().optional(),
SigAlg: z.string().optional(),
RelayState: z.string().optional(),
})
.catchall(z.string()),
status: [200, 302, 400, 404],
}),
async (ctx, next) => {
const {
params: { id },
query: { Signature, RelayState, ...rest },
} = ctx.guard;
const [{ metadata }, details] = await Promise.all([
getSamlIdPMetadataByApplicationId(id),
getSamlApplicationDetailsById(id),
]);
const { entityId, acsUrl, redirectUri, certificate, privateKey } =
validateSamlApplicationDetails(details);
const { idp, sp } = getSamlIdpAndSp({
idp: { metadata, certificate, privateKey },
sp: { entityId, acsUrl },
});
const octetString = Object.keys(ctx.request.query)
// eslint-disable-next-line no-restricted-syntax
.map((key) => key + '=' + encodeURIComponent(ctx.request.query[key] as string))
.join('&');
const { SAMLRequest, SigAlg } = rest;
// Parse login request
try {
const loginRequestResult = await idp.parseLoginRequest(sp, 'redirect', {
query: removeUndefinedKeys({
SAMLRequest,
Signature,
SigAlg,
}),
octetString,
});
const extractResult = authRequestInfoGuard.safeParse(loginRequestResult.extract);
if (!extractResult.success) {
throw new RequestError({
code: 'application.saml.invalid_saml_request',
error: extractResult.error.flatten(),
});
}
assertThat(
extractResult.data.issuer === entityId,
'application.saml.auth_request_issuer_not_match'
);
const state = generateStandardId(32);
const signInUrl = await getSignInUrl({
issuer: envSet.oidc.issuer,
applicationId: id,
redirectUri,
state,
});
const currentDate = new Date();
const expiresAt = addMinutes(currentDate, 60); // Lifetime of the session is 60 minutes.
const createSession = {
id: generateStandardId(32),
applicationId: id,
oidcState: state,
samlRequestId: extractResult.data.request.id,
authRequestInfo: extractResult.data,
// Expire the session in 60 minutes.
expiresAt: expiresAt.getTime(),
...cond(RelayState && { relayState: RelayState }),
};
const insertSamlAppSession = await insertSession(createSession);
// Set the session ID to cookie for later use.
ctx.cookies.set(spInitiatedSamlSsoSessionCookieName, insertSamlAppSession.id, {
httpOnly: true,
sameSite: 'strict',
expires: expiresAt,
overwrite: true,
});
ctx.redirect(signInUrl.toString());
} catch (error: unknown) {
if (error instanceof RequestError) {
throw error;
}
throw new RequestError({
code: 'application.saml.invalid_saml_request',
});
}
return next();
}
);
// Post binding SAML authentication request endpoint
router.post(
'/saml/:id/authn',
koaGuard({
params: z.object({ id: z.string() }),
body: z.object({
SAMLRequest: z.string().min(1),
RelayState: z.string().optional(),
}),
status: [200, 302, 400, 404],
}),
async (ctx, next) => {
const {
params: { id },
body: { SAMLRequest, RelayState },
} = ctx.guard;
const [{ metadata }, details] = await Promise.all([
getSamlIdPMetadataByApplicationId(id),
getSamlApplicationDetailsById(id),
]);
const { acsUrl, entityId, redirectUri, privateKey, certificate } =
validateSamlApplicationDetails(details);
const { idp, sp } = getSamlIdpAndSp({
idp: { metadata, privateKey, certificate },
sp: { entityId, acsUrl },
});
// Parse login request
try {
const loginRequestResult = await idp.parseLoginRequest(sp, 'post', {
body: {
SAMLRequest,
},
});
const extractResult = authRequestInfoGuard.safeParse(loginRequestResult.extract);
if (!extractResult.success) {
throw new RequestError({
code: 'application.saml.invalid_saml_request',
error: extractResult.error.flatten(),
});
}
assertThat(
extractResult.data.issuer === entityId,
'application.saml.auth_request_issuer_not_match'
);
const state = generateStandardShortId();
const signInUrl = await getSignInUrl({
issuer: envSet.oidc.issuer,
applicationId: id,
redirectUri,
state,
});
const currentDate = new Date();
const expiresAt = addMinutes(currentDate, 60); // Lifetime of the session is 60 minutes.
const insertSamlAppSession = await insertSession({
id: generateStandardId(),
applicationId: id,
oidcState: state,
samlRequestId: extractResult.data.request.id,
authRequestInfo: extractResult.data,
// Expire the session in 60 minutes.
expiresAt: expiresAt.getTime(),
...cond(RelayState && { relayState: RelayState }),
});
// Set the session ID to cookie for later use.
ctx.cookies.set(spInitiatedSamlSsoSessionCookieName, insertSamlAppSession.id, {
httpOnly: true,
sameSite: 'strict',
expires: expiresAt,
overwrite: true,
});
ctx.redirect(signInUrl.toString());
} catch (error: unknown) {
if (error instanceof RequestError) {
throw error;
}
throw new RequestError({
code: 'application.saml.invalid_saml_request',
});
}
return next();
}
);
}

View file

@ -1,12 +1,7 @@
import nock from 'nock';
import type { IdentityProviderInstance, ServiceProviderInstance } from 'samlify';
import {
createSamlTemplateCallback,
exchangeAuthorizationCode,
getUserInfo,
setupSamlProviders,
} from './utils.js';
import { createSamlTemplateCallback, exchangeAuthorizationCode, getUserInfo } from './utils.js';
const { jest } = import.meta;
@ -158,21 +153,3 @@ describe('getUserInfo', () => {
});
});
});
describe('setupSamlProviders', () => {
it('should setup SAML providers with correct configuration', () => {
const mockMetadata = '<EntityDescriptor>...</EntityDescriptor>';
const mockPrivateKey = '-----BEGIN PRIVATE KEY-----...';
const mockEntityId = 'https://sp.example.com';
const mockAcsUrl = {
binding: 'urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST',
url: 'https://sp.example.com/acs',
};
const { idp, sp } = setupSamlProviders(mockMetadata, mockPrivateKey, mockEntityId, mockAcsUrl);
expect(idp).toBeDefined();
expect(sp).toBeDefined();
expect(sp.entityMeta.getEntityID()).toBe(mockEntityId);
});
});

View file

@ -1,7 +1,12 @@
/* eslint-disable max-lines */
// TODO: refactor this file to reduce LOC
import { parseJson } from '@logto/connector-kit';
import { Prompt, QueryKey, ReservedScope, UserScope } from '@logto/js';
import { type SamlAcsUrl } from '@logto/schemas';
import { generateStandardId } from '@logto/shared';
import { tryThat, appendPath } from '@silverhand/essentials';
import camelcaseKeys from 'camelcase-keys';
import { tryThat, appendPath, deduplicate } from '@silverhand/essentials';
import camelcaseKeys, { type CamelCaseKeys } from 'camelcase-keys';
import { XMLValidator } from 'fast-xml-parser';
import saml from 'samlify';
import { ZodError, z } from 'zod';
@ -11,8 +16,11 @@ import {
getRawUserInfoResponse,
handleTokenExchange,
} from '#src/sso/OidcConnector/utils.js';
import { idTokenProfileStandardClaimsGuard } from '#src/sso/types/oidc.js';
import { type IdTokenProfileStandardClaims } from '#src/sso/types/oidc.js';
import {
idTokenProfileStandardClaimsGuard,
type OidcConfigResponse,
type IdTokenProfileStandardClaims,
} from '#src/sso/types/oidc.js';
import assertThat from '#src/utils/assert-that.js';
import {
@ -20,6 +28,7 @@ import {
samlAttributeNameFormatBasic,
samlValueXmlnsXsi,
} from '../libraries/consts.js';
import { type SamlApplicationDetails } from '../queries/index.js';
/**
* Determines the SAML NameID format and value based on the user's claims and IdP's NameID format.
@ -223,20 +232,7 @@ export const handleOidcCallbackAndGetUserInfo = async (
issuer: string
) => {
// Get OIDC configuration
const { tokenEndpoint, userinfoEndpoint } = await tryThat(
async () => fetchOidcConfigRaw(issuer),
(error) => {
if (error instanceof ZodError) {
throw new RequestError({
code: 'oidc.invalid_request',
message: error.message,
error: error.flatten(),
});
}
throw error;
}
);
const { tokenEndpoint, userinfoEndpoint } = await getOidcConfig(issuer);
// Exchange authorization code for tokens
const { accessToken } = await exchangeAuthorizationCode(tokenEndpoint, {
@ -252,12 +248,102 @@ export const handleOidcCallbackAndGetUserInfo = async (
return getUserInfo(accessToken, userinfoEndpoint);
};
export const setupSamlProviders = (
metadata: string,
privateKey: string,
entityId: string,
acsUrl: { binding: string; url: string }
) => {
const getOidcConfig = async (issuer: string): Promise<CamelCaseKeys<OidcConfigResponse>> => {
const oidcConfig = await tryThat(
async () => fetchOidcConfigRaw(issuer),
(error) => {
if (error instanceof ZodError) {
throw new RequestError({
code: 'oidc.invalid_request',
message: error.message,
error: error.flatten(),
});
}
throw error;
}
);
return oidcConfig;
};
export const getSignInUrl = async ({
issuer,
applicationId,
redirectUri,
scope,
state,
}: {
issuer: string;
applicationId: string;
redirectUri: string;
scope?: string;
state?: string;
}) => {
const { authorizationEndpoint } = await getOidcConfig(issuer);
const queryParameters = new URLSearchParams({
[QueryKey.ClientId]: applicationId,
[QueryKey.RedirectUri]: redirectUri,
[QueryKey.ResponseType]: 'code',
[QueryKey.Prompt]: Prompt.Login,
});
// TODO: get value of `scope` parameters according to setup in attribute mapping.
queryParameters.append(
QueryKey.Scope,
// For security reasons, DO NOT include the offline_access scope by default.
deduplicate([
ReservedScope.OpenId,
UserScope.Profile,
UserScope.Roles,
UserScope.Organizations,
UserScope.OrganizationRoles,
UserScope.CustomData,
UserScope.Identities,
...(scope?.split(' ') ?? []),
]).join(' ')
);
if (state) {
queryParameters.append(QueryKey.State, state);
}
return new URL(`${authorizationEndpoint}?${queryParameters.toString()}`);
};
export const validateSamlApplicationDetails = (details: SamlApplicationDetails) => {
const {
entityId,
acsUrl,
oidcClientMetadata: { redirectUris },
privateKey,
certificate,
} = details;
assertThat(acsUrl, 'application.saml.acs_url_required');
assertThat(entityId, 'application.saml.entity_id_required');
assertThat(redirectUris[0], 'oidc.invalid_redirect_uri');
assertThat(privateKey, 'application.saml.private_key_required');
assertThat(certificate, 'application.saml.certificate_required');
return {
entityId,
acsUrl,
redirectUri: redirectUris[0],
privateKey,
certificate,
};
};
export const getSamlIdpAndSp = ({
idp: { metadata, privateKey, certificate },
sp: { entityId, acsUrl },
}: {
idp: { metadata: string; privateKey: string; certificate: string };
sp: { entityId: string; acsUrl: SamlAcsUrl };
}): { idp: saml.IdentityProviderInstance; sp: saml.ServiceProviderInstance } => {
// eslint-disable-next-line new-cap
const idp = saml.IdentityProvider({
metadata,
@ -295,6 +381,24 @@ export const setupSamlProviders = (
Location: acsUrl.url,
},
],
signingCert: certificate,
authnRequestsSigned: idp.entityMeta.isWantAuthnRequestsSigned(),
allowCreate: false,
});
// Used to check whether xml content is valid in format.
saml.setSchemaValidator({
validate: async (xmlContent: string) => {
try {
XMLValidator.validate(xmlContent, {
allowBooleanAttributes: true,
});
return true;
} catch {
return false;
}
},
});
return { idp, sp };
@ -302,3 +406,4 @@ export const setupSamlProviders = (
export const buildSamlAppCallbackUrl = (baseUrl: URL, samlApplicationId: string) =>
appendPath(baseUrl, `api/saml-applications/${samlApplicationId}/callback`).toString();
/* eslint-enable max-lines */

View file

@ -29,6 +29,7 @@ import { createUserQueries } from '#src/queries/user.js';
import { createUsersRolesQueries } from '#src/queries/users-roles.js';
import { createVerificationStatusQueries } from '#src/queries/verification-status.js';
import { createSamlApplicationConfigQueries } from '#src/saml-applications/queries/configs.js';
import { createSamlApplicationQueries } from '#src/saml-applications/queries/index.js';
import { createSamlApplicationSecretsQueries } from '#src/saml-applications/queries/secrets.js';
import { createSamlApplicationSessionQueries } from '#src/saml-applications/queries/sessions.js';
@ -66,6 +67,7 @@ export default class Queries {
samlApplicationSecrets = createSamlApplicationSecretsQueries(this.pool);
samlApplicationConfigs = createSamlApplicationConfigQueries(this.pool);
samlApplicationSessions = createSamlApplicationSessionQueries(this.pool);
samlApplications = createSamlApplicationQueries(this.pool);
personalAccessTokens = new PersonalAccessTokensQueries(this.pool);
verificationRecords = new VerificationRecordQueries(this.pool);
accountCenters = new AccountCenterQueries(this.pool);

View file

@ -28,8 +28,13 @@ const application = {
can_not_delete_active_secret: 'Can not delete the active secret.',
no_active_secret: 'No active secret found.',
entity_id_required: 'Entity ID is required to generate metadata.',
acs_url_required: 'Assertion consumer service URL is required to generate metadata.',
invalid_certificate_pem_format: 'Invalid PEM certificate format',
acs_url_required: 'Assertion Consumer Service URL is required.',
private_key_required: 'Private key is required.',
certificate_required: 'Certificate is required.',
invalid_saml_request: 'Invalid SAML authentication request.',
auth_request_issuer_not_match:
'The issuer of the SAML authentication request mismatch with service provider entity ID.',
},
};

View file

@ -6,6 +6,7 @@ const oidc = {
insufficient_scope: 'Token missing scope `{{scope}}`.',
invalid_request: 'Request is invalid.',
invalid_grant: 'Grant request is invalid.',
invalid_issuer: 'Invalid issuer.',
invalid_redirect_uri:
"`redirect_uri` did not match any of the client's registered `redirect_uris`.",
access_denied: 'Access denied.',