0
Fork 0
mirror of https://github.com/logto-io/logto.git synced 2024-12-16 20:26:19 -05:00

refactor(core): refactor single sign-on session handle logic (#4871)

* refactor(core): refactor single sign-on session handle logic

refactor single sign-on session handle logic

* refactor(core): remove the OIDC/SAML instance assertion

remove the OIDC/SAML instance assertion

* chore(core): rename guard

rename guard
This commit is contained in:
simeng-li 2023-11-15 10:30:34 +08:00 committed by GitHub
parent 741de8c259
commit 83ba800d0a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 161 additions and 66 deletions

View file

@ -15,7 +15,7 @@ import type { WithInteractionDetailsContext } from './middleware/koa-interaction
import koaInteractionHooks from './middleware/koa-interaction-hooks.js'; import koaInteractionHooks from './middleware/koa-interaction-hooks.js';
import { getInteractionStorage, storeInteractionResult } from './utils/interaction.js'; import { getInteractionStorage, storeInteractionResult } from './utils/interaction.js';
import { import {
oidcAuthorizationUrlPayloadGuard, authorizationUrlPayloadGuard,
getSsoAuthorizationUrl, getSsoAuthorizationUrl,
getSsoAuthentication, getSsoAuthentication,
handleSsoAuthentication, handleSsoAuthentication,
@ -36,8 +36,7 @@ export default function singleSignOnRoutes<T extends IRouterParamContext>(
params: z.object({ params: z.object({
connectorId: z.string(), connectorId: z.string(),
}), }),
// Only required for OIDC body: authorizationUrlPayloadGuard,
body: oidcAuthorizationUrlPayloadGuard.optional(),
status: [200, 500, 404], status: [200, 500, 404],
response: z.object({ response: z.object({
redirectTo: z.string(), redirectTo: z.string(),

View file

@ -1,8 +1,15 @@
import { type IdentifierPayload } from '@logto/schemas'; import { type IdentifierPayload } from '@logto/schemas';
import { type Context } from 'koa';
import type Provider from 'oidc-provider';
import { z } from 'zod';
import { EnvSet } from '#src/env-set/index.js'; import { EnvSet } from '#src/env-set/index.js';
import RequestError from '#src/errors/RequestError/index.js'; import RequestError from '#src/errors/RequestError/index.js';
import { type SsoConnectorLibrary } from '#src/libraries/sso-connector.js'; import { type SsoConnectorLibrary } from '#src/libraries/sso-connector.js';
import {
type SingleSignOnConnectorSession,
singleSignOnConnectorSessionGuard,
} from '#src/sso/types/session.js';
import assertThat from '#src/utils/assert-that.js'; import assertThat from '#src/utils/assert-that.js';
// Guard the SSO only email identifier // Guard the SSO only email identifier
@ -41,3 +48,40 @@ export const verifySsoOnlyEmailIdentifier = async (
) )
); );
}; };
/**
* Get the single sign on session data from the oidc provider session storage.
*
* @param ctx
* @param provider
* @param connectorId
* @returns The single sign on session data
*
* @remark Forked from ./social-verification.ts.
* Use SingleSignOnSession guard instead of ConnectorSession guard.
*/
export const getSingleSignOnSessionResult = async (
ctx: Context,
provider: Provider
): Promise<SingleSignOnConnectorSession> => {
const { result } = await provider.interactionDetails(ctx.req, ctx.res);
const singleSignOnSessionResult = z
.object({
connectorSession: singleSignOnConnectorSessionGuard,
})
.safeParse(result);
assertThat(
result && singleSignOnSessionResult.success,
'session.connector_validation_session_not_found'
);
// Clear the session after the session data is retrieved
const { connectorSession, ...rest } = result;
await provider.interactionResult(ctx.req, ctx.res, {
...rest,
});
return singleSignOnSessionResult.data.connectorSession;
};

View file

@ -85,19 +85,17 @@ describe('Single sign on util methods tests', () => {
}); });
describe('getSsoAuthorizationUrl tests', () => { describe('getSsoAuthorizationUrl tests', () => {
it('should throw an error if the connector config is invalid', async () => { const payload = {
await expect(getSsoAuthorizationUrl(mockContext, tenant, mockSsoConnector)).rejects.toThrow( state: 'state',
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument redirectUri: 'https://example.com',
expect.objectContaining({ status: 500, code: `connector.invalid_config` }) };
);
});
it('should throw an error if OIDC connector is used without a proper payload', async () => { it('should throw an error if the connector config is invalid', async () => {
await expect( await expect(
getSsoAuthorizationUrl(mockContext, tenant, wellConfiguredSsoConnector) getSsoAuthorizationUrl(mockContext, tenant, mockSsoConnector, payload)
).rejects.toThrow( ).rejects.toThrow(
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument // eslint-disable-next-line @typescript-eslint/no-unsafe-argument
expect.objectContaining({ status: 400, code: 'session.insufficient_info' }) expect.objectContaining({ status: 500, code: `connector.invalid_config` })
); );
}); });
@ -105,10 +103,7 @@ describe('Single sign on util methods tests', () => {
getAuthorizationUrlMock.mockResolvedValueOnce('https://example.com'); getAuthorizationUrlMock.mockResolvedValueOnce('https://example.com');
await expect( await expect(
getSsoAuthorizationUrl(mockContext, tenant, wellConfiguredSsoConnector, { getSsoAuthorizationUrl(mockContext, tenant, wellConfiguredSsoConnector, payload)
state: 'state',
redirectUri: 'https://example.com',
})
).resolves.toBe('https://example.com'); ).resolves.toBe('https://example.com');
}); });
}); });

View file

@ -1,4 +1,4 @@
import { type ConnectorSession, ConnectorError, type SocialUserInfo } from '@logto/connector-kit'; import { ConnectorError, type SocialUserInfo } from '@logto/connector-kit';
import { validateRedirectUrl } from '@logto/core-kit'; import { validateRedirectUrl } from '@logto/core-kit';
import { InteractionEvent, type User, type UserSsoIdentity } from '@logto/schemas'; import { InteractionEvent, type User, type UserSsoIdentity } from '@logto/schemas';
import { generateStandardId } from '@logto/shared'; import { generateStandardId } from '@logto/shared';
@ -8,29 +8,32 @@ import { z } from 'zod';
import RequestError from '#src/errors/RequestError/index.js'; import RequestError from '#src/errors/RequestError/index.js';
import { type WithLogContext } from '#src/middleware/koa-audit-log.js'; import { type WithLogContext } from '#src/middleware/koa-audit-log.js';
import { type WithInteractionDetailsContext } from '#src/routes/interaction/middleware/koa-interaction-details.js'; import { type WithInteractionDetailsContext } from '#src/routes/interaction/middleware/koa-interaction-details.js';
import OidcConnector from '#src/sso/OidcConnector/index.js';
import { ssoConnectorFactories } from '#src/sso/index.js'; import { ssoConnectorFactories } from '#src/sso/index.js';
import { type SupportedSsoConnector } from '#src/sso/types/index.js'; import {
type SupportedSsoConnector,
type SingleSignOnConnectorSession,
} from '#src/sso/types/index.js';
import type Queries from '#src/tenants/Queries.js'; import type Queries from '#src/tenants/Queries.js';
import type TenantContext from '#src/tenants/TenantContext.js'; import type TenantContext from '#src/tenants/TenantContext.js';
import assertThat from '#src/utils/assert-that.js'; import assertThat from '#src/utils/assert-that.js';
import { storeInteractionResult } from './interaction.js'; import { storeInteractionResult } from './interaction.js';
import { assignConnectorSessionResult, getConnectorSessionResult } from './social-verification.js'; import { getSingleSignOnSessionResult } from './single-sign-on-guard.js';
import { assignConnectorSessionResult } from './social-verification.js';
export const oidcAuthorizationUrlPayloadGuard = z.object({ export const authorizationUrlPayloadGuard = z.object({
state: z.string().min(1), state: z.string().min(1),
redirectUri: z.string().refine((url) => validateRedirectUrl(url, 'web')), redirectUri: z.string().refine((url) => validateRedirectUrl(url, 'web')),
}); });
type OidcAuthorizationUrlPayload = z.infer<typeof oidcAuthorizationUrlPayloadGuard>; type AuthorizationUrlPayload = z.infer<typeof authorizationUrlPayloadGuard>;
// Get the authorization url for the SSO provider // Get the authorization url for the SSO provider
export const getSsoAuthorizationUrl = async ( export const getSsoAuthorizationUrl = async (
ctx: WithLogContext & WithInteractionDetailsContext, ctx: WithLogContext & WithInteractionDetailsContext,
{ provider, id: tenantId }: TenantContext, { provider, id: tenantId }: TenantContext,
connectorData: SupportedSsoConnector, connectorData: SupportedSsoConnector,
payload?: OidcAuthorizationUrlPayload payload: AuthorizationUrlPayload
): Promise<string> => { ): Promise<string> => {
const { id: connectorId, providerName } = connectorData; const { id: connectorId, providerName } = connectorData;
@ -52,21 +55,13 @@ export const getSsoAuthorizationUrl = async (
tenantId tenantId
); );
// OIDC connectors assertThat(payload, 'session.insufficient_info');
if (connectorInstance instanceof OidcConnector) {
// Only required for OIDC
assertThat(payload, 'session.insufficient_info');
// Will throw ConnectorError if failed to fetch the provider's config return await connectorInstance.getAuthorizationUrl(
return await connectorInstance.getAuthorizationUrl( { jti, ...payload, connectorId },
payload, async (connectorSession: SingleSignOnConnectorSession) =>
async (connectorSession: ConnectorSession) => assignConnectorSessionResult(ctx, provider, connectorSession)
assignConnectorSessionResult(ctx, provider, connectorSession) );
);
}
// SAML connectors
return await connectorInstance.getSingleSignOnUrl(jti);
} catch (error: unknown) { } catch (error: unknown) {
// Catch ConnectorError and re-throw as 500 RequestError // Catch ConnectorError and re-throw as 500 RequestError
if (error instanceof ConnectorError) { if (error instanceof ConnectorError) {
@ -104,7 +99,7 @@ export const getSsoAuthentication = async (
const issuer = await connectorInstance.getIssuer(); const issuer = await connectorInstance.getIssuer();
const userInfo = await connectorInstance.getUserInfo(data, async () => const userInfo = await connectorInstance.getUserInfo(data, async () =>
getConnectorSessionResult(ctx, provider) getSingleSignOnSessionResult(ctx, provider)
); );
const result = { const result = {

View file

@ -1,14 +1,10 @@
import { import { ConnectorError, ConnectorErrorCodes } from '@logto/connector-kit';
ConnectorError,
ConnectorErrorCodes,
type GetSession,
type SetSession,
} from '@logto/connector-kit';
import { generateStandardId } from '@logto/shared/universal'; import { generateStandardId } from '@logto/shared/universal';
import { assert, conditional } from '@silverhand/essentials'; import { assert, conditional } from '@silverhand/essentials';
import snakecaseKeys from 'snakecase-keys'; import snakecaseKeys from 'snakecase-keys';
import { type BaseOidcConfig, type BasicOidcConnectorConfig } from '../types/oidc.js'; import { type BaseOidcConfig, type BasicOidcConnectorConfig } from '../types/oidc.js';
import { type CreateSingleSignOnSession, type GetSingleSignOnSession } from '../types/session.js';
import { fetchOidcConfig, fetchToken, getIdTokenClaims } from './utils.js'; import { fetchOidcConfig, fetchToken, getIdTokenClaims } from './utils.js';
@ -51,11 +47,15 @@ class OidcConnector {
* @param oidcQueryParams The query params for the OIDC provider * @param oidcQueryParams The query params for the OIDC provider
* @param oidcQueryParams.state The state generated by Logto experience client * @param oidcQueryParams.state The state generated by Logto experience client
* @param oidcQueryParams.redirectUri The redirect uri for the OIDC provider * @param oidcQueryParams.redirectUri The redirect uri for the OIDC provider
* @param setSession Set the connector session data to the oidc provider session storage. @see @logto/connector-kit * @param setSession Set the connector session data to the oidc provider session storage.
*/ */
getAuthorizationUrl = async ( getAuthorizationUrl = async (
{ state, redirectUri }: { state: string; redirectUri: string }, {
setSession: SetSession state,
redirectUri,
connectorId,
}: { state: string; redirectUri: string; connectorId: string },
setSession: CreateSingleSignOnSession
) => { ) => {
assert( assert(
setSession, setSession,
@ -67,7 +67,7 @@ class OidcConnector {
const oidcConfig = await this.getOidcConfig(); const oidcConfig = await this.getOidcConfig();
const nonce = generateStandardId(); const nonce = generateStandardId();
await setSession({ nonce, redirectUri }); await setSession({ nonce, redirectUri, connectorId, state });
const queryParameters = new URLSearchParams({ const queryParameters = new URLSearchParams({
state, state,
@ -96,7 +96,7 @@ class OidcConnector {
* @remark Forked from @logto/oidc-connector * @remark Forked from @logto/oidc-connector
* *
*/ */
getUserInfo = async (data: unknown, getSession: GetSession) => { getUserInfo = async (data: unknown, getSession: GetSingleSignOnSession) => {
assert( assert(
getSession, getSession,
new ConnectorError(ConnectorErrorCodes.NotImplemented, { new ConnectorError(ConnectorErrorCodes.NotImplemented, {
@ -105,14 +105,7 @@ class OidcConnector {
); );
const oidcConfig = await this.getOidcConfig(); const oidcConfig = await this.getOidcConfig();
const { redirectUri, nonce } = await getSession(); const { nonce, redirectUri } = await getSession();
assert(
redirectUri,
new ConnectorError(ConnectorErrorCodes.General, {
message: "CAN NOT find 'redirectUri' from connector session.",
})
);
// Fetch token from the OIDC provider using authorization code // Fetch token from the OIDC provider using authorization code
const { idToken } = await fetchToken(oidcConfig, data, redirectUri); const { idToken } = await fetchToken(oidcConfig, data, redirectUri);

View file

@ -126,11 +126,10 @@ class SamlConnector {
/** /**
* Get the SSO URL. * Get the SSO URL.
* *
* @param jti The current session id. * @param relayState The relay state to be passed to the SAML identity provider. We use it to pass `jti` to find the connector session.
*
* @returns The SSO URL. * @returns The SSO URL.
*/ */
async getSingleSignOnUrl(jti: string) { async getSingleSignOnUrl(relayState: string) {
const { const {
entityId: entityID, entityId: entityID,
x509Certificate, x509Certificate,
@ -156,7 +155,7 @@ class SamlConnector {
// eslint-disable-next-line new-cap // eslint-disable-next-line new-cap
const serviceProvider = saml.ServiceProvider({ const serviceProvider = saml.ServiceProvider({
entityID, entityID,
relayState: jti, relayState,
nameIDFormat: nameIdFormat, nameIDFormat: nameIdFormat,
signingCert: x509Certificate, signingCert: x509Certificate,
authnRequestsSigned: true, // Sign auth request by default authnRequestsSigned: true, // Sign auth request by default

View file

@ -1,5 +1,5 @@
import * as validator from '@authenio/samlify-node-xmllint'; import * as validator from '@authenio/samlify-node-xmllint';
import { ConnectorError, ConnectorErrorCodes, socialUserInfoGuard } from '@logto/connector-kit'; import { ConnectorError, ConnectorErrorCodes } from '@logto/connector-kit';
import { type Optional, conditional } from '@silverhand/essentials'; import { type Optional, conditional } from '@silverhand/essentials';
import { got } from 'got'; import { got } from 'got';
import * as saml from 'samlify'; import * as saml from 'samlify';
@ -13,14 +13,12 @@ import {
defaultAttributeMapping, defaultAttributeMapping,
type CustomizableAttributeMap, type CustomizableAttributeMap,
type AttributeMap, type AttributeMap,
extendedSocialUserInfoGuard,
type ExtendedSocialUserInfo,
} from '../types/saml.js'; } from '../types/saml.js';
type ESamlHttpRequest = Parameters<saml.ServiceProviderInstance['parseLoginResponse']>[2]; type ESamlHttpRequest = Parameters<saml.ServiceProviderInstance['parseLoginResponse']>[2];
const extendedSocialUserInfoGuard = socialUserInfoGuard.catchall(z.unknown());
type ExtendedSocialUserInfo = z.infer<typeof extendedSocialUserInfoGuard>;
/** /**
* Parse XML-format raw SAML metadata and return the parsed SAML metadata. * Parse XML-format raw SAML metadata and return the parsed SAML metadata.
* *

View file

@ -7,6 +7,7 @@ import SamlConnector from '../SamlConnector/index.js';
import { type SingleSignOnFactory } from '../index.js'; import { type SingleSignOnFactory } from '../index.js';
import { type SingleSignOn } from '../types/index.js'; import { type SingleSignOn } from '../types/index.js';
import { samlConnectorConfigGuard } from '../types/saml.js'; import { samlConnectorConfigGuard } from '../types/saml.js';
import { type CreateSingleSignOnSession } from '../types/session.js';
/** /**
* SAML SSO connector * SAML SSO connector
@ -46,6 +47,31 @@ export class SamlSsoConnector extends SamlConnector implements SingleSignOn {
return this.getSamlConfig(); return this.getSamlConfig();
} }
/**
* Get SAML SSO URL.
* This URL will be used to redirect to the SAML identity provider.
*
* @param jti The unique identifier for the connector session.
* @param redirectUri The redirect uri for the identity provider.
* @param state The state generated by Logto experience client.
* @param setSession Set the connector session data to the oidc provider session storage. @see @logto/connector-kit
*/
async getAuthorizationUrl(
{
jti,
redirectUri,
state,
connectorId,
}: { jti: string; redirectUri: string; state: string; connectorId: string },
setSession: CreateSingleSignOnSession
) {
// We use jti as the value of the RelayState in the SAML request. So we can get it back from the SAML response and retrieve the connector session.
const singleSignOnUrl = await this.getSingleSignOnUrl(jti);
await setSession({ connectorId, redirectUri, state });
return singleSignOnUrl;
}
/** /**
* Get social user info. * Get social user info.
* *

View file

@ -1,5 +1,7 @@
import { type JsonObject, type SsoConnector } from '@logto/schemas'; import { type JsonObject, type SsoConnector } from '@logto/schemas';
export * from './session.js';
/** /**
* Single sign-on connector interface * Single sign-on connector interface
* @interface SingleSignOn * @interface SingleSignOn

View file

@ -40,3 +40,7 @@ export const samlMetadataGuard = z
export type SamlMetadata = z.infer<typeof samlMetadataGuard>; export type SamlMetadata = z.infer<typeof samlMetadataGuard>;
export type SamlConfig = SamlConnectorConfig & SamlMetadata; export type SamlConfig = SamlConnectorConfig & SamlMetadata;
// Saml assertion returned user attribute value
export const extendedSocialUserInfoGuard = socialUserInfoGuard.catchall(z.unknown());
export type ExtendedSocialUserInfo = z.infer<typeof extendedSocialUserInfoGuard>;

View file

@ -0,0 +1,40 @@
import { z } from 'zod';
import { extendedSocialUserInfoGuard } from './saml.js';
/**
* Single sign on connector session
*
* @property state The state generated by Logto experience client.
* @property redirectUri The redirect uri for the identity provider.
* @property nonce OIDC only properties, generated by OIDC connector factory, used to verify the identity provider response.
* @property userInfo The user info returned by the identity provider.
* SAML only properties, parsed from the SAML assertion.
* We store the assertion in the session storage after receiving it from the identity provider.
* So the client authentication handler can get it later.
* @property connectorId The connector id.
*
* @remark this is a forked version of @logto/connector-kit
* Simplified the type definition to only include the properties we need.
* Create additional type guard to validate the session data.
* @see @logto/connector-kit/types/social.ts
*/
export const singleSignOnConnectorSessionGuard = z.object({
state: z.string(),
redirectUri: z.string(),
connectorId: z.string(),
nonce: z.string().optional(),
userInfo: extendedSocialUserInfoGuard.optional(),
});
export type SingleSignOnConnectorSession = z.infer<typeof singleSignOnConnectorSessionGuard>;
export const samlConnectorAssertionSessionGuard = z.object({
state: z.string(),
redirectUri: z.string(),
connectorId: z.string(),
});
export type CreateSingleSignOnSession = (storage: SingleSignOnConnectorSession) => Promise<void>;
export type GetSingleSignOnSession = () => Promise<SingleSignOnConnectorSession>;