diff --git a/packages/core/src/app/init.ts b/packages/core/src/app/init.ts index bdb5e8aa6..51dac8475 100644 --- a/packages/core/src/app/init.ts +++ b/packages/core/src/app/init.ts @@ -29,7 +29,7 @@ export default async function initApp(app: Koa): Promise { return next(); } - const tenantId = getTenantId(ctx.URL); + const tenantId = await getTenantId(ctx.URL); if (!tenantId) { ctx.status = 404; diff --git a/packages/core/src/caches/index.ts b/packages/core/src/caches/index.ts index 55ec79b02..448557eba 100644 --- a/packages/core/src/caches/index.ts +++ b/packages/core/src/caches/index.ts @@ -23,9 +23,9 @@ export class RedisCache implements CacheStore { } } - async set(key: string, value: string) { + async set(key: string, value: string, expire: number = 30 * 60) { await this.client?.set(key, value, { - EX: 30 * 60 /* 30 minutes */, + EX: expire, }); } diff --git a/packages/core/src/middleware/koa-spa-session-guard.ts b/packages/core/src/middleware/koa-spa-session-guard.ts index bb6784c1d..0d0177bb5 100644 --- a/packages/core/src/middleware/koa-spa-session-guard.ts +++ b/packages/core/src/middleware/koa-spa-session-guard.ts @@ -49,7 +49,7 @@ export default function koaSpaSessionGuard< return; } - const tenantId = getTenantId(ctx.URL); + const tenantId = await getTenantId(ctx.URL); if (!tenantId) { throw new RequestError({ code: 'session.not_found', status: 404 }); diff --git a/packages/core/src/queries/domains.ts b/packages/core/src/queries/domains.ts index ed5af66fc..68875f989 100644 --- a/packages/core/src/queries/domains.ts +++ b/packages/core/src/queries/domains.ts @@ -1,5 +1,4 @@ -import type { CreateDomain, Domain } from '@logto/schemas'; -import { Domains } from '@logto/schemas'; +import { type CreateDomain, type Domain, DomainStatus, Domains } from '@logto/schemas'; import type { OmitAutoSetFields } from '@logto/shared'; import { convertToIdentifiers, manyRows } from '@logto/shared'; import type { CommonQueryMethods } from 'slonik'; @@ -23,6 +22,14 @@ export const createDomainsQueries = (pool: CommonQueryMethods) => { const findDomainById = buildFindEntityByIdWithPool(pool)(Domains); + const findActiveDomain = async (domain: string) => + pool.maybeOne(sql` + select ${sql.join(Object.values(fields), sql`, `)} + from ${table} + where ${fields.domain}=${domain} + and ${fields.status}=${DomainStatus.Active} + `); + const insertDomain = buildInsertIntoWithPool(pool)(Domains, { returning: true, }); @@ -49,6 +56,7 @@ export const createDomainsQueries = (pool: CommonQueryMethods) => { return { findAllDomains, findDomainById, + findActiveDomain, insertDomain, updateDomainById, deleteDomainById, diff --git a/packages/core/src/routes/interaction/actions/submit-interaction.ts b/packages/core/src/routes/interaction/actions/submit-interaction.ts index d89d2f852..99906ae3a 100644 --- a/packages/core/src/routes/interaction/actions/submit-interaction.ts +++ b/packages/core/src/routes/interaction/actions/submit-interaction.ts @@ -185,7 +185,7 @@ export default async function submitInteraction( const { client_id } = ctx.interactionDetails.params; const { isCloud } = EnvSet.values; - const isInAdminTenant = getTenantId(ctx.URL) === adminTenantId; + const isInAdminTenant = (await getTenantId(ctx.URL)) === adminTenantId; const isCreatingFirstAdminUser = isInAdminTenant && String(client_id) === adminConsoleApplicationId && diff --git a/packages/core/src/routes/user-assets.ts b/packages/core/src/routes/user-assets.ts index 03b223519..745d382e9 100644 --- a/packages/core/src/routes/user-assets.ts +++ b/packages/core/src/routes/user-assets.ts @@ -62,7 +62,7 @@ export default function userAssetsRoutes(...[router]: Ro 'guard.mime_type_not_allowed' ); - const tenantId = getTenantId(ctx.URL); + const tenantId = await getTenantId(ctx.URL); assertThat(tenantId, 'guard.can_not_get_tenant_id'); const { storageProviderConfig } = SystemContext.shared; diff --git a/packages/core/src/utils/tenant.test.ts b/packages/core/src/utils/tenant.test.ts index 2eb368980..fac175546 100644 --- a/packages/core/src/utils/tenant.test.ts +++ b/packages/core/src/utils/tenant.test.ts @@ -4,7 +4,7 @@ import { createMockUtils } from '@logto/shared/esm'; const { jest } = import.meta; -const { mockEsmWithActual } = createMockUtils(jest); +const { mockEsmWithActual, mockEsm } = createMockUtils(jest); await mockEsmWithActual('#src/env-set/index.js', () => ({ EnvSet: { @@ -14,6 +14,13 @@ await mockEsmWithActual('#src/env-set/index.js', () => ({ }, })); +const findActiveDomain = jest.fn(); +mockEsm('#src/queries/domains.js', () => ({ + createDomainsQueries: () => ({ + findActiveDomain, + }), +})); + const { getTenantId } = await import('./tenant.js'); describe('getTenantId()', () => { @@ -30,7 +37,7 @@ describe('getTenantId()', () => { DEVELOPMENT_TENANT_ID: 'foo', }; - expect(getTenantId(new URL('https://some.random.url'))).toBe('foo'); + await expect(getTenantId(new URL('https://some.random.url'))).resolves.toBe('foo'); process.env = { ...backupEnv, @@ -39,14 +46,20 @@ describe('getTenantId()', () => { DEVELOPMENT_TENANT_ID: 'bar', }; - expect(getTenantId(new URL('https://some.random.url'))).toBe('bar'); + await expect(getTenantId(new URL('https://some.random.url'))).resolves.toBe('bar'); }); it('should resolve proper tenant ID for similar localhost endpoints', async () => { - expect(getTenantId(new URL('http://localhost:3002/some/path////'))).toBe(adminTenantId); - expect(getTenantId(new URL('http://localhost:30021/some/path'))).toBe(defaultTenantId); - expect(getTenantId(new URL('http://localhostt:30021/some/path'))).toBe(defaultTenantId); - expect(getTenantId(new URL('https://localhost:3002'))).toBe(defaultTenantId); + await expect(getTenantId(new URL('http://localhost:3002/some/path////'))).resolves.toBe( + adminTenantId + ); + await expect(getTenantId(new URL('http://localhost:30021/some/path'))).resolves.toBe( + defaultTenantId + ); + await expect(getTenantId(new URL('http://localhostt:30021/some/path'))).resolves.toBe( + defaultTenantId + ); + await expect(getTenantId(new URL('https://localhost:3002'))).resolves.toBe(defaultTenantId); }); it('should resolve proper tenant ID for similar domain endpoints', async () => { @@ -56,14 +69,24 @@ describe('getTenantId()', () => { ENDPOINT: 'https://foo.*.logto.mock/app', }; - expect(getTenantId(new URL('https://foo.foo.logto.mock/app///asdasd'))).toBe('foo'); - expect(getTenantId(new URL('https://foo.*.logto.mock/app'))).toBe(undefined); - expect(getTenantId(new URL('https://foo.foo.logto.mockk/app///asdasd'))).toBe(undefined); - expect(getTenantId(new URL('https://foo.foo.logto.mock/appp'))).toBe(undefined); - expect(getTenantId(new URL('https://foo.foo.logto.mock:1/app/'))).toBe(undefined); - expect(getTenantId(new URL('http://foo.foo.logto.mock/app'))).toBe(undefined); - expect(getTenantId(new URL('https://user.foo.bar.logto.mock/app'))).toBe(undefined); - expect(getTenantId(new URL('https://foo.bar.bar.logto.mock/app'))).toBe(undefined); + await expect(getTenantId(new URL('https://foo.foo.logto.mock/app///asdasd'))).resolves.toBe( + 'foo' + ); + await expect(getTenantId(new URL('https://foo.*.logto.mock/app'))).resolves.toBe(undefined); + await expect(getTenantId(new URL('https://foo.foo.logto.mockk/app///asdasd'))).resolves.toBe( + undefined + ); + await expect(getTenantId(new URL('https://foo.foo.logto.mock/appp'))).resolves.toBe(undefined); + await expect(getTenantId(new URL('https://foo.foo.logto.mock:1/app/'))).resolves.toBe( + undefined + ); + await expect(getTenantId(new URL('http://foo.foo.logto.mock/app'))).resolves.toBe(undefined); + await expect(getTenantId(new URL('https://user.foo.bar.logto.mock/app'))).resolves.toBe( + undefined + ); + await expect(getTenantId(new URL('https://foo.bar.bar.logto.mock/app'))).resolves.toBe( + undefined + ); }); it('should resolve proper tenant ID if admin localhost is disabled', async () => { @@ -76,11 +99,17 @@ describe('getTenantId()', () => { ADMIN_DISABLE_LOCALHOST: '1', }; - expect(getTenantId(new URL('http://localhost:5000/app///asdasd'))).toBe(undefined); - expect(getTenantId(new URL('http://localhost:3002/app///asdasd'))).toBe(undefined); - expect(getTenantId(new URL('https://user.foo.logto.mock/app'))).toBe('foo'); - expect(getTenantId(new URL('https://user.admin.logto.mock/app//'))).toBe(undefined); // Admin endpoint is explicitly set - expect(getTenantId(new URL('https://admin.logto.mock/app'))).toBe(adminTenantId); + await expect(getTenantId(new URL('http://localhost:5000/app///asdasd'))).resolves.toBe( + undefined + ); + await expect(getTenantId(new URL('http://localhost:3002/app///asdasd'))).resolves.toBe( + undefined + ); + await expect(getTenantId(new URL('https://user.foo.logto.mock/app'))).resolves.toBe('foo'); + await expect(getTenantId(new URL('https://user.admin.logto.mock/app//'))).resolves.toBe( + undefined + ); // Admin endpoint is explicitly set + await expect(getTenantId(new URL('https://admin.logto.mock/app'))).resolves.toBe(adminTenantId); process.env = { ...backupEnv, @@ -89,7 +118,9 @@ describe('getTenantId()', () => { ENDPOINT: 'https://user.*.logto.mock/app', ADMIN_DISABLE_LOCALHOST: '1', }; - expect(getTenantId(new URL('https://user.admin.logto.mock/app//'))).toBe('admin'); + await expect(getTenantId(new URL('https://user.admin.logto.mock/app//'))).resolves.toBe( + 'admin' + ); }); it('should resolve proper tenant ID for path-based multi-tenancy', async () => { @@ -101,11 +132,25 @@ describe('getTenantId()', () => { PATH_BASED_MULTI_TENANCY: '1', }; - expect(getTenantId(new URL('http://localhost:5000/app///asdasd'))).toBe('app'); - expect(getTenantId(new URL('http://localhost:3002///bar///asdasd'))).toBe(adminTenantId); - expect(getTenantId(new URL('https://user.foo.logto.mock/app'))).toBe(undefined); - expect(getTenantId(new URL('https://user.admin.logto.mock/app//'))).toBe(undefined); - expect(getTenantId(new URL('https://user.logto.mock/app'))).toBe(undefined); - expect(getTenantId(new URL('https://user.logto.mock/app/admin'))).toBe('admin'); + await expect(getTenantId(new URL('http://localhost:5000/app///asdasd'))).resolves.toBe('app'); + await expect(getTenantId(new URL('http://localhost:3002///bar///asdasd'))).resolves.toBe( + adminTenantId + ); + await expect(getTenantId(new URL('https://user.foo.logto.mock/app'))).resolves.toBe(undefined); + await expect(getTenantId(new URL('https://user.admin.logto.mock/app//'))).resolves.toBe( + undefined + ); + await expect(getTenantId(new URL('https://user.logto.mock/app'))).resolves.toBe(undefined); + await expect(getTenantId(new URL('https://user.logto.mock/app/admin'))).resolves.toBe('admin'); + }); + + it('should resolve proper custom domain', async () => { + process.env = { + ...backupEnv, + ENDPOINT: 'https://foo.*.logto.mock/app', + NODE_ENV: 'production', + }; + findActiveDomain.mockResolvedValueOnce({ domain: 'logto.mock.com', tenantId: 'mock' }); + await expect(getTenantId(new URL('https://logto.mock.com'))).resolves.toBe('mock'); }); }); diff --git a/packages/core/src/utils/tenant.ts b/packages/core/src/utils/tenant.ts index 8c045ece4..75c790944 100644 --- a/packages/core/src/utils/tenant.ts +++ b/packages/core/src/utils/tenant.ts @@ -1,10 +1,12 @@ import { adminTenantId, defaultTenantId } from '@logto/schemas'; -import type { UrlSet } from '@logto/shared'; -import { conditionalString } from '@silverhand/essentials'; +import { type UrlSet } from '@logto/shared'; +import { conditionalString, trySafe } from '@silverhand/essentials'; +import { type CommonQueryMethods } from 'slonik'; +import { redisCache } from '#src/caches/index.js'; import { EnvSet, getTenantEndpoint } from '#src/env-set/index.js'; - -import { consoleLog } from './console.js'; +import { createDomainsQueries } from '#src/queries/domains.js'; +import { consoleLog } from '#src/utils/console.js'; const normalizePathname = (pathname: string) => pathname + conditionalString(!pathname.endsWith('/') && '/'); @@ -43,16 +45,45 @@ const matchPathBasedTenantId = (urlSet: UrlSet, url: URL) => { return urlSegments[found.pathname === '/' ? 1 : endpointSegments.length]; }; -export const getTenantId = (url: URL) => { +const cacheKey = 'custom-domain'; +const notFoundValue = 'not-found'; +const getDomainCacheKey = (url: URL) => `${cacheKey}:${url.hostname}`; + +const getTenantIdFromCustomDomain = async ( + url: URL, + pool: CommonQueryMethods +): Promise => { + const cachedValue = await trySafe(async () => redisCache.get(getDomainCacheKey(url))); + + if (cachedValue) { + return cachedValue === notFoundValue ? undefined : cachedValue; + } + + const { findActiveDomain } = createDomainsQueries(pool); + + const domain = await findActiveDomain(url.hostname); + + await trySafe(async () => + redisCache.set(getDomainCacheKey(url), domain?.tenantId ?? notFoundValue, 60) + ); + + return domain?.tenantId; +}; + +export const getTenantId = async (url: URL) => { const { - isMultiTenancy, - isPathBasedMultiTenancy, - isProduction, - isIntegrationTest, - developmentTenantId, - urlSet, - adminUrlSet, - } = EnvSet.values; + values: { + isMultiTenancy, + isPathBasedMultiTenancy, + isProduction, + isIntegrationTest, + developmentTenantId, + urlSet, + adminUrlSet, + }, + sharedPool, + } = EnvSet; + const pool = await sharedPool; if (adminUrlSet.deduplicated().some((endpoint) => isEndpointOf(url, endpoint))) { return adminTenantId; @@ -72,5 +103,11 @@ export const getTenantId = (url: URL) => { return matchPathBasedTenantId(urlSet, url); } + const customDomainTenantId = await getTenantIdFromCustomDomain(url, pool); + + if (customDomainTenantId) { + return customDomainTenantId; + } + return matchDomainBasedTenantId(urlSet.endpoint, url); };