0
Fork 0
mirror of https://github.com/logto-io/logto.git synced 2024-12-16 20:26:19 -05:00

feat(core): add token usage guard

add token usage guard
This commit is contained in:
simeng-li 2024-12-11 18:35:33 +08:00
parent c70e6ecfa7
commit 87a8687c98
No known key found for this signature in database
GPG key ID: 14EA7BB1541E8075
11 changed files with 231 additions and 11 deletions

View file

@ -147,7 +147,9 @@ export abstract class BaseCache<CacheMapT extends Record<string, unknown>> {
const cachedValue = await trySafe(kvCache.get(type, promiseKey));
if (cachedValue) {
cacheConsole.info(`${kvCache.name} cache hit for', type, promiseKey`);
cacheConsole.info(
`${kvCache.name} cache hit for, ${kvCache.tenantId}, ${type}, ${promiseKey}`
);
return cachedValue;
}

View file

@ -23,10 +23,8 @@ function getValueGuard(type: SubscriptionCacheType): ZodType<SubscriptionCacheMa
* A local region cache for tenant subscription data.
* We use this cache to reduce the number of requests to the Cloud
* and improve the performance of subscription-related operations.
*
* TODO: Will use the cache for tenant subscription data.
*/
class TenantSubscriptionCache extends BaseCache<SubscriptionCacheMap> {
export class TenantSubscriptionCache extends BaseCache<SubscriptionCacheMap> {
name = 'Tenant Subscription';
getValueGuard = getValueGuard;
}

View file

@ -0,0 +1,97 @@
import { SubscriptionRedisCacheKey } from '@logto/schemas';
import { TtlCache } from '@logto/shared';
import { TenantSubscriptionCache } from '#src/caches/tenant-subscription.js';
import { type CacheStore } from '#src/caches/types.js';
import { cacheConsole } from '#src/caches/utils.js';
import type Queries from '#src/tenants/Queries.js';
import { getTenantSubscription } from '#src/utils/subscription/index.js';
import { type Subscription } from '#src/utils/subscription/types.js';
import { type CloudConnectionLibrary } from './cloud-connection.js';
/**
* Return the expiration time of the subscription cache in seconds.
*
* @param currentPeriodEnd The end date of the current subscription period.
*/
const getSubscriptionCacheExpiration = (currentPeriodEnd: string) => {
const defaultExpiration = 60 * 60 * 24; // 1 day
const expiration = Math.floor((new Date(currentPeriodEnd).getTime() - Date.now()) / 1000);
return expiration > 0 ? expiration : defaultExpiration;
};
const tokenUsageCacheTtl = 60 * 60 * 1000; // 1 hour
export class SubscriptionLibrary {
/**
* Get the subscription data for the tenant with caching.
*
* @remarks
* This method will retrieve the subscription data (without usages) from the Cloud service
* with redis caching.
*
* - The cache will be automatically invalidated when the subscription period ends.
* - Any tenant subscription updates at the Cloud service side will also invalidate the cache.
*/
public readonly getSubscriptionData: () => Promise<Subscription>;
/**
* Tenant subscription data redis cache.
*/
private readonly subscriptionCache;
/**
* Tenant token usage TtlCache
* We use this to reduce the token usage calculation queries.
* Each token request will trigger a token usage validation.
* We don't want to calculate the latest token usage for each request.
* Using this cache, we can reduce the number of queries to the database.
*/
private readonly tokenUsageCache = new TtlCache<string, number>(tokenUsageCacheTtl);
constructor(
public readonly tenantId: string,
public readonly queries: Queries,
public readonly cloudConnection: CloudConnectionLibrary,
cache: CacheStore
) {
this.subscriptionCache = new TenantSubscriptionCache(tenantId, cache);
this.getSubscriptionData = this.subscriptionCache.memoize(
async () => getTenantSubscription(this.cloudConnection),
[SubscriptionRedisCacheKey.Subscription],
({ currentPeriodEnd }) => getSubscriptionCacheExpiration(currentPeriodEnd)
);
}
/**
* Get the tenant token usage for the given period.
* This method will use the local TTL cache to reduce the number of queries to the database.
* The cache will be invalidated every hour.
*/
public async getTenantTokenUsage({ from, to }: { from: Date; to: Date }) {
const cacheKey = this.buildTokenUsageKey({ tenantId: this.tenantId, from, to });
const cachedValue = this.tokenUsageCache.get(cacheKey);
if (cachedValue !== undefined) {
cacheConsole.info(`Tenant token usage TTL cache hit for: ${cacheKey}`);
return cachedValue;
}
const { tokenUsage } = await this.queries.dailyTokenUsage.countTokenUsage({
from,
to,
});
this.tokenUsageCache.set(cacheKey, tokenUsage);
return tokenUsage;
}
private buildTokenUsageKey({ tenantId, from, to }: { tenantId: string; from: Date; to: Date }) {
return `${tenantId}:${from.toISOString().split('T')[0]}:${
to.toISOString().split('T')[0]
}:token-usage`;
}
}

View file

@ -0,0 +1,78 @@
import { appInsights } from '@logto/app-insights/node';
import { adminTenantId, ReservedPlanId } from '@logto/schemas';
import { type Nullable } from '@silverhand/essentials';
import { type MiddlewareType } from 'koa';
import RequestError from '#src/errors/RequestError/index.js';
import { type SubscriptionLibrary } from '#src/libraries/subscription.js';
import assertThat from '#src/utils/assert-that.js';
import { buildAppInsightsTelemetry } from '#src/utils/request.js';
const guardedPlanIds = new Set<string>([ReservedPlanId.Free, ReservedPlanId.Development]);
/**
* This middleware will be applied to the /token endpoint to validate the current tenant's token usage.
* If the tenant has exceeded the token usage, the middleware will reject the request.
*/
export default function koaTokenUsageGuard<StateT, ContextT, ResponseBodyT>(
subscriptionLibrary: SubscriptionLibrary
): MiddlewareType<StateT, ContextT, Nullable<ResponseBodyT>> {
return async (ctx, next) => {
const { path } = ctx;
if (path !== '/token') {
return next();
}
/**
* Skip the token usage guard for the admin tenant.
*
* Notice:
* We need to skip the subscription check for the admin tenant.
* Not only because there is no token usage limit for the admin tenant,
* but also because cloud connection API need to retrieve the access token from the admin tenant,
* in order to make the request to the cloud service.
* Otherwise it will cause an infinite loop.
*/
if (subscriptionLibrary.tenantId === adminTenantId) {
return next();
}
try {
const {
planId,
currentPeriodEnd,
currentPeriodStart,
quota: { tokenLimit },
} = await subscriptionLibrary.getSubscriptionData();
if (!guardedPlanIds.has(planId)) {
await next();
return;
}
const tokenUsage = await subscriptionLibrary.getTenantTokenUsage({
from: new Date(currentPeriodStart),
to: new Date(currentPeriodEnd),
});
assertThat(
tokenLimit === null || tokenUsage < tokenLimit,
new RequestError({
code: 'auth.exceed_token_limit',
status: 429,
})
);
} catch (error: unknown) {
if (error instanceof RequestError) {
throw error;
}
// Incase of any unexpected error, track it to App Insights and continue the request.
// Should not block the end-user's request for any unexpected error.
void appInsights.trackException(error, buildAppInsightsTelemetry(ctx));
}
return next();
};
}

View file

@ -5,10 +5,10 @@ import initOidc from './init.js';
describe('oidc provider init', () => {
it('init should not throw', async () => {
const { queries, libraries, logtoConfigs, cloudConnection } = new MockTenant();
const { queries, libraries, logtoConfigs, cloudConnection, subscription } = new MockTenant();
expect(() =>
initOidc(mockEnvSet, queries, libraries, logtoConfigs, cloudConnection)
initOidc(mockEnvSet, queries, libraries, logtoConfigs, cloudConnection, subscription)
).not.toThrow();
});
});

View file

@ -22,7 +22,7 @@ import { Provider, errors } from 'oidc-provider';
import getRawBody from 'raw-body';
import snakecaseKeys from 'snakecase-keys';
import { type EnvSet } from '#src/env-set/index.js';
import { EnvSet } from '#src/env-set/index.js';
import { addOidcEventListeners } from '#src/event-listeners/index.js';
import { type CloudConnectionLibrary } from '#src/libraries/cloud-connection.js';
import { type LogtoConfigLibrary } from '#src/libraries/logto-config.js';
@ -39,6 +39,9 @@ import {
import type Libraries from '#src/tenants/Libraries.js';
import type Queries from '#src/tenants/Queries.js';
import { type SubscriptionLibrary } from '../libraries/subscription.js';
import koaTokenUsageGuard from '../middleware/koa-token-usage-guard.js';
import defaults from './defaults.js';
import {
getExtraTokenClaimsForJwtCustomization,
@ -63,7 +66,8 @@ export default function initOidc(
queries: Queries,
libraries: Libraries,
logtoConfigs: LogtoConfigLibrary,
cloudConnection: CloudConnectionLibrary
cloudConnection: CloudConnectionLibrary,
subscription: SubscriptionLibrary
): Provider {
const {
resources: { findDefaultResource },
@ -414,6 +418,12 @@ export default function initOidc(
oidc.use(koaAppSecretTranspilation(queries));
oidc.use(koaBodyEtag());
// TODO: Remove the devFeature guard when the implementation is stable
// Only enabled in the cloud environment
if (EnvSet.values.isDevFeaturesEnabled && EnvSet.values.isCloud) {
oidc.use(koaTokenUsageGuard(subscription));
}
return oidc;
}
/* eslint-enable max-lines */

View file

@ -41,7 +41,21 @@ export const createDailyTokenUsageQueries = (pool: CommonQueryMethods) => {
returning ${sql.join(Object.values(fields), sql`, `)}
`);
const countTokenUsage = async ({ from, to }: { from: Date; to: Date }) => {
return pool.one<{ tokenUsage: number }>(sql`
select sum(${fields.usage}) as token_usage
from ${table}
where ${fields.date} >= to_timestamp(${getUtcStartOfTheDay(
from
).getTime()}::double precision / 1000)
and ${fields.date} < to_timestamp(${getUtcStartOfTheDay(
to
).getTime()}::double precision / 1000)
`);
};
return {
recordTokenUsage,
countTokenUsage,
};
};

View file

@ -30,6 +30,9 @@ import initApis from '#src/routes/init.js';
import initMeApis from '#src/routes-me/init.js';
import BasicSentinel from '#src/sentinel/basic-sentinel.js';
import { redisCache } from '../caches/index.js';
import { SubscriptionLibrary } from '../libraries/subscription.js';
import Libraries from './Libraries.js';
import Queries from './Queries.js';
import type TenantContext from './TenantContext.js';
@ -89,7 +92,8 @@ export default class Tenant implements TenantContext {
cloudConnection,
logtoConfigs
),
public readonly sentinel = new BasicSentinel(envSet.pool)
public readonly sentinel = new BasicSentinel(envSet.pool),
public readonly subscription = new SubscriptionLibrary(id, queries, cloudConnection, redisCache)
) {
const isAdminTenant = id === adminTenantId;
const mountedApps = [
@ -111,7 +115,14 @@ export default class Tenant implements TenantContext {
app.use(koaSecurityHeaders(mountedApps, id));
// Mount OIDC
const provider = initOidc(envSet, queries, libraries, logtoConfigs, cloudConnection);
const provider = initOidc(
envSet,
queries,
libraries,
logtoConfigs,
cloudConnection,
subscription
);
app.use(mount('/oidc', provider.app));
const tenantContext: TenantContext = {

View file

@ -2,6 +2,7 @@ import { type Sentinel } from '@logto/schemas';
import { TtlCache } from '@logto/shared';
import { createMockPool, createMockQueryResult } from '@silverhand/slonik';
import { redisCache } from '#src/caches/index.js';
import { WellKnownCache } from '#src/caches/well-known.js';
import type { CloudConnectionLibrary } from '#src/libraries/cloud-connection.js';
import { createCloudConnectionLibrary } from '#src/libraries/cloud-connection.js';
@ -12,6 +13,8 @@ import Libraries from '#src/tenants/Libraries.js';
import Queries from '#src/tenants/Queries.js';
import type TenantContext from '#src/tenants/TenantContext.js';
import { SubscriptionLibrary } from '../libraries/subscription.js';
import { mockEnvSet } from './env-set.js';
import type { GrantMock } from './oidc-provider.js';
import { createMockProvider } from './oidc-provider.js';
@ -67,6 +70,7 @@ export class MockTenant implements TenantContext {
public connectors: ConnectorLibrary;
public libraries: Libraries;
public sentinel: Sentinel;
public readonly subscription: SubscriptionLibrary;
// eslint-disable-next-line max-params
constructor(
@ -93,6 +97,12 @@ export class MockTenant implements TenantContext {
);
this.setPartial('libraries', librariesOverride);
this.sentinel = new MockSentinel();
this.subscription = new SubscriptionLibrary(
this.id,
this.queries,
this.cloudConnection,
redisCache
);
}
public async invalidateCache() {

View file

@ -113,7 +113,6 @@ export const subscriptionCacheGuard = z.object({
currentPeriodStart: z.string(),
currentPeriodEnd: z.string(),
isEnterprisePlan: z.boolean(),
isAddOnAvailable: z.boolean(),
status: subscriptionStatusGuard,
upcomingInvoice: upcomingInvoiceGuard.nullable().optional(),
quota: logtoSkuQuotaGuard,

View file

@ -6,6 +6,7 @@ const auth = {
expected_role_not_found: 'Expected role not found. Please check your user roles and permissions.',
jwt_sub_missing: 'Missing `sub` in JWT.',
require_re_authentication: 'Re-authentication is required to perform a protected action.',
exceed_token_limit: 'Token limit exceeded. Please contact your administrator.',
};
export default Object.freeze(auth);