0
Fork 0
mirror of https://github.com/logto-io/logto.git synced 2025-02-17 22:04:19 -05:00

fix(core): apply custom domain on SAML SSO and app (#7022)

* fix: apply custom domain on SAML SSO and app

* chore: apply custom domain on SAML SSO guide

* chore: add changeset

* chore: rename input params
This commit is contained in:
Darcy Ye 2025-02-12 14:59:43 +08:00 committed by GitHub
parent bd18da4cfa
commit d44007faab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 132 additions and 75 deletions

View file

@ -0,0 +1,6 @@
---
"@logto/console": patch
"@logto/core": patch
---
apply custom domain to SAML SSO and SAML applications

View file

@ -1,9 +1,11 @@
import { conditionalString } from '@silverhand/essentials';
import { useContext, useMemo } from 'react';
import { z } from 'zod';
import { SsoConnectorContext } from '@/contexts/SsoConnectorContextProvider';
import CopyToClipboard from '@/ds-components/CopyToClipboard';
import FormField from '@/ds-components/FormField';
import useCustomDomain from '@/hooks/use-custom-domain';
import styles from './index.module.scss';
@ -20,6 +22,7 @@ const samlProviderConfigGuard = z.object({
function SsoSamlSpMetadata() {
const { ssoConnector } = useContext(SsoConnectorContext);
const { applyDomain: applyCustomDomain } = useCustomDomain();
const serviceProviderMetadata = useMemo(() => {
if (!ssoConnector) {
@ -49,7 +52,9 @@ function SsoSamlSpMetadata() {
<CopyToClipboard
displayType="block"
variant="border"
value={serviceProviderMetadata?.entityId ?? ''}
value={conditionalString(
serviceProviderMetadata?.entityId && applyCustomDomain(serviceProviderMetadata.entityId)
)}
/>
</FormField>
<FormField
@ -59,7 +64,10 @@ function SsoSamlSpMetadata() {
<CopyToClipboard
displayType="block"
variant="border"
value={serviceProviderMetadata?.assertionConsumerServiceUrl ?? ''}
value={conditionalString(
serviceProviderMetadata?.assertionConsumerServiceUrl &&
applyCustomDomain(serviceProviderMetadata.assertionConsumerServiceUrl)
)}
/>
</FormField>
</div>

View file

@ -43,6 +43,7 @@ export class EnvSet {
#pool: Optional<DatabasePool>;
#oidc: Optional<Awaited<ReturnType<typeof loadOidcValues>>>;
#endpoint: Optional<URL>;
constructor(
public readonly tenantId: string,
@ -65,6 +66,14 @@ export class EnvSet {
return this.#oidc;
}
get endpoint() {
if (!this.#endpoint) {
return throwNotLoadedError();
}
return this.#endpoint;
}
async load(customDomain?: string) {
const pool = await createPoolByEnv(
this.databaseUrl,
@ -81,10 +90,10 @@ export class EnvSet {
});
const oidcConfigs = await getOidcConfigs(consoleLog);
const endpoint = customDomain
this.#endpoint = customDomain
? new URL(customDomain)
: getTenantEndpoint(this.tenantId, EnvSet.values);
this.#oidc = await loadOidcValues(appendPath(endpoint, '/oidc').href, oidcConfigs);
this.#oidc = await loadOidcValues(appendPath(this.#endpoint, '/oidc').href, oidcConfigs);
}
async end() {

View file

@ -198,7 +198,7 @@ export default function authnRoutes<T extends AnonymousRouter>(
// Will throw ConnectorError if the config is invalid
const connectorInstance = new ssoConnectorFactories[providerName].constructor(
connectorData,
tenantId
envSet.endpoint
);
assertThat(connectorInstance instanceof SamlConnector, 'connector.unexpected_type');

View file

@ -8,7 +8,7 @@ import {
wellConfiguredSsoConnector,
mockSamlSsoConnector,
} from '#src/__mocks__/sso.js';
import { EnvSet } from '#src/env-set/index.js';
import { EnvSet, getTenantEndpoint } from '#src/env-set/index.js';
import RequestError from '#src/errors/RequestError/index.js';
import { type WithLogContext } from '#src/middleware/koa-audit-log.js';
import { type WithInteractionDetailsContext } from '#src/middleware/koa-interaction-details.js';
@ -74,7 +74,8 @@ jest
jest
.spyOn(ssoConnectorFactories.SAML, 'constructor')
.mockImplementation(
(data: SingleSignOnConnectorData) => new MockSamlSsoConnector(data, 'tenantId')
(data: SingleSignOnConnectorData) =>
new MockSamlSsoConnector(data, getTenantEndpoint('tenantId', EnvSet.values))
);
const {

View file

@ -39,7 +39,7 @@ type AuthorizationUrlPayload = z.infer<typeof authorizationUrlPayloadGuard>;
export const getSsoAuthorizationUrl = async (
ctx: WithLogContext,
{ provider, id: tenantId, queries }: TenantContext,
{ provider, queries, envSet }: TenantContext,
connectorData: SupportedSsoConnector,
payload: AuthorizationUrlPayload
): Promise<string> => {
@ -58,7 +58,7 @@ export const getSsoAuthorizationUrl = async (
// Will throw ConnectorError if the config is invalid
const connectorInstance = new ssoConnectorFactories[providerName].constructor(
connectorData,
tenantId
envSet.endpoint
);
assertThat(payload, 'session.insufficient_info');
@ -143,7 +143,7 @@ type SsoAuthenticationResult = {
*/
export const verifySsoIdentity = async (
ctx: WithLogContext,
{ provider, id: tenantId }: TenantContext,
{ provider, envSet }: TenantContext,
connectorData: SupportedSsoConnector,
data: Record<string, unknown>
): Promise<SsoAuthenticationResult> => {
@ -159,7 +159,7 @@ export const verifySsoIdentity = async (
// Will throw ConnectorError if the config is invalid
const connectorInstance = new ssoConnectorFactories[providerName].constructor(
connectorData,
tenantId
envSet.endpoint
);
const issuer = await connectorInstance.getIssuer();
const userInfo = await connectorInstance.getUserInfo(singleSignOnSession, data);

View file

@ -27,7 +27,7 @@ const samlApplicationSignInCallbackQueryParametersGuard = z
.partial();
export default function samlApplicationAnonymousRoutes<T extends AnonymousRouter>(
...[router, { id: tenantId, libraries, queries, envSet }]: RouterInitArgs<T>
...[router, { queries, envSet }]: RouterInitArgs<T>
) {
const {
samlApplications: { getSamlApplicationDetailsById },
@ -50,7 +50,7 @@ export default function samlApplicationAnonymousRoutes<T extends AnonymousRouter
const { id } = ctx.guard.params;
const details = await getSamlApplicationDetailsById(id);
const samlApplication = new SamlApplication(details, id, envSet.oidc.issuer, tenantId);
const samlApplication = new SamlApplication(details, id, envSet);
ctx.status = 200;
ctx.body = samlApplication.idPMetadata;
@ -134,7 +134,7 @@ export default function samlApplicationAnonymousRoutes<T extends AnonymousRouter
});
const details = await getSamlApplicationDetailsById(id);
const samlApplication = new SamlApplication(details, id, envSet.oidc.issuer, tenantId);
const samlApplication = new SamlApplication(details, id, envSet);
assertThat(
samlApplication.config.redirectUri === samlApplication.samlAppCallbackUrl,
@ -252,7 +252,7 @@ export default function samlApplicationAnonymousRoutes<T extends AnonymousRouter
});
const details = await getSamlApplicationDetailsById(id);
const samlApplication = new SamlApplication(details, id, envSet.oidc.issuer, tenantId);
const samlApplication = new SamlApplication(details, id, envSet);
const octetString = Object.keys(ctx.request.query)
// eslint-disable-next-line no-restricted-syntax
@ -361,7 +361,7 @@ export default function samlApplicationAnonymousRoutes<T extends AnonymousRouter
});
const details = await getSamlApplicationDetailsById(id);
const samlApplication = new SamlApplication(details, id, envSet.oidc.issuer, tenantId);
const samlApplication = new SamlApplication(details, id, envSet);
// Parse login request
try {

View file

@ -10,7 +10,7 @@ import { generateStandardId } from '@logto/shared';
import { removeUndefinedKeys } from '@silverhand/essentials';
import { z } from 'zod';
import { EnvSet, getTenantEndpoint } from '#src/env-set/index.js';
import { EnvSet } from '#src/env-set/index.js';
import RequestError from '#src/errors/RequestError/index.js';
import {
calculateCertificateFingerprints,
@ -27,7 +27,7 @@ import assertThat from '#src/utils/assert-that.js';
import { parseSearchParamsForSearch } from '#src/utils/search.js';
export default function samlApplicationRoutes<T extends ManagementApiRouter>(
...[router, { id: tenantId, queries, libraries }]: RouterInitArgs<T>
...[router, { id: tenantId, queries, libraries, envSet }]: RouterInitArgs<T>
) {
const {
applications: {
@ -92,10 +92,7 @@ export default function samlApplicationRoutes<T extends ManagementApiRouter>(
const id = generateStandardId();
// Set the default redirect URI for SAML apps when creating a new SAML app.
const redirectUri = getSamlAppCallbackUrl(
getTenantEndpoint(tenantId, EnvSet.values),
id
).toString();
const redirectUri = getSamlAppCallbackUrl(envSet.endpoint, id).toString();
const application = await insertApplication(
removeUndefinedKeys({

View file

@ -41,6 +41,7 @@ export default function singleSignOnConnectorsRoutes<T extends ManagementApiRout
quota,
ssoConnectors: { getSsoConnectorById, getSsoConnectors },
},
envSet,
},
] = args;
@ -118,7 +119,7 @@ export default function singleSignOnConnectorsRoutes<T extends ManagementApiRout
providerName,
config: parsedConfig,
},
tenantId
envSet.endpoint
);
}
@ -158,7 +159,7 @@ export default function singleSignOnConnectorsRoutes<T extends ManagementApiRout
// Fetch provider details for each connector
const connectorsWithProviderDetails = await Promise.all(
connectors.map(async (connector) =>
fetchConnectorProviderDetails(connector, tenantId, ctx.locale)
fetchConnectorProviderDetails(connector, envSet.endpoint, ctx.locale)
)
);
@ -189,7 +190,7 @@ export default function singleSignOnConnectorsRoutes<T extends ManagementApiRout
// Fetch provider details for the connector
const connectorWithProviderDetails = await fetchConnectorProviderDetails(
connector,
tenantId,
envSet.endpoint,
locale
);
@ -269,7 +270,7 @@ export default function singleSignOnConnectorsRoutes<T extends ManagementApiRout
providerName,
config: parsedConfig,
},
tenantId
envSet.endpoint
);
}
@ -293,7 +294,7 @@ export default function singleSignOnConnectorsRoutes<T extends ManagementApiRout
const connectorWithProviderDetails = await fetchConnectorProviderDetails(
connector,
tenantId,
envSet.endpoint,
locale
);

View file

@ -2,6 +2,7 @@ import { SsoProviderName } from '@logto/schemas';
import { createMockUtils } from '@logto/shared/esm';
import { mockSsoConnector } from '#src/__mocks__/sso.js';
import { getTenantEndpoint, EnvSet } from '#src/env-set/index.js';
import RequestError from '#src/errors/RequestError/index.js';
const { jest } = import.meta;
@ -55,7 +56,11 @@ describe('parseFactoryDetail', () => {
describe('fetchConnectorProviderDetails', () => {
it('providerConfig should be undefined if connector config is invalid', async () => {
const connector = { ...mockSsoConnector, config: { clientId: 'foo' } };
const result = await fetchConnectorProviderDetails(connector, mockTenantId, 'en');
const result = await fetchConnectorProviderDetails(
connector,
getTenantEndpoint(mockTenantId, EnvSet.values),
'en'
);
expect(result).toMatchObject(
expect.objectContaining({
@ -74,7 +79,11 @@ describe('fetchConnectorProviderDetails', () => {
};
fetchOidcConfig.mockRejectedValueOnce(new Error('mock-error'));
const result = await fetchConnectorProviderDetails(connector, mockTenantId, 'en');
const result = await fetchConnectorProviderDetails(
connector,
getTenantEndpoint(mockTenantId, EnvSet.values),
'en'
);
expect(result).toMatchObject(
expect.objectContaining({
@ -93,7 +102,11 @@ describe('fetchConnectorProviderDetails', () => {
};
fetchOidcConfig.mockResolvedValueOnce({ tokenEndpoint: 'http://example.com/token' });
const result = await fetchConnectorProviderDetails(connector, mockTenantId, 'en');
const result = await fetchConnectorProviderDetails(
connector,
getTenantEndpoint(mockTenantId, EnvSet.values),
'en'
);
expect(result).toMatchObject(
expect.objectContaining({

View file

@ -57,7 +57,7 @@ export const parseConnectorConfig = (providerName: SsoProviderName, config: Json
export const fetchConnectorProviderDetails = async (
connector: SupportedSsoConnector,
tenantId: string,
endpoint: URL,
locale: string
): Promise<SsoConnectorWithProviderConfig> => {
const { providerName } = connector;
@ -69,7 +69,7 @@ export const fetchConnectorProviderDetails = async (
Return undefined if failed to fetch or parse the config.
*/
const providerConfig = await trySafe(async () => {
const instance = new constructor(connector, tenantId);
const instance = new constructor(connector, endpoint);
return instance.getConfig();
});
@ -91,11 +91,11 @@ export const fetchConnectorProviderDetails = async (
*/
export const validateConnectorConfigConnectionStatus = async (
connector: SingleSignOnConnectorData,
tenantId: string
endpoint: URL
) => {
const { providerName } = connector;
const { constructor } = ssoConnectorFactories[providerName];
const instance = new constructor(connector, tenantId);
const instance = new constructor(connector, endpoint);
// SAML connector's idpMetadata is optional (safely catch by the getConfig method), we need to force fetch the IdP metadata here
if (instance instanceof SamlConnector) {

View file

@ -2,6 +2,8 @@ import { UserScope, ReservedScope } from '@logto/core-kit';
import { NameIdFormat } from '@logto/schemas';
import nock from 'nock';
import { EnvSet, getTenantEndpoint } from '#src/env-set/index.js';
import { SamlApplication } from './index.js';
const { jest } = import.meta;
@ -58,7 +60,10 @@ describe('SamlApplication', () => {
beforeEach(() => {
// @ts-expect-error
// eslint-disable-next-line @silverhand/fp/no-mutation
samlApp = new TestSamlApplication(mockDetails, mockSamlApplicationId, mockIssuer, mockTenantId);
samlApp = new TestSamlApplication(mockDetails, mockSamlApplicationId, {
oidc: { issuer: mockIssuer },
endpoint: getTenantEndpoint(mockTenantId, EnvSet.values),
});
nock(mockIssuer).get('/.well-known/openid-configuration').reply(200, {
token_endpoint: mockTokenEndpoint,
@ -188,8 +193,10 @@ describe('SamlApplication', () => {
attributeMapping: {},
},
mockSamlApplicationId,
mockIssuer,
mockTenantId
{
oidc: { issuer: mockIssuer },
endpoint: getTenantEndpoint(mockTenantId, EnvSet.values),
}
);
const scopes = app.exposedGetScopesFromAttributeMapping();
@ -207,8 +214,10 @@ describe('SamlApplication', () => {
attributeMapping: {},
},
mockSamlApplicationId,
mockIssuer,
mockTenantId
{
oidc: { issuer: mockIssuer },
endpoint: getTenantEndpoint(mockTenantId, EnvSet.values),
}
);
const scopes = app.exposedGetScopesFromAttributeMapping();
@ -228,8 +237,10 @@ describe('SamlApplication', () => {
},
},
mockSamlApplicationId,
mockIssuer,
mockTenantId
{
oidc: { issuer: mockIssuer },
endpoint: getTenantEndpoint(mockTenantId, EnvSet.values),
}
);
const scopes = app.exposedGetScopesFromAttributeMapping();
@ -250,8 +261,10 @@ describe('SamlApplication', () => {
},
},
mockSamlApplicationId,
mockIssuer,
mockTenantId
{
oidc: { issuer: mockIssuer },
endpoint: getTenantEndpoint(mockTenantId, EnvSet.values),
}
);
const scopes = app.exposedGetScopesFromAttributeMapping();
@ -277,8 +290,10 @@ describe('SamlApplication', () => {
},
},
mockSamlApplicationId,
mockIssuer,
mockTenantId
{
oidc: { issuer: mockIssuer },
endpoint: getTenantEndpoint(mockTenantId, EnvSet.values),
}
);
const scopes = app.exposedGetScopesFromAttributeMapping();
@ -308,8 +323,10 @@ describe('SamlApplication', () => {
// @ts-expect-error
mockDetailsWithMapping,
mockSamlApplicationId,
mockIssuer,
mockTenantId
{
oidc: { issuer: mockIssuer },
endpoint: getTenantEndpoint(mockTenantId, EnvSet.values),
}
);
const template = samlApp.exposedBuildLoginResponseTemplate();
@ -353,8 +370,10 @@ describe('SamlApplication', () => {
// @ts-expect-error
mockDetailsWithMapping,
mockSamlApplicationId,
mockIssuer,
mockTenantId
{
oidc: { issuer: mockIssuer },
endpoint: getTenantEndpoint(mockTenantId, EnvSet.values),
}
);
const tagValues = samlApp.exposedBuildSamlAttributesTagValues(mockUser);
@ -382,8 +401,10 @@ describe('SamlApplication', () => {
// @ts-expect-error
mockDetailsWithMapping,
mockSamlApplicationId,
mockIssuer,
mockTenantId
{
oidc: { issuer: mockIssuer },
endpoint: getTenantEndpoint(mockTenantId, EnvSet.values),
}
);
const tagValues = samlApp.exposedBuildSamlAttributesTagValues(mockUser);

View file

@ -16,7 +16,7 @@ import { XMLValidator } from 'fast-xml-parser';
import saml from 'samlify';
import { ZodError, z } from 'zod';
import { EnvSet, getTenantEndpoint } from '#src/env-set/index.js';
import { type EnvSet } from '#src/env-set/index.js';
import RequestError from '#src/errors/RequestError/index.js';
import {
buildSingleSignOnUrl,
@ -109,7 +109,8 @@ class SamlApplicationConfig {
export class SamlApplication {
public config: SamlApplicationConfig;
protected tenantEndpoint: URL;
protected endpoint: URL;
protected issuer: string;
protected oidcConfig?: CamelCaseKeys<OidcConfigResponse>;
private _idp?: saml.IdentityProviderInstance;
@ -118,11 +119,11 @@ export class SamlApplication {
constructor(
details: SamlApplicationDetails,
protected samlApplicationId: string,
protected issuer: string,
tenantId: string
protected envSet: EnvSet
) {
this.config = new SamlApplicationConfig(details);
this.tenantEndpoint = getTenantEndpoint(tenantId, EnvSet.values);
this.issuer = envSet.oidc.issuer;
this.endpoint = envSet.endpoint;
}
public get idp(): saml.IdentityProviderInstance {
@ -146,7 +147,7 @@ export class SamlApplication {
}
public get samlAppCallbackUrl() {
return getSamlAppCallbackUrl(this.tenantEndpoint, this.samlApplicationId).toString();
return getSamlAppCallbackUrl(this.endpoint, this.samlApplicationId).toString();
}
public async parseLoginRequest(
@ -484,10 +485,10 @@ export class SamlApplication {
private buildIdpConfig(): SamlIdentityProviderConfig {
return {
entityId: buildSamlIdentityProviderEntityId(this.tenantEndpoint, this.samlApplicationId),
entityId: buildSamlIdentityProviderEntityId(this.endpoint, this.samlApplicationId),
privateKey: this.config.privateKey,
certificate: this.config.certificate,
singleSignOnUrl: buildSingleSignOnUrl(this.tenantEndpoint, this.samlApplicationId),
singleSignOnUrl: buildSingleSignOnUrl(this.endpoint, this.samlApplicationId),
nameIdFormat: this.config.nameIdFormat,
encryptSamlAssertion: this.config.encryption?.encryptAssertion ?? false,
};

View file

@ -3,8 +3,6 @@ import { conditional, type Optional } from '@silverhand/essentials';
import { XMLValidator } from 'fast-xml-parser';
import * as saml from 'samlify';
import { EnvSet, getTenantEndpoint } from '#src/env-set/index.js';
import {
SsoConnectorConfigErrorCodes,
SsoConnectorError,
@ -58,18 +56,13 @@ class SamlConnector {
// Allow _idpConfig input to be undefined when constructing the connector.
constructor(
tenantId: string,
endpoint: URL,
ssoConnectorId: string,
private readonly _idpConfig: SamlConnectorConfig | undefined
) {
const tenantEndpoint = getTenantEndpoint(tenantId, EnvSet.values);
const assertionConsumerServiceUrl = buildAssertionConsumerServiceUrl(endpoint, ssoConnectorId);
const assertionConsumerServiceUrl = buildAssertionConsumerServiceUrl(
tenantEndpoint,
ssoConnectorId
);
const spEntityId = buildSpEntityId(tenantEndpoint, ssoConnectorId);
const spEntityId = buildSpEntityId(endpoint, ssoConnectorId);
this.serviceProviderMetadata = {
entityId: spEntityId,

View file

@ -1,6 +1,7 @@
import { SsoProviderName } from '@logto/schemas';
import { mockSsoConnector as _mockSsoConnector } from '#src/__mocks__/sso.js';
import { getTenantEndpoint, EnvSet } from '#src/env-set/index.js';
import {
SsoConnectorConfigErrorCodes,
@ -17,7 +18,10 @@ describe('SamlSsoConnector', () => {
it('constructor should work properly', () => {
// eslint-disable-next-line unicorn/consistent-function-scoping
const createSamlSsoConnector = () =>
new samlSsoConnectorFactory.constructor(mockSsoConnector, 'default_tenant');
new samlSsoConnectorFactory.constructor(
mockSsoConnector,
getTenantEndpoint('default_tenant', EnvSet.values)
);
expect(createSamlSsoConnector).not.toThrow();
});
@ -26,7 +30,7 @@ describe('SamlSsoConnector', () => {
const temporaryMockSsoConnector = { ...mockSsoConnector, config: { metadata: 123 } };
const connector = new samlSsoConnectorFactory.constructor(
temporaryMockSsoConnector,
'default_tenant'
getTenantEndpoint('default_tenant', EnvSet.values)
);
const { serviceProvider, identityProvider } = await connector.getConfig();
@ -36,7 +40,10 @@ describe('SamlSsoConnector', () => {
});
it('should throw error on calling getIdpMetadata, if the config is invalid', async () => {
const connector = new samlSsoConnectorFactory.constructor(mockSsoConnector, 'default_tenant');
const connector = new samlSsoConnectorFactory.constructor(
mockSsoConnector,
getTenantEndpoint('default_tenant', EnvSet.values)
);
await expect(async () => connector.getSamlIdpMetadata()).rejects.toThrow(
new SsoConnectorError(SsoConnectorErrorCodes.InvalidConfig, {
@ -59,7 +66,7 @@ describe('SamlSsoConnector', () => {
const connector = new samlSsoConnectorFactory.constructor(
temporaryMockSsoConnector,
'default_tenant'
getTenantEndpoint('default_tenant', EnvSet.values)
);
expect(connector.idpConfig).toEqual(config);

View file

@ -32,12 +32,12 @@ import {
export class SamlSsoConnector extends SamlConnector implements SingleSignOn {
constructor(
readonly data: SingleSignOnConnectorData,
tenantId: string
endpoint: URL
) {
const parseConfigResult = samlConnectorConfigGuard.safeParse(data.config);
// Fallback to undefined if config is invalid
super(tenantId, data.id, conditional(parseConfigResult.success && parseConfigResult.data));
super(endpoint, data.id, conditional(parseConfigResult.success && parseConfigResult.data));
}
async getIssuer() {