diff --git a/packages/core/src/database/utils.ts b/packages/core/src/database/utils.ts index d669d33ac..59a048c38 100644 --- a/packages/core/src/database/utils.ts +++ b/packages/core/src/database/utils.ts @@ -1,10 +1,13 @@ -import { IdentifierSqlTokenType, sql, ValueExpressionType } from 'slonik'; +import { IdentifierSqlTokenType, sql } from 'slonik'; type Table = { table: string; fields: Record }; type FieldIdentifiers = { [key in Key]: IdentifierSqlTokenType; }; +const convertToPrimitive = (value: T) => + value !== null && typeof value === 'object' ? JSON.stringify(value) : value; + export const convertToIdentifiers = ( { table, fields }: T, withPrefix = false @@ -20,14 +23,23 @@ export const convertToIdentifiers = ( ), }); -export const insertInto = ( +export const insertInto = ( table: IdentifierSqlTokenType, - fields: FieldIdentifiers, - fieldKeys: readonly T[], - value: { [key in T]?: ValueExpressionType } + fields: FieldIdentifiers, + fieldKeys: readonly Key[], + value: { [key in Key]?: Type[key] } ) => sql` - insert into ${table} (${sql.join(Object.values(fields), sql`, `)}) + insert into ${table} (${sql.join( + fieldKeys.map((key) => fields[key]), + sql`, ` +)}) values (${sql.join( - fieldKeys.map((key) => value[key] ?? null), + fieldKeys.map((key) => convertToPrimitive(value[key] ?? null)), sql`, ` )})`; + +export const setExcluded = (...fields: IdentifierSqlTokenType[]) => + sql.join( + fields.map((field) => sql`${field}=excluded.${field}`), + sql`, ` + ); diff --git a/packages/core/src/oidc/adapter.ts b/packages/core/src/oidc/adapter.ts index e57b72b7a..8481e4510 100644 --- a/packages/core/src/oidc/adapter.ts +++ b/packages/core/src/oidc/adapter.ts @@ -1,106 +1,21 @@ -import dayjs from 'dayjs'; import { AdapterFactory } from 'oidc-provider'; -import { IdentifierSqlTokenType, sql, ValueExpressionType } from 'slonik'; -import { conditional } from '@logto/essentials'; import { - OidcModelInstances, - OidcModelInstanceDBEntry, - OidcModelInstancePayload, -} from '@logto/schemas'; -import pool from '@/database/pool'; -import { convertToIdentifiers } from '@/database/utils'; + consumeInstanceById, + destoryInstanceById, + findPayloadById, + findPayloadByPayloadField, + revokeInstanceByGrantId, + upsertInstance, +} from '@/queries/oidc-adapter'; -export default function postgresAdapter(modelName: string) { - const { table, fields } = convertToIdentifiers(OidcModelInstances); - - type WithConsumed = T & { consumed?: boolean }; - const withConsumed = (data: T, consumedAt?: number): WithConsumed => ({ - ...data, - ...(consumedAt ? { consumed: true } : undefined), - }); - type QueryResult = Pick; - const convertResult = (result: QueryResult | null) => - conditional(result && withConsumed(result.payload, result.consumedAt)); - const setExcluded = (...fields: IdentifierSqlTokenType[]) => - sql.join( - fields.map((field) => sql`${field}=excluded.${field}`), - sql`, ` - ); - - const findByField = async ( - field: IdentifierSqlTokenType, - value: T - ) => { - const result = await pool.maybeOne(sql` - select ${fields.payload}, ${fields.consumedAt} - from ${table} - where ${fields.modelName}=${modelName} - and ${field}=${value} - `); - - return convertResult(result); +export default function postgresAdapter(modelName: string): ReturnType { + return { + upsert: async (id, payload, expiresIn) => upsertInstance(modelName, id, payload, expiresIn), + find: async (id) => findPayloadById(modelName, id), + findByUserCode: async (userCode) => findPayloadByPayloadField(modelName, 'userCode', userCode), + findByUid: async (uid) => findPayloadByPayloadField(modelName, 'uid', uid), + consume: async (id) => consumeInstanceById(modelName, id), + destroy: async (id) => destoryInstanceById(modelName, id), + revokeByGrantId: async (grantId) => revokeInstanceByGrantId(modelName, grantId), }; - - const findByPayloadField = async < - T extends ValueExpressionType, - Field extends keyof OidcModelInstancePayload - >( - field: Field, - value: T - ) => { - const result = await pool.maybeOne(sql` - select ${fields.payload}, ${fields.consumedAt} - from ${table} - where ${fields.modelName}=${modelName} - and ${fields.payload}->>${field}=${value} - `); - - return convertResult(result); - }; - - const adapter: ReturnType = { - upsert: async (id, payload, expiresIn) => { - await pool.query(sql` - insert into ${table} (${sql.join( - [fields.modelName, fields.id, fields.payload, fields.expiresAt], - sql`, ` - )}) - values ( - ${modelName}, - ${id}, - ${JSON.stringify(payload)}, - ${dayjs().add(expiresIn, 'second').unix()} - ) - on conflict (${fields.modelName}, ${fields.id}) do update - set ${setExcluded(fields.payload, fields.expiresAt)} - `); - }, - find: async (id) => findByField(fields.id, id), - findByUserCode: async (userCode) => findByPayloadField('userCode', userCode), - findByUid: async (uid) => findByPayloadField('uid', uid), - consume: async (id) => { - await pool.query(sql` - update ${table} - set ${fields.consumedAt}=${dayjs().unix()} - where ${fields.modelName}=${modelName} - and ${fields.id}=${id} - `); - }, - destroy: async (id) => { - await pool.query(sql` - delete from ${table} - where ${fields.modelName}=${modelName} - and ${fields.id}=${id} - `); - }, - revokeByGrantId: async (grantId) => { - await pool.query(sql` - delete from ${table} - where ${fields.modelName}=${modelName} - and ${fields.payload}->>'grantId'=${grantId} - `); - }, - }; - - return adapter; } diff --git a/packages/core/src/queries/oidc-adapter.ts b/packages/core/src/queries/oidc-adapter.ts new file mode 100644 index 000000000..35addb5cf --- /dev/null +++ b/packages/core/src/queries/oidc-adapter.ts @@ -0,0 +1,104 @@ +import pool from '@/database/pool'; +import { convertToIdentifiers, insertInto, setExcluded } from '@/database/utils'; +import { conditional } from '@logto/essentials'; +import { + OidcModelInstanceDBEntry, + OidcModelInstancePayload, + OidcModelInstances, +} from '@logto/schemas'; +import dayjs from 'dayjs'; +import { sql, ValueExpressionType } from 'slonik'; + +export type WithConsumed = T & { consumed?: boolean }; +export type QueryResult = Pick; + +const { table, fields } = convertToIdentifiers(OidcModelInstances); + +const withConsumed = (data: T, consumedAt?: number): WithConsumed => ({ + ...data, + ...(consumedAt ? { consumed: true } : undefined), +}); + +const convertResult = (result: QueryResult | null) => + conditional(result && withConsumed(result.payload, result.consumedAt)); + +export const upsertInstance = async ( + modelName: string, + id: string, + payload: OidcModelInstancePayload, + expiresIn: number +) => { + await pool.query( + sql` + ${insertInto( + table, + fields, + ['modelName', 'id', 'payload', 'expiresAt'], + { + modelName, + id, + payload, + expiresAt: dayjs().add(expiresIn, 'second').unix(), + } + )} + on conflict (${fields.modelName}, ${fields.id}) do update + set ${setExcluded(fields.payload, fields.expiresAt)} + ` + ); +}; + +const findByModel = (modelName: string) => sql` + select ${fields.payload}, ${fields.consumedAt} + from ${table} + where ${fields.modelName}=${modelName} +`; + +export const findPayloadById = async (modelName: string, id: string) => { + const result = await pool.maybeOne(sql` + ${findByModel(modelName)} + and ${fields.id}=${id} + `); + + return convertResult(result); +}; + +export const findPayloadByPayloadField = async < + T extends ValueExpressionType, + Field extends keyof OidcModelInstancePayload +>( + modelName: string, + field: Field, + value: T +) => { + const result = await pool.maybeOne(sql` + ${findByModel(modelName)} + and ${fields.payload}->>${field}=${value} + `); + + return convertResult(result); +}; + +export const consumeInstanceById = async (modelName: string, id: string) => { + await pool.query(sql` + update ${table} + set ${fields.consumedAt}=${dayjs().unix()} + where ${fields.modelName}=${modelName} + and ${fields.id}=${id} + `); +}; + +export const destoryInstanceById = async (modelName: string, id: string) => { + await pool.query(sql` + delete from ${table} + where ${fields.modelName}=${modelName} + and ${fields.id}=${id} + `); +}; + +export const revokeInstanceByGrantId = async (modelName: string, grantId: string) => { + await pool.query(sql` + delete from ${table} + where ${fields.modelName}=${modelName} + and ${fields.payload}->>'grantId'=${grantId} + `); +};