0
Fork 0
mirror of https://github.com/logto-io/logto.git synced 2025-04-14 23:11:31 -05:00

feat(core): add saml sso class

This commit is contained in:
Darcy Ye 2023-11-02 11:17:34 +08:00
parent e515c04d44
commit b5a9633f03
No known key found for this signature in database
GPG key ID: B46F4C07EDEFC610
11 changed files with 547 additions and 1 deletions

View file

@ -54,6 +54,7 @@
"deepmerge": "^4.2.2",
"dotenv": "^16.0.0",
"etag": "^1.8.1",
"fast-xml-parser": "^4.2.5",
"find-up": "^6.3.0",
"got": "^13.0.0",
"hash-wasm": "^4.9.0",
@ -80,6 +81,7 @@
"redis": "^4.6.5",
"roarr": "^7.11.0",
"semver": "^7.3.8",
"samlify": "2.8.10",
"slonik": "^30.0.0",
"slonik-interceptor-preset": "^1.2.10",
"slonik-sql-tag-raw": "^1.1.4",

View file

@ -0,0 +1,174 @@
import {
ConnectorError,
ConnectorErrorCodes,
type GetSession,
type SetSession,
} from '@logto/connector-kit';
import { assert, appendPath } from '@silverhand/essentials';
import * as saml from 'samlify';
import { z } from 'zod';
import { EnvSet, getTenantEndpoint } from '#src/env-set/index.js';
import {
type BaseSamlConfig,
type BaseSamlConnectorConfig,
attributeMappingPostProcessor,
} from '../types/saml.js';
import {
fetchSamlConfig,
getRawSamlConfig,
getUserInfoFromRawUserProfile,
samlAssertionHandler,
} from './utils.js';
class SamlConnector {
private readonly _acsUrl: string;
constructor(
private readonly config: BaseSamlConnectorConfig,
tenantId: string,
ssoConnectorId: string
) {
this._acsUrl = appendPath(
getTenantEndpoint(tenantId, EnvSet.values),
`api/authn/saml/sso/${ssoConnectorId}`
).toString();
}
get acsUrl() {
return this._acsUrl;
}
/* Fetch SAML config from the metadata XML file or metadata URL. Throws error if config is invalid. */
getSamlConfig = async (): Promise<BaseSamlConfig> => {
const samlConfig = await fetchSamlConfig(this.config);
return {
...samlConfig,
...this.config,
};
};
getIdpMetadata = async () => {
return getRawSamlConfig(this.config);
};
getAuthorizationUrl = async (
{
state,
redirectUri,
jti,
}: {
state: string;
redirectUri: string;
jti: string;
},
setSession: SetSession
) => {
const {
entityId: entityID,
x509Certificate,
nameIdFormat,
signingAlgorithm,
} = await this.getSamlConfig();
assert(
setSession,
new ConnectorError(ConnectorErrorCodes.NotImplemented, {
message: 'Function `setSession()` is not implemented.',
})
);
const storage = { state, redirectUri, jti };
await setSession(storage);
try {
const idpMetadataXml = await getRawSamlConfig(this.config);
// eslint-disable-next-line new-cap
const identityProvider = saml.IdentityProvider({
wantAuthnRequestsSigned: true, // Sign auth request by default
metadata: idpMetadataXml,
});
// eslint-disable-next-line new-cap
const serviceProvider = saml.ServiceProvider({
entityID,
relayState: jti,
nameIDFormat: nameIdFormat,
signingCert: x509Certificate,
authnRequestsSigned: true, // Sign auth request by default
requestSignatureAlgorithm: signingAlgorithm,
assertionConsumerService: [
{
Location: this._acsUrl,
Binding: saml.Constants.BindingNamespace.Post,
},
],
});
const loginRequest = serviceProvider.createLoginRequest(identityProvider, 'redirect');
return loginRequest.context;
} catch (error: unknown) {
throw new ConnectorError(ConnectorErrorCodes.General, error);
}
};
getUserInfo = async (_data: unknown, getSession: GetSession) => {
const parsedConfig = await this.getSamlConfig();
const { attributeMapping } = parsedConfig;
const profileMap = attributeMappingPostProcessor(attributeMapping);
assert(
getSession,
new ConnectorError(ConnectorErrorCodes.NotImplemented, {
message: 'Function `getSession()` is not implemented.',
})
);
const { extractedRawProfile } = await getSession();
const extractedRawProfileGuard = z.record(z.string().or(z.array(z.string())));
const rawProfileParseResult = extractedRawProfileGuard.safeParse(extractedRawProfile);
if (!rawProfileParseResult.success) {
throw new ConnectorError(ConnectorErrorCodes.InvalidResponse, rawProfileParseResult.error);
}
const rawUserProfile = rawProfileParseResult.data;
return getUserInfoFromRawUserProfile(rawUserProfile, profileMap);
};
validateSamlAssertion = async (
assertion: Record<string, unknown>,
getSession: GetSession,
setSession: SetSession
): Promise<string> => {
const parsedConfig = await this.getSamlConfig();
const idpMetadataXml = await this.getIdpMetadata();
const connectorSession = await getSession();
const { redirectUri, state } = connectorSession;
await samlAssertionHandler(assertion, { ...parsedConfig, idpMetadataXml }, setSession);
assert(
state,
new ConnectorError(ConnectorErrorCodes.General, {
message: 'Can not find `state` from connector session.',
})
);
assert(
redirectUri,
new ConnectorError(ConnectorErrorCodes.General, {
message: 'Can not find `redirectUri` from connector session.',
})
);
const queryParameters = new URLSearchParams({ state });
return `${redirectUri}?${queryParameters.toString()}`;
};
}
export default SamlConnector;

View file

@ -0,0 +1,168 @@
import {
ConnectorError,
ConnectorErrorCodes,
socialUserInfoGuard,
type SetSession,
} from '@logto/connector-kit';
import { XMLValidator } from 'fast-xml-parser';
import { got } from 'got';
import * as saml from 'samlify';
import { z } from 'zod';
import {
samlMetadataGuard,
type SamlMetadata,
type BaseSamlConnectorConfig,
type ProfileMap,
MetadataType,
type BaseSamlConfig,
} from '../types/saml.js';
type ESamlHttpRequest = Parameters<saml.ServiceProviderInstance['parseLoginResponse']>[2];
const xmlValidator = (xml: string) => {
try {
XMLValidator.validate(xml, {
allowBooleanAttributes: true,
});
} catch (error: unknown) {
throw new ConnectorError(ConnectorErrorCodes.InvalidConfig, error);
}
};
const parseXmlMetadata = (xml: string): SamlMetadata => {
xmlValidator(xml);
// eslint-disable-next-line new-cap
const idP = saml.IdentityProvider({ metadata: xml });
const rawSingleSignOnService = idP.entityMeta.getSingleSignOnService(
saml.Constants.namespace.binding.redirect
);
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
const singleSignOnService =
typeof rawSingleSignOnService === 'string'
? rawSingleSignOnService
: Object.entries(rawSingleSignOnService).find(
([key, value]) => key === saml.Constants.namespace.binding.redirect
)?.[1];
const rawSamlMetadata = {
entityId: idP.entityMeta.getEntityID(),
/**
* See implementation in `samlify` {@link https://github.com/tngan/samlify/blob/55f845da60b18d40668885c7f7e71ed0967ef67f/src/entity.ts#L88}.
*/
nameIdFormat: idP.entitySetting.nameIDFormat,
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
signInEndpoint: singleSignOnService,
signingAlgorithm: idP.entitySetting.requestSignatureAlgorithm,
// The type inference of the return type of `getX509Certificate` is any, will be guarded by later zod parser if it is not string-typed.
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
x509Certificate: idP.entityMeta.getX509Certificate(saml.Constants.wording.certUse.signing),
};
// The return type of `samlify`
const result = samlMetadataGuard.safeParse(rawSamlMetadata);
if (!result.success) {
throw new ConnectorError(ConnectorErrorCodes.InvalidConfig, result.error);
}
return result.data;
};
export const getRawSamlConfig = async (config: BaseSamlConnectorConfig): Promise<string> => {
if (config.metadataType === MetadataType.URL) {
const { body } = await got.get(config.metadataUrl);
const result = z.string().safeParse(body);
if (!result.success) {
throw new ConnectorError(ConnectorErrorCodes.InvalidConfig, result.error);
}
return result.data;
}
return config.metadataXml;
};
export const fetchSamlConfig = async (config: BaseSamlConnectorConfig) => {
const rawMetadata = await getRawSamlConfig(config);
return parseXmlMetadata(rawMetadata);
};
export const getUserInfoFromRawUserProfile = (
rawUserProfile: Record<string, unknown>,
keyMapping: ProfileMap
) => {
const keyMap = new Map(
Object.entries(keyMapping).map(([destination, source]) => [source, destination])
);
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
const mappedUserProfile = Object.fromEntries(
Object.entries(rawUserProfile)
.filter(([key, value]) => keyMap.get(key) && value)
.map(([key, value]) => [keyMap.get(key), value])
);
const result = socialUserInfoGuard.safeParse(mappedUserProfile);
if (!result.success) {
throw new ConnectorError(ConnectorErrorCodes.InvalidResponse, result.error);
}
return result.data;
};
export const samlAssertionHandler = async (
request: ESamlHttpRequest,
options: BaseSamlConfig & { idpMetadataXml: string },
setSession: SetSession
): Promise<void | Record<string, unknown>> => {
const { entityId: entityID, x509Certificate, idpMetadataXml } = options;
// eslint-disable-next-line new-cap
const identityProvider = saml.IdentityProvider({
metadata: idpMetadataXml,
});
// eslint-disable-next-line new-cap
const serviceProvider = saml.ServiceProvider({
entityID,
signingCert: x509Certificate,
});
// Used to check whether xml content is valid in format.
saml.setSchemaValidator({
validate: async (xmlContent: string) => {
try {
XMLValidator.validate(xmlContent, {
allowBooleanAttributes: true,
});
return true;
} catch {
return false;
}
},
});
try {
const assertionResult = await serviceProvider.parseLoginResponse(
identityProvider,
'post',
request
);
await setSession({
extractedRawProfile: {
...(Boolean(assertionResult.extract.nameID) && {
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
id: assertionResult.extract.nameID,
}),
...assertionResult.extract.attributes,
},
});
} catch (error: unknown) {
throw new ConnectorError(ConnectorErrorCodes.General, String(error));
}
};

View file

@ -0,0 +1,32 @@
import { ConnectorError, ConnectorErrorCodes } from '@logto/connector-kit';
import { mockSsoConnector as _mockSsoConnector } from '#src/__mocks__/sso.js';
import { SsoProviderName } from '../types/index.js';
import { samlSsoConnectorFactory } from './index.js';
const mockSsoConnector = { ..._mockSsoConnector, providerName: SsoProviderName.SAML };
describe('SamlSsoConnector', () => {
it('SamlSsoConnector should contains static properties', () => {
expect(samlSsoConnectorFactory.providerName).toEqual(SsoProviderName.SAML);
expect(samlSsoConnectorFactory.configGuard).toBeDefined();
});
it('constructor should throw error if config is invalid', () => {
const result = samlSsoConnectorFactory.configGuard.safeParse(mockSsoConnector.config);
if (result.success) {
throw new Error('Invalid config');
}
const createSamlSsoConnector = () => {
return new samlSsoConnectorFactory.constructor(mockSsoConnector, 'http://localhost:3001/api');
};
expect(createSamlSsoConnector).toThrow(
new ConnectorError(ConnectorErrorCodes.InvalidConfig, result.error)
);
});
});

View file

@ -0,0 +1,38 @@
import { ConnectorError, ConnectorErrorCodes } from '@logto/connector-kit';
import { type SsoConnector } from '@logto/schemas';
import SamlConnector from '../SamlConnector/index.js';
import { type SingleSignOnFactory } from '../index.js';
import { type SingleSignOn, SsoProviderName } from '../types/index.js';
import { baseSamlConnectorConfigGuard } from '../types/saml.js';
export class SamlSsoConnector extends SamlConnector implements SingleSignOn {
constructor(
private readonly _data: SsoConnector,
tenantId: string
) {
const parseConfigResult = baseSamlConnectorConfigGuard.safeParse(_data.config);
if (!parseConfigResult.success) {
throw new ConnectorError(ConnectorErrorCodes.InvalidConfig, parseConfigResult.error);
}
super(parseConfigResult.data, tenantId, _data.id);
}
get data() {
return this._data;
}
getConfig = async () => this.getSamlConfig();
}
export const samlSsoConnectorFactory: SingleSignOnFactory<SsoProviderName.SAML> = {
providerName: SsoProviderName.SAML,
logo: 'saml.svg',
description: {
en: ' This connector is used to connect to SAML single sign-on identity provider.',
},
configGuard: baseSamlConnectorConfigGuard,
constructor: SamlSsoConnector,
};

View file

@ -1,15 +1,21 @@
import { type I18nPhrases } from '@logto/connector-kit';
import { oidcSsoConnectorFactory, type OidcSsoConnector } from './OidcSsoConnector/index.js';
import { type SamlSsoConnector, samlSsoConnectorFactory } from './SamlSsoConnector/index.js';
import { SsoProviderName } from './types/index.js';
import { type basicOidcConnectorConfigGuard } from './types/oidc.js';
import { type baseSamlConnectorConfigGuard } from './types/saml.js';
type SingleSignOnConstructor<T extends SsoProviderName> = T extends SsoProviderName.OIDC
? typeof OidcSsoConnector
: T extends SsoProviderName.SAML
? typeof SamlSsoConnector
: never;
type SingleSignOnConnectorConfig<T extends SsoProviderName> = T extends SsoProviderName.OIDC
? typeof basicOidcConnectorConfigGuard
: T extends SsoProviderName.SAML
? typeof baseSamlConnectorConfigGuard
: never;
export type SingleSignOnFactory<T extends SsoProviderName> = {
@ -24,6 +30,10 @@ export const ssoConnectorFactories: {
[key in SsoProviderName]: SingleSignOnFactory<key>;
} = {
[SsoProviderName.OIDC]: oidcSsoConnectorFactory,
[SsoProviderName.SAML]: samlSsoConnectorFactory,
};
export const standardSsoConnectorProviders = Object.freeze([SsoProviderName.OIDC]);
export const standardSsoConnectorProviders = Object.freeze([
SsoProviderName.OIDC,
SsoProviderName.SAML,
]);

View file

@ -14,6 +14,7 @@ export abstract class SingleSignOn {
export enum SsoProviderName {
OIDC = 'OIDC',
SAML = 'SAML',
}
export type SupportedSsoConnector = Omit<SsoConnector, 'providerName'> & {

View file

@ -0,0 +1,24 @@
import { attributeMappingPostProcessor } from './saml.js';
const expectedDefaultAttributeMapping = {
id: 'id',
email: 'email',
phone: 'phone',
name: 'name',
avatar: 'avatar',
};
describe('attributeMappingPostProcessor', () => {
it('should fallback to `expectedDefaultAttributeMapping` if no other attribute mapping is specified', () => {
expect(attributeMappingPostProcessor()).toEqual(expectedDefaultAttributeMapping);
expect(attributeMappingPostProcessor({})).toEqual(expectedDefaultAttributeMapping);
});
it('should overwrite specified attributes of `expectedDefaultAttributeMapping`', () => {
expect(attributeMappingPostProcessor({ id: 'sub', avatar: 'picture' })).toEqual({
...expectedDefaultAttributeMapping,
id: 'sub',
avatar: 'picture',
});
});
});

View file

@ -0,0 +1,83 @@
import { socialUserInfoGuard, socialUserInfoKeys } from '@logto/connector-kit';
import { conditional } from '@silverhand/essentials';
import cleanDeep from 'clean-deep';
import { z } from 'zod';
export enum MetadataType {
XML = 'XML',
URL = 'URL',
}
export type ProfileMap = Required<z.infer<typeof socialUserInfoGuard>>;
const attributeMapGuard = socialUserInfoGuard.partial();
type AttributeMap = z.infer<typeof attributeMapGuard>;
/**
* Get the full attribute mapping using specified attribute mappings with default fallback values.
*
* @param attributeMapping Specified attribute mapping stored in database
* @returns Full attribute mapping with default fallback values
*/
export const attributeMappingPostProcessor = (attributeMapping?: AttributeMap): ProfileMap => {
return {
// eslint-disable-next-line no-restricted-syntax
...(Object.fromEntries(socialUserInfoKeys.map((key) => [key, key])) as ProfileMap),
...conditional(attributeMapping && cleanDeep(attributeMapping)),
};
};
const basicSamlCommonFields = {
attributeMapping: attributeMapGuard.optional(),
signInEndpoint: z.string().optional(),
entityId: z.string().optional(),
x509Certificate: z.string().optional(),
};
export const baseSamlConnectorConfigGuard = z.discriminatedUnion('metadataType', [
z.object({
metadataType: z.literal(MetadataType.URL),
metadataUrl: z.string().url(),
...basicSamlCommonFields,
}),
z.object({
metadataType: z.literal(MetadataType.XML),
metadataXml: z.string(),
...basicSamlCommonFields,
}),
]);
export type BaseSamlConnectorConfig = z.infer<typeof baseSamlConnectorConfigGuard>;
/**
* Zod discriminate union does not support its partial util method, we need to manually implement this.
* This is for guarding the config on creating.
*/
export const basicSamlConnectorConfigPartialGuard = z.discriminatedUnion('metadataType', [
z
.object({
metadataUrl: z.string().url(),
...basicSamlCommonFields,
})
.partial()
.merge(z.object({ metadataType: z.literal(MetadataType.URL) })),
z
.object({
metadataXml: z.string(),
...basicSamlCommonFields,
})
.partial()
.merge(z.object({ metadataType: z.literal(MetadataType.XML) })),
]);
export const samlMetadataGuard = z.object({
entityId: z.string(),
nameIdFormat: z.string().array().optional(),
signInEndpoint: z.string(),
signingAlgorithm: z.string(),
x509Certificate: z.string(),
});
export type SamlMetadata = z.infer<typeof samlMetadataGuard>;
export type BaseSamlConfig = BaseSamlConnectorConfig & SamlMetadata;

View file

@ -32,6 +32,14 @@ export const socialUserInfoGuard = z.object({
export type SocialUserInfo = z.infer<typeof socialUserInfoGuard>;
export const socialUserInfoKeys = Object.freeze([
'id',
'email',
'phone',
'name',
'avatar',
] satisfies Array<keyof SocialUserInfo>);
export type GetUserInfo = (
data: unknown,
getSession: GetSession

6
pnpm-lock.yaml generated
View file

@ -3208,6 +3208,9 @@ importers:
etag:
specifier: ^1.8.1
version: 1.8.1
fast-xml-parser:
specifier: ^4.2.5
version: 4.2.5
find-up:
specifier: ^6.3.0
version: 6.3.0
@ -3283,6 +3286,9 @@ importers:
roarr:
specifier: ^7.11.0
version: 7.11.0
samlify:
specifier: 2.8.10
version: 2.8.10
semver:
specifier: ^7.3.8
version: 7.3.8