diff --git a/packages/core/src/app/init.ts b/packages/core/src/app/init.ts index 4ffdb811c..5ac044aff 100644 --- a/packages/core/src/app/init.ts +++ b/packages/core/src/app/init.ts @@ -29,7 +29,7 @@ export default async function initApp(app: Koa): Promise { return next(); } - const tenantId = await getTenantId(ctx.URL); + const [tenantId, isCustomDomain] = await getTenantId(ctx.URL); if (!tenantId) { ctx.status = 404; @@ -37,7 +37,11 @@ export default async function initApp(app: Koa): Promise { return next(); } - const tenant = await trySafe(tenantPool.get(tenantId), (error) => { + // If the request is a custom domain of the tenant, use the custom endpoint to build "OIDC issuer" + // otherwise, build from the default endpoint (subdomain). + const customEndpoint = isCustomDomain ? ctx.URL.origin : undefined; + + const tenant = await trySafe(tenantPool.get(tenantId, customEndpoint), (error) => { ctx.status = error instanceof TenantNotFoundError ? 404 : 500; void appInsights.trackException(error); }); diff --git a/packages/core/src/env-set/index.ts b/packages/core/src/env-set/index.ts index 7fb599dac..88c14e63d 100644 --- a/packages/core/src/env-set/index.ts +++ b/packages/core/src/env-set/index.ts @@ -63,7 +63,7 @@ export class EnvSet { return this.#oidc; } - async load() { + async load(customDomain?: string) { const pool = await createPoolByEnv( this.databaseUrl, EnvSet.values.isUnitTest, @@ -77,7 +77,9 @@ export class EnvSet { }); const oidcConfigs = await getOidcConfigs(); - const endpoint = getTenantEndpoint(this.tenantId, EnvSet.values); + const endpoint = customDomain + ? new URL(customDomain) + : getTenantEndpoint(this.tenantId, EnvSet.values); this.#oidc = await loadOidcValues(appendPath(endpoint, '/oidc').href, oidcConfigs); } diff --git a/packages/core/src/middleware/koa-spa-session-guard.ts b/packages/core/src/middleware/koa-spa-session-guard.ts index f5c8f01dd..66ebf5623 100644 --- a/packages/core/src/middleware/koa-spa-session-guard.ts +++ b/packages/core/src/middleware/koa-spa-session-guard.ts @@ -52,7 +52,7 @@ export default function koaSpaSessionGuard< return; } - const tenantId = await getTenantId(ctx.URL); + const [tenantId] = await getTenantId(ctx.URL); if (!tenantId) { throw new RequestError({ code: 'session.not_found', status: 404 }); diff --git a/packages/core/src/routes/interaction/actions/submit-interaction.mfa.test.ts b/packages/core/src/routes/interaction/actions/submit-interaction.mfa.test.ts index 0549bb916..9d3fbe9ae 100644 --- a/packages/core/src/routes/interaction/actions/submit-interaction.mfa.test.ts +++ b/packages/core/src/routes/interaction/actions/submit-interaction.mfa.test.ts @@ -36,7 +36,7 @@ mockEsm('@logto/shared', () => ({ })); mockEsm('#src/utils/tenant.js', () => ({ - getTenantId: () => adminTenantId, + getTenantId: () => [adminTenantId], })); const userQueries = { diff --git a/packages/core/src/routes/interaction/actions/submit-interaction.test.ts b/packages/core/src/routes/interaction/actions/submit-interaction.test.ts index 0dea2ab9a..bc9007ac8 100644 --- a/packages/core/src/routes/interaction/actions/submit-interaction.test.ts +++ b/packages/core/src/routes/interaction/actions/submit-interaction.test.ts @@ -37,7 +37,7 @@ mockEsm('@logto/shared', () => ({ })); mockEsm('#src/utils/tenant.js', () => ({ - getTenantId: () => adminTenantId, + getTenantId: () => [adminTenantId], })); const userQueries = { diff --git a/packages/core/src/routes/interaction/actions/submit-interaction.ts b/packages/core/src/routes/interaction/actions/submit-interaction.ts index 861922e42..5ca004ab3 100644 --- a/packages/core/src/routes/interaction/actions/submit-interaction.ts +++ b/packages/core/src/routes/interaction/actions/submit-interaction.ts @@ -105,7 +105,8 @@ async function handleSubmitRegister( const { client_id } = ctx.interactionDetails.params; const { isCloud } = EnvSet.values; - const isInAdminTenant = (await getTenantId(ctx.URL)) === adminTenantId; + const [currentTenantId] = await getTenantId(ctx.URL); + const isInAdminTenant = currentTenantId === adminTenantId; const isCreatingFirstAdminUser = isInAdminTenant && String(client_id) === adminConsoleApplicationId && !(await hasActiveUsers()); diff --git a/packages/core/src/routes/user-assets.ts b/packages/core/src/routes/user-assets.ts index d69841729..fa52bde44 100644 --- a/packages/core/src/routes/user-assets.ts +++ b/packages/core/src/routes/user-assets.ts @@ -63,7 +63,7 @@ export default function userAssetsRoutes(...[router]: Ro 'guard.mime_type_not_allowed' ); - const tenantId = await getTenantId(ctx.URL); + const [tenantId] = await getTenantId(ctx.URL); assertThat(tenantId, 'guard.can_not_get_tenant_id'); const { storageProviderConfig } = SystemContext.shared; diff --git a/packages/core/src/tenants/Tenant.ts b/packages/core/src/tenants/Tenant.ts index 4963febf4..3d24be56d 100644 --- a/packages/core/src/tenants/Tenant.ts +++ b/packages/core/src/tenants/Tenant.ts @@ -36,10 +36,11 @@ import type TenantContext from './TenantContext.js'; import { getTenantDatabaseDsn } from './utils.js'; export default class Tenant implements TenantContext { - static async create(id: string, redisCache: RedisCache): Promise { + static async create(id: string, redisCache: RedisCache, customDomain?: string): Promise { // Treat the default database URL as the management URL const envSet = new EnvSet(id, await getTenantDatabaseDsn(id)); - await envSet.load(); + // Custom endpoint is used for building OIDC issuer URL when the request is a custom domain + await envSet.load(customDomain); return new Tenant(envSet, id, new WellKnownCache(id, redisCache)); } diff --git a/packages/core/src/tenants/index.ts b/packages/core/src/tenants/index.ts index 31d7586e1..5225a2d6c 100644 --- a/packages/core/src/tenants/index.ts +++ b/packages/core/src/tenants/index.ts @@ -15,8 +15,9 @@ export class TenantPool { }, }); - async get(tenantId: string): Promise { - const tenantPromise = this.cache.get(tenantId); + async get(tenantId: string, customDomain?: string): Promise { + const cacheKey = `${tenantId}-${customDomain ?? 'default'}`; + const tenantPromise = this.cache.get(cacheKey); if (tenantPromise) { const tenant = await tenantPromise; @@ -27,9 +28,9 @@ export class TenantPool { // Otherwise, create a new tenant instance and store in LRU cache, using the code below. } - consoleLog.info('Init tenant:', tenantId); - const newTenantPromise = Tenant.create(tenantId, redisCache); - this.cache.set(tenantId, newTenantPromise); + consoleLog.info('Init tenant:', tenantId, customDomain); + const newTenantPromise = Tenant.create(tenantId, redisCache, customDomain); + this.cache.set(cacheKey, newTenantPromise); return newTenantPromise; } diff --git a/packages/core/src/utils/tenant.test.ts b/packages/core/src/utils/tenant.test.ts index fac175546..f91b2771b 100644 --- a/packages/core/src/utils/tenant.test.ts +++ b/packages/core/src/utils/tenant.test.ts @@ -23,6 +23,11 @@ mockEsm('#src/queries/domains.js', () => ({ const { getTenantId } = await import('./tenant.js'); +const getTenantIdFirstElement = async (url: URL) => { + const [tenantId] = await getTenantId(url); + return tenantId; +}; + describe('getTenantId()', () => { const backupEnv = process.env; @@ -37,7 +42,7 @@ describe('getTenantId()', () => { DEVELOPMENT_TENANT_ID: 'foo', }; - await expect(getTenantId(new URL('https://some.random.url'))).resolves.toBe('foo'); + await expect(getTenantIdFirstElement(new URL('https://some.random.url'))).resolves.toBe('foo'); process.env = { ...backupEnv, @@ -46,20 +51,22 @@ describe('getTenantId()', () => { DEVELOPMENT_TENANT_ID: 'bar', }; - await expect(getTenantId(new URL('https://some.random.url'))).resolves.toBe('bar'); + await expect(getTenantIdFirstElement(new URL('https://some.random.url'))).resolves.toBe('bar'); }); it('should resolve proper tenant ID for similar localhost endpoints', async () => { - await expect(getTenantId(new URL('http://localhost:3002/some/path////'))).resolves.toBe( - adminTenantId - ); - await expect(getTenantId(new URL('http://localhost:30021/some/path'))).resolves.toBe( + await expect( + getTenantIdFirstElement(new URL('http://localhost:3002/some/path////')) + ).resolves.toBe(adminTenantId); + await expect( + getTenantIdFirstElement(new URL('http://localhost:30021/some/path')) + ).resolves.toBe(defaultTenantId); + await expect( + getTenantIdFirstElement(new URL('http://localhostt:30021/some/path')) + ).resolves.toBe(defaultTenantId); + await expect(getTenantIdFirstElement(new URL('https://localhost:3002'))).resolves.toBe( defaultTenantId ); - await expect(getTenantId(new URL('http://localhostt:30021/some/path'))).resolves.toBe( - defaultTenantId - ); - await expect(getTenantId(new URL('https://localhost:3002'))).resolves.toBe(defaultTenantId); }); it('should resolve proper tenant ID for similar domain endpoints', async () => { @@ -69,24 +76,30 @@ describe('getTenantId()', () => { ENDPOINT: 'https://foo.*.logto.mock/app', }; - await expect(getTenantId(new URL('https://foo.foo.logto.mock/app///asdasd'))).resolves.toBe( - 'foo' - ); - await expect(getTenantId(new URL('https://foo.*.logto.mock/app'))).resolves.toBe(undefined); - await expect(getTenantId(new URL('https://foo.foo.logto.mockk/app///asdasd'))).resolves.toBe( + await expect( + getTenantIdFirstElement(new URL('https://foo.foo.logto.mock/app///asdasd')) + ).resolves.toBe('foo'); + await expect(getTenantIdFirstElement(new URL('https://foo.*.logto.mock/app'))).resolves.toBe( undefined ); - await expect(getTenantId(new URL('https://foo.foo.logto.mock/appp'))).resolves.toBe(undefined); - await expect(getTenantId(new URL('https://foo.foo.logto.mock:1/app/'))).resolves.toBe( + await expect( + getTenantIdFirstElement(new URL('https://foo.foo.logto.mockk/app///asdasd')) + ).resolves.toBe(undefined); + await expect(getTenantIdFirstElement(new URL('https://foo.foo.logto.mock/appp'))).resolves.toBe( undefined ); - await expect(getTenantId(new URL('http://foo.foo.logto.mock/app'))).resolves.toBe(undefined); - await expect(getTenantId(new URL('https://user.foo.bar.logto.mock/app'))).resolves.toBe( - undefined - ); - await expect(getTenantId(new URL('https://foo.bar.bar.logto.mock/app'))).resolves.toBe( + await expect( + getTenantIdFirstElement(new URL('https://foo.foo.logto.mock:1/app/')) + ).resolves.toBe(undefined); + await expect(getTenantIdFirstElement(new URL('http://foo.foo.logto.mock/app'))).resolves.toBe( undefined ); + await expect( + getTenantIdFirstElement(new URL('https://user.foo.bar.logto.mock/app')) + ).resolves.toBe(undefined); + await expect( + getTenantIdFirstElement(new URL('https://foo.bar.bar.logto.mock/app')) + ).resolves.toBe(undefined); }); it('should resolve proper tenant ID if admin localhost is disabled', async () => { @@ -99,17 +112,21 @@ describe('getTenantId()', () => { ADMIN_DISABLE_LOCALHOST: '1', }; - await expect(getTenantId(new URL('http://localhost:5000/app///asdasd'))).resolves.toBe( - undefined + await expect( + getTenantIdFirstElement(new URL('http://localhost:5000/app///asdasd')) + ).resolves.toBe(undefined); + await expect( + getTenantIdFirstElement(new URL('http://localhost:3002/app///asdasd')) + ).resolves.toBe(undefined); + await expect(getTenantIdFirstElement(new URL('https://user.foo.logto.mock/app'))).resolves.toBe( + 'foo' ); - await expect(getTenantId(new URL('http://localhost:3002/app///asdasd'))).resolves.toBe( - undefined + await expect( + getTenantIdFirstElement(new URL('https://user.admin.logto.mock/app//')) + ).resolves.toBe(undefined); // Admin endpoint is explicitly set + await expect(getTenantIdFirstElement(new URL('https://admin.logto.mock/app'))).resolves.toBe( + adminTenantId ); - await expect(getTenantId(new URL('https://user.foo.logto.mock/app'))).resolves.toBe('foo'); - await expect(getTenantId(new URL('https://user.admin.logto.mock/app//'))).resolves.toBe( - undefined - ); // Admin endpoint is explicitly set - await expect(getTenantId(new URL('https://admin.logto.mock/app'))).resolves.toBe(adminTenantId); process.env = { ...backupEnv, @@ -118,9 +135,9 @@ describe('getTenantId()', () => { ENDPOINT: 'https://user.*.logto.mock/app', ADMIN_DISABLE_LOCALHOST: '1', }; - await expect(getTenantId(new URL('https://user.admin.logto.mock/app//'))).resolves.toBe( - 'admin' - ); + await expect( + getTenantIdFirstElement(new URL('https://user.admin.logto.mock/app//')) + ).resolves.toBe('admin'); }); it('should resolve proper tenant ID for path-based multi-tenancy', async () => { @@ -132,16 +149,24 @@ describe('getTenantId()', () => { PATH_BASED_MULTI_TENANCY: '1', }; - await expect(getTenantId(new URL('http://localhost:5000/app///asdasd'))).resolves.toBe('app'); - await expect(getTenantId(new URL('http://localhost:3002///bar///asdasd'))).resolves.toBe( - adminTenantId - ); - await expect(getTenantId(new URL('https://user.foo.logto.mock/app'))).resolves.toBe(undefined); - await expect(getTenantId(new URL('https://user.admin.logto.mock/app//'))).resolves.toBe( + await expect( + getTenantIdFirstElement(new URL('http://localhost:5000/app///asdasd')) + ).resolves.toBe('app'); + await expect( + getTenantIdFirstElement(new URL('http://localhost:3002///bar///asdasd')) + ).resolves.toBe(adminTenantId); + await expect(getTenantIdFirstElement(new URL('https://user.foo.logto.mock/app'))).resolves.toBe( undefined ); - await expect(getTenantId(new URL('https://user.logto.mock/app'))).resolves.toBe(undefined); - await expect(getTenantId(new URL('https://user.logto.mock/app/admin'))).resolves.toBe('admin'); + await expect( + getTenantIdFirstElement(new URL('https://user.admin.logto.mock/app//')) + ).resolves.toBe(undefined); + await expect(getTenantIdFirstElement(new URL('https://user.logto.mock/app'))).resolves.toBe( + undefined + ); + await expect( + getTenantIdFirstElement(new URL('https://user.logto.mock/app/admin')) + ).resolves.toBe('admin'); }); it('should resolve proper custom domain', async () => { @@ -151,6 +176,6 @@ describe('getTenantId()', () => { NODE_ENV: 'production', }; findActiveDomain.mockResolvedValueOnce({ domain: 'logto.mock.com', tenantId: 'mock' }); - await expect(getTenantId(new URL('https://logto.mock.com'))).resolves.toBe('mock'); + await expect(getTenantIdFirstElement(new URL('https://logto.mock.com'))).resolves.toBe('mock'); }); }); diff --git a/packages/core/src/utils/tenant.ts b/packages/core/src/utils/tenant.ts index d33c66e42..7223265de 100644 --- a/packages/core/src/utils/tenant.ts +++ b/packages/core/src/utils/tenant.ts @@ -53,6 +53,9 @@ export const clearCustomDomainCache = async (url: URL | string) => { await trySafe(async () => redisCache.delete(getDomainCacheKey(url))); }; +/** + * Get tenant ID from the custom domain URL. + */ const getTenantIdFromCustomDomain = async ( url: URL, pool: CommonQueryMethods @@ -74,7 +77,15 @@ const getTenantIdFromCustomDomain = async ( return domain?.tenantId; }; -export const getTenantId = async (url: URL) => { +/** + * Get tenant ID from the current request's URL. + * + * @param url The current request's URL + * @returns The tenant ID and whether the URL is a custom domain + */ +export const getTenantId = async ( + url: URL +): Promise<[tenantId: string | undefined, isCustomDomain: boolean]> => { const { values: { isMultiTenancy, @@ -90,28 +101,28 @@ export const getTenantId = async (url: URL) => { const pool = await sharedPool; if (adminUrlSet.deduplicated().some((endpoint) => isEndpointOf(url, endpoint))) { - return adminTenantId; + return [adminTenantId, false]; } if ((!isProduction || isIntegrationTest) && developmentTenantId) { consoleLog.warn(`Found dev tenant ID ${developmentTenantId}.`); - return developmentTenantId; + return [developmentTenantId, false]; } if (!isMultiTenancy) { - return defaultTenantId; + return [defaultTenantId, false]; } if (isPathBasedMultiTenancy) { - return matchPathBasedTenantId(urlSet, url); + return [matchPathBasedTenantId(urlSet, url), false]; } const customDomainTenantId = await getTenantIdFromCustomDomain(url, pool); if (customDomainTenantId) { - return customDomainTenantId; + return [customDomainTenantId, true]; } - return matchDomainBasedTenantId(urlSet.endpoint, url); + return [matchDomainBasedTenantId(urlSet.endpoint, url), false]; };