0
Fork 0
mirror of https://github.com/logto-io/logto.git synced 2025-01-13 21:30:30 -05:00

feat(core): add token usage guard (#6877)

* feat(core): add token usage guard

add token usage guard

* test(core): add unit test

add unit test

* refactor(core): update the token usage cache strategy

udpate the token usage cache strategy

* fix(core): fix unit test

fix unit test
This commit is contained in:
simeng-li 2024-12-20 13:51:47 +08:00 committed by GitHub
parent 588ed34e12
commit ef795299ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 455 additions and 11 deletions

View file

@ -1,4 +1,7 @@
import { ReservedPlanId } from '@logto/schemas';
import { type CloudConnectionLibrary } from '#src/libraries/cloud-connection.js';
import { type Subscription } from '#src/utils/subscription/types.js';
export const mockGetCloudConnectionData: CloudConnectionLibrary['getCloudConnectionData'] =
async () => ({
@ -8,3 +11,39 @@ export const mockGetCloudConnectionData: CloudConnectionLibrary['getCloudConnect
endpoint: 'https://logto.dev/api',
tokenEndpoint: 'https://logto.dev/oidc/token',
});
export const mockQuota = {
mauLimit: 50_000,
tokenLimit: 10_000,
applicationsLimit: 3,
machineToMachineLimit: 1,
resourcesLimit: 1,
scopesPerResourceLimit: 1,
socialConnectorsLimit: 3,
userRolesLimit: 1,
machineToMachineRolesLimit: 1,
scopesPerRoleLimit: 1,
hooksLimit: 1,
auditLogsRetentionDays: 3,
mfaEnabled: false,
/** @deprecated */
organizationsEnabled: false,
organizationsLimit: 0,
enterpriseSsoLimit: 0,
thirdPartyApplicationsLimit: 0,
tenantMembersLimit: 1,
customJwtEnabled: false,
subjectTokenEnabled: false,
bringYourUiEnabled: false,
idpInitiatedSsoEnabled: false,
};
export const mockSubscriptionData: Subscription = {
id: 'sub_123',
currentPeriodEnd: '2022-01-01T00:00:00Z',
currentPeriodStart: '2021-12-01T00:00:00Z',
planId: ReservedPlanId.Free,
isEnterprisePlan: false,
quota: mockQuota,
status: 'active',
};

View file

@ -147,7 +147,9 @@ export abstract class BaseCache<CacheMapT extends Record<string, unknown>> {
const cachedValue = await trySafe(kvCache.get(type, promiseKey));
if (cachedValue) {
cacheConsole.info(`${kvCache.name} cache hit for', type, promiseKey`);
cacheConsole.info(
`${kvCache.name} cache hit for, ${kvCache.tenantId}, ${type}, ${promiseKey}`
);
return cachedValue;
}

View file

@ -23,10 +23,8 @@ function getValueGuard(type: SubscriptionCacheType): ZodType<SubscriptionCacheMa
* A local region cache for tenant subscription data.
* We use this cache to reduce the number of requests to the Cloud
* and improve the performance of subscription-related operations.
*
* TODO: Will use the cache for tenant subscription data.
*/
class TenantSubscriptionCache extends BaseCache<SubscriptionCacheMap> {
export class TenantSubscriptionCache extends BaseCache<SubscriptionCacheMap> {
name = 'Tenant Subscription';
getValueGuard = getValueGuard;
}

View file

@ -0,0 +1,172 @@
import { ReservedPlanId } from '@logto/schemas';
import { createMockUtils } from '@logto/shared/esm';
import { mockSubscriptionData } from '#src/__mocks__/cloud-connection.js';
const { jest } = import.meta;
const { mockEsmWithActual } = createMockUtils(jest);
const mockGetTenantSubscription = jest.fn();
const mockCountTokenUsage = jest.fn();
const now = new Date();
// Set the current period end to 1 day from now
const currentPeriodEnd = new Date(now.getTime() + 1000 * 60 * 60 * 24);
const mockSubscription = {
...mockSubscriptionData,
currentPeriodEnd: currentPeriodEnd.toISOString(),
};
await mockEsmWithActual('#src/utils/subscription/index.js', () => ({
getTenantSubscription: mockGetTenantSubscription,
}));
const { MockTenant } = await import('#src/test-utils/tenant.js');
describe('get subscription data', () => {
const { subscription } = new MockTenant(undefined);
it('should get subscription data', async () => {
mockGetTenantSubscription.mockResolvedValueOnce(mockSubscription);
const subscriptionData = await subscription.getSubscriptionData();
expect(subscriptionData).toEqual(mockSubscription);
});
it('should get subscription data from cache', async () => {
mockGetTenantSubscription.mockClear();
const subscriptionDataFromCache = await subscription.getSubscriptionData();
expect(subscriptionDataFromCache).toEqual(mockSubscription);
expect(mockGetTenantSubscription).not.toHaveBeenCalled();
});
});
describe('get subscription data with cache expiration', () => {
const { subscription } = new MockTenant(undefined);
beforeAll(() => {
jest.useFakeTimers();
});
afterAll(() => {
jest.useRealTimers();
});
it('should get new subscription data if cache is expired', async () => {
mockGetTenantSubscription.mockResolvedValueOnce(mockSubscription);
const subscriptionData = await subscription.getSubscriptionData();
expect(subscriptionData).toEqual(mockSubscription);
// Move the time to 1 hour later
// In Unit test we use ttlCache instead of redis cache
// The ttl time unit is in milliseconds instead of seconds, so we do not need to multiply by 1000
jest.advanceTimersByTime(60 * 60);
mockGetTenantSubscription.mockClear();
// Should hit the cache
const subscriptionDataFromCache = await subscription.getSubscriptionData();
expect(subscriptionDataFromCache).toEqual(mockSubscription);
// Move the time to 1 day later
jest.advanceTimersByTime(60 * 60 * 24);
mockGetTenantSubscription.mockResolvedValueOnce({
...mockSubscriptionData,
planId: ReservedPlanId.Pro202411,
});
// Should get new subscription data
const refreshedSubscriptionData = await subscription.getSubscriptionData();
expect(refreshedSubscriptionData).toEqual({
...mockSubscriptionData,
planId: ReservedPlanId.Pro202411,
});
expect(mockGetTenantSubscription).toHaveBeenCalled();
});
});
describe('get tenant token usage', () => {
const { subscription } = new MockTenant(undefined, {
dailyTokenUsage: {
countTokenUsage: mockCountTokenUsage,
},
});
const from = new Date();
const to = new Date(from.valueOf() + 1000 * 60 * 60 * 24);
it('should get tenant token usage without cache', async () => {
mockCountTokenUsage.mockResolvedValueOnce({ tokenUsage: 100 });
const tokenUsage = await subscription.getTenantTokenUsage({
from,
to,
});
expect(tokenUsage).toBe(100);
});
it('should get tenant token usage from cache', async () => {
mockCountTokenUsage.mockClear();
const tokenUsageFromCache = await subscription.getTenantTokenUsage({
from,
to,
});
expect(tokenUsageFromCache).toBe(100);
expect(mockCountTokenUsage).not.toHaveBeenCalled();
});
it('should get new tenant token usage if the period is different', async () => {
mockCountTokenUsage.mockResolvedValueOnce({ tokenUsage: 200 });
const tokenUsage = await subscription.getTenantTokenUsage({
from,
to: new Date(to.valueOf() + 1000 * 60 * 60 * 24),
});
expect(tokenUsage).toBe(200);
expect(mockCountTokenUsage).toHaveBeenCalled();
});
});
describe('get tenant token usage with cache expiration', () => {
beforeAll(() => {
jest.useFakeTimers();
});
afterAll(() => {
jest.useRealTimers();
});
const tokenUsageCacheTtl = 60 * 60 * 1000; // 1 hour
const from = new Date();
const to = new Date(from.valueOf() + 1000 * 60 * 60 * 24);
it('should get new tenant token usage if cache is expired', async () => {
const { subscription } = new MockTenant(undefined, {
dailyTokenUsage: {
countTokenUsage: mockCountTokenUsage,
},
});
mockCountTokenUsage.mockResolvedValueOnce({ tokenUsage: 100 });
const tokenUsage = await subscription.getTenantTokenUsage({
from,
to,
});
expect(tokenUsage).toBe(100);
// Move the time to 30 minutes later
mockCountTokenUsage.mockClear();
jest.advanceTimersByTime(tokenUsageCacheTtl / 2);
const tokenUsageFromCache = await subscription.getTenantTokenUsage({
from,
to,
});
expect(tokenUsageFromCache).toBe(100);
expect(mockCountTokenUsage).not.toHaveBeenCalled();
// Move the time to 1 hour later
mockCountTokenUsage.mockResolvedValueOnce({ tokenUsage: 200 });
jest.advanceTimersByTime(tokenUsageCacheTtl / 2 + 1);
const refreshedTokenUsage = await subscription.getTenantTokenUsage({
from,
to,
});
expect(refreshedTokenUsage).toBe(200);
expect(mockCountTokenUsage).toHaveBeenCalled();
});
});

View file

@ -0,0 +1,111 @@
import { SubscriptionRedisCacheKey } from '@logto/schemas';
import { TtlCache } from '@logto/shared';
import { TenantSubscriptionCache } from '#src/caches/tenant-subscription.js';
import { type CacheStore } from '#src/caches/types.js';
import { cacheConsole } from '#src/caches/utils.js';
import type Queries from '#src/tenants/Queries.js';
import { getTenantSubscription } from '#src/utils/subscription/index.js';
import { type Subscription } from '#src/utils/subscription/types.js';
import { type CloudConnectionLibrary } from './cloud-connection.js';
/**
* Return the expiration time of the subscription cache in seconds.
*
* @param currentPeriodEnd The end date of the current subscription period.
*/
const getSubscriptionCacheExpiration = (currentPeriodEnd: string) => {
const expiration = Math.floor((new Date(currentPeriodEnd).getTime() - Date.now()) / 1000);
return Math.max(expiration, 0);
};
const tokenUsageCacheTtl = 60 * 60 * 1000; // 1 hour
/**
*
* @param to The end date of the token usage period.
*
* @returns The TTL for the token usage cache in milliseconds.
*
* @remarks
* - A maximum TTL of 1 hour is set for the token usage cache.
* - If the token usage period ends is more than an hour from now, the TTL will be 1 hour.
* - If the token usage period ends is less than an hour from now, the TTL will be the difference between the end date and now.
* - This is to ensure that the cache is invalidated immediately after the token usage period ends.
*/
const getTokenUsageCacheTtl = (to: Date) => {
const expiration = Math.floor(to.getTime() - Date.now());
return Math.min(expiration, tokenUsageCacheTtl);
};
export class SubscriptionLibrary {
/**
* Get the subscription data for the tenant with caching.
*
* @remarks
* This method will retrieve the subscription data (without usages) from the Cloud service
* with redis caching.
*
* - The cache will be automatically invalidated when the subscription period ends.
* - Any tenant subscription updates at the Cloud service side will also invalidate the cache.
*/
public readonly getSubscriptionData: () => Promise<Subscription>;
/**
* Tenant subscription data redis cache.
*/
private readonly subscriptionCache;
/**
* Tenant token usage TtlCache
* We use this to reduce the token usage calculation queries.
* Each token request will trigger a token usage validation.
* We don't want to calculate the latest token usage for each request.
* Using this cache, we can reduce the number of queries to the database.
*/
private readonly tokenUsageCache = new TtlCache<string, number>(tokenUsageCacheTtl);
constructor(
public readonly tenantId: string,
public readonly queries: Queries,
public readonly cloudConnection: CloudConnectionLibrary,
cache: CacheStore
) {
this.subscriptionCache = new TenantSubscriptionCache(tenantId, cache);
this.getSubscriptionData = this.subscriptionCache.memoize(
async () => getTenantSubscription(this.cloudConnection),
[SubscriptionRedisCacheKey.Subscription],
({ currentPeriodEnd }) => getSubscriptionCacheExpiration(currentPeriodEnd)
);
}
/**
* Get the tenant token usage for the given period.
* This method will use the local TTL cache to reduce the number of queries to the database.
* The cache will be invalidated every hour.
*/
public async getTenantTokenUsage({ from, to }: { from: Date; to: Date }) {
const cacheKey = this.buildTokenUsageKey({ tenantId: this.tenantId, from, to });
const cachedValue = this.tokenUsageCache.get(cacheKey);
if (cachedValue !== undefined) {
cacheConsole.info(`Tenant token usage TTL cache hit for: ${cacheKey}`);
return cachedValue;
}
const { tokenUsage } = await this.queries.dailyTokenUsage.countTokenUsage({
from,
to,
});
this.tokenUsageCache.set(cacheKey, tokenUsage, getTokenUsageCacheTtl(to));
return tokenUsage;
}
private buildTokenUsageKey({ tenantId, from, to }: { tenantId: string; from: Date; to: Date }) {
return `${tenantId}:${from.toISOString().split('T')[0]}:${
to.toISOString().split('T')[0]
}:token-usage`;
}
}

View file

@ -0,0 +1,78 @@
import { appInsights } from '@logto/app-insights/node';
import { adminTenantId, ReservedPlanId } from '@logto/schemas';
import { type Nullable } from '@silverhand/essentials';
import { type MiddlewareType } from 'koa';
import RequestError from '#src/errors/RequestError/index.js';
import { type SubscriptionLibrary } from '#src/libraries/subscription.js';
import assertThat from '#src/utils/assert-that.js';
import { buildAppInsightsTelemetry } from '#src/utils/request.js';
const guardedPlanIds = new Set<string>([ReservedPlanId.Free, ReservedPlanId.Development]);
/**
* This middleware will be applied to the /token endpoint to validate the current tenant's token usage.
* If the tenant has exceeded the token usage, the middleware will reject the request.
*/
export default function koaTokenUsageGuard<StateT, ContextT, ResponseBodyT>(
subscriptionLibrary: SubscriptionLibrary
): MiddlewareType<StateT, ContextT, Nullable<ResponseBodyT>> {
return async (ctx, next) => {
const { path } = ctx;
if (path !== '/token') {
return next();
}
/**
* Skip the token usage guard for the admin tenant.
*
* Notice:
* The token usage guard is skipped for the admin tenant.
* This is because the admin tenant has no token limit,
* and the cloud connection API needs to retrieve the access token for the admin tenant,
* to make requests to the cloud service. Checking the token usage for the admin tenant
* will result in an infinite loop.
*/
if (subscriptionLibrary.tenantId === adminTenantId) {
return next();
}
try {
const {
planId,
currentPeriodEnd,
currentPeriodStart,
quota: { tokenLimit },
} = await subscriptionLibrary.getSubscriptionData();
if (!guardedPlanIds.has(planId)) {
await next();
return;
}
const tokenUsage = await subscriptionLibrary.getTenantTokenUsage({
from: new Date(currentPeriodStart),
to: new Date(currentPeriodEnd),
});
assertThat(
tokenLimit === null || tokenUsage < tokenLimit,
new RequestError({
code: 'auth.exceed_token_limit',
status: 429,
})
);
} catch (error: unknown) {
if (error instanceof RequestError) {
throw error;
}
// Incase of any unexpected error, track it to App Insights and continue the request.
// Should not block the end-user's request for any unexpected error.
void appInsights.trackException(error, buildAppInsightsTelemetry(ctx));
}
return next();
};
}

View file

@ -5,10 +5,10 @@ import initOidc from './init.js';
describe('oidc provider init', () => {
it('init should not throw', async () => {
const { queries, libraries, logtoConfigs, cloudConnection } = new MockTenant();
const { queries, libraries, logtoConfigs, cloudConnection, subscription } = new MockTenant();
expect(() =>
initOidc(mockEnvSet, queries, libraries, logtoConfigs, cloudConnection)
initOidc(mockEnvSet, queries, libraries, logtoConfigs, cloudConnection, subscription)
).not.toThrow();
});
});

View file

@ -22,7 +22,7 @@ import { Provider, errors } from 'oidc-provider';
import getRawBody from 'raw-body';
import snakecaseKeys from 'snakecase-keys';
import { type EnvSet } from '#src/env-set/index.js';
import { EnvSet } from '#src/env-set/index.js';
import { addOidcEventListeners } from '#src/event-listeners/index.js';
import { type CloudConnectionLibrary } from '#src/libraries/cloud-connection.js';
import { type LogtoConfigLibrary } from '#src/libraries/logto-config.js';
@ -39,6 +39,9 @@ import {
import type Libraries from '#src/tenants/Libraries.js';
import type Queries from '#src/tenants/Queries.js';
import { type SubscriptionLibrary } from '../libraries/subscription.js';
import koaTokenUsageGuard from '../middleware/koa-token-usage-guard.js';
import defaults from './defaults.js';
import {
getExtraTokenClaimsForJwtCustomization,
@ -63,7 +66,8 @@ export default function initOidc(
queries: Queries,
libraries: Libraries,
logtoConfigs: LogtoConfigLibrary,
cloudConnection: CloudConnectionLibrary
cloudConnection: CloudConnectionLibrary,
subscription: SubscriptionLibrary
): Provider {
const {
resources: { findDefaultResource },
@ -414,6 +418,12 @@ export default function initOidc(
oidc.use(koaAppSecretTranspilation(queries));
oidc.use(koaBodyEtag());
// TODO: Remove the devFeature guard when the implementation is stable
// Only enabled in the cloud environment
if (EnvSet.values.isDevFeaturesEnabled && EnvSet.values.isCloud) {
oidc.use(koaTokenUsageGuard(subscription));
}
return oidc;
}
/* eslint-enable max-lines */

View file

@ -41,7 +41,21 @@ export const createDailyTokenUsageQueries = (pool: CommonQueryMethods) => {
returning ${sql.join(Object.values(fields), sql`, `)}
`);
const countTokenUsage = async ({ from, to }: { from: Date; to: Date }) => {
return pool.one<{ tokenUsage: number }>(sql`
select sum(${fields.usage}) as token_usage
from ${table}
where ${fields.date} >= to_timestamp(${getUtcStartOfTheDay(
from
).getTime()}::double precision / 1000)
and ${fields.date} < to_timestamp(${getUtcStartOfTheDay(
to
).getTime()}::double precision / 1000)
`);
};
return {
recordTokenUsage,
countTokenUsage,
};
};

View file

@ -30,6 +30,9 @@ import initApis from '#src/routes/init.js';
import initMeApis from '#src/routes-me/init.js';
import BasicSentinel from '#src/sentinel/basic-sentinel.js';
import { redisCache } from '../caches/index.js';
import { SubscriptionLibrary } from '../libraries/subscription.js';
import Libraries from './Libraries.js';
import Queries from './Queries.js';
import type TenantContext from './TenantContext.js';
@ -89,7 +92,8 @@ export default class Tenant implements TenantContext {
cloudConnection,
logtoConfigs
),
public readonly sentinel = new BasicSentinel(envSet.pool)
public readonly sentinel = new BasicSentinel(envSet.pool),
public readonly subscription = new SubscriptionLibrary(id, queries, cloudConnection, redisCache)
) {
const isAdminTenant = id === adminTenantId;
const mountedApps = [
@ -111,7 +115,14 @@ export default class Tenant implements TenantContext {
app.use(koaSecurityHeaders(mountedApps, id));
// Mount OIDC
const provider = initOidc(envSet, queries, libraries, logtoConfigs, cloudConnection);
const provider = initOidc(
envSet,
queries,
libraries,
logtoConfigs,
cloudConnection,
subscription
);
app.use(mount('/oidc', provider.app));
const tenantContext: TenantContext = {

View file

@ -12,6 +12,8 @@ import Libraries from '#src/tenants/Libraries.js';
import Queries from '#src/tenants/Queries.js';
import type TenantContext from '#src/tenants/TenantContext.js';
import { SubscriptionLibrary } from '../libraries/subscription.js';
import { mockEnvSet } from './env-set.js';
import type { GrantMock } from './oidc-provider.js';
import { createMockProvider } from './oidc-provider.js';
@ -67,6 +69,7 @@ export class MockTenant implements TenantContext {
public connectors: ConnectorLibrary;
public libraries: Libraries;
public sentinel: Sentinel;
public readonly subscription: SubscriptionLibrary;
// eslint-disable-next-line max-params
constructor(
@ -93,6 +96,12 @@ export class MockTenant implements TenantContext {
);
this.setPartial('libraries', librariesOverride);
this.sentinel = new MockSentinel();
this.subscription = new SubscriptionLibrary(
this.id,
this.queries,
this.cloudConnection,
new TtlCache<string, string>(60_000)
);
}
public async invalidateCache() {

View file

@ -113,7 +113,6 @@ export const subscriptionCacheGuard = z.object({
currentPeriodStart: z.string(),
currentPeriodEnd: z.string(),
isEnterprisePlan: z.boolean(),
isAddOnAvailable: z.boolean(),
status: subscriptionStatusGuard,
upcomingInvoice: upcomingInvoiceGuard.nullable().optional(),
quota: logtoSkuQuotaGuard,

View file

@ -6,6 +6,7 @@ const auth = {
expected_role_not_found: 'Expected role not found. Please check your user roles and permissions.',
jwt_sub_missing: 'Missing `sub` in JWT.',
require_re_authentication: 'Re-authentication is required to perform a protected action.',
exceed_token_limit: 'Token limit exceeded. Please contact your administrator.',
};
export default Object.freeze(auth);