diff --git a/packages/core/src/database/find-all-entities.ts b/packages/core/src/database/find-all-entities.ts index bdf0e1af3..cabcdd77e 100644 --- a/packages/core/src/database/find-all-entities.ts +++ b/packages/core/src/database/find-all-entities.ts @@ -2,6 +2,8 @@ import { type GeneratedSchema, type SchemaLike } from '@logto/schemas'; import { conditionalSql, convertToIdentifiers, manyRows } from '@logto/shared'; import { sql, type CommonQueryMethods } from 'slonik'; +import { buildSearchSql, type SearchOptions } from './utils.js'; + export const buildFindAllEntitiesWithPool = (pool: CommonQueryMethods) => < @@ -17,11 +19,16 @@ export const buildFindAllEntitiesWithPool = ) => { const { table, fields } = convertToIdentifiers(schema); - return async (limit?: number, offset?: number) => + return async ( + limit?: number, + offset?: number, + search?: SearchOptions + ) => manyRows( pool.query(sql` select ${sql.join(Object.values(fields), sql`, `)} from ${table} + ${buildSearchSql(search)} ${conditionalSql(orderBy, (orderBy) => { const orderBySql = orderBy.map(({ field, order }) => // Note: 'desc' and 'asc' are keywords, so we don't pass them as values diff --git a/packages/core/src/database/row-count.ts b/packages/core/src/database/row-count.ts index 399ea3ebe..ceff0437e 100644 --- a/packages/core/src/database/row-count.ts +++ b/packages/core/src/database/row-count.ts @@ -1,13 +1,17 @@ import type { CommonQueryMethods, IdentifierSqlToken } from 'slonik'; import { sql } from 'slonik'; +import { type SearchOptions, buildSearchSql } from './utils.js'; + export const buildGetTotalRowCountWithPool = - (pool: CommonQueryMethods, table: string) => async () => { + (pool: CommonQueryMethods, table: string) => + async (search?: SearchOptions) => { // Postgres returns a bigint for count(*), which is then converted to a string by query library. // We need to convert it to a number. const { count } = await pool.one<{ count: string }>(sql` select count(*) from ${sql.identifier([table])} + ${buildSearchSql(search)} `); return { count: Number(count) }; diff --git a/packages/core/src/database/utils.ts b/packages/core/src/database/utils.ts new file mode 100644 index 000000000..9f8075ee1 --- /dev/null +++ b/packages/core/src/database/utils.ts @@ -0,0 +1,23 @@ +import { conditionalSql } from '@logto/shared'; +import { sql } from 'slonik'; + +/** + * Options for searching for a string within a set of fields (case-insensitive). + * + * Note: `id` is excluded from the fields since it should be unique. + */ +export type SearchOptions = { + fields: ReadonlyArray>; + keyword: string; +}; + +export const buildSearchSql = (search?: SearchOptions) => { + return conditionalSql(search, (search) => { + const { fields: searchFields, keyword } = search; + const searchSql = sql.join( + searchFields.map((field) => sql`${sql.identifier([field])} ilike ${`%${keyword}%`}`), + sql` or ` + ); + return sql`where ${searchSql}`; + }); +}; diff --git a/packages/core/src/queries/organizations.ts b/packages/core/src/queries/organizations.ts index 6de52681a..8b292efd6 100644 --- a/packages/core/src/queries/organizations.ts +++ b/packages/core/src/queries/organizations.ts @@ -101,8 +101,11 @@ class OrganizationRolesQueries extends SchemaQueries< override async findAll( limit: number, offset: number - ): Promise> { - return this.pool.any(this.#findWithScopesSql(undefined, limit, offset)); + ): Promise<[totalNumber: number, rows: Readonly]> { + return Promise.all([ + this.findTotalNumber(), + this.pool.any(this.#findWithScopesSql(undefined, limit, offset)), + ]); } #findWithScopesSql(roleId?: string, limit = 1, offset = 0) { diff --git a/packages/core/src/routes/organization/index.ts b/packages/core/src/routes/organization/index.ts index 9c09d8c8a..560cc0dcd 100644 --- a/packages/core/src/routes/organization/index.ts +++ b/packages/core/src/routes/organization/index.ts @@ -21,6 +21,7 @@ export default function organizationRoutes(...args: Rout ] = args; const router = new SchemaRouter(Organizations, organizations, { errorHandler, + searchFields: ['name'], }); router.addRelationRoutes(organizations.relations.users); diff --git a/packages/core/src/routes/organization/roles.ts b/packages/core/src/routes/organization/roles.ts index eed24b4ce..50b474c4a 100644 --- a/packages/core/src/routes/organization/roles.ts +++ b/packages/core/src/routes/organization/roles.ts @@ -21,8 +21,8 @@ export default function organizationRoleRoutes( ) { const router = new SchemaRouter(OrganizationRoles, roles, { errorHandler, + searchFields: ['name'], }); - router.addRelationRoutes(rolesScopes, 'scopes'); originalRouter.use(router.routes()); diff --git a/packages/core/src/routes/organization/scopes.ts b/packages/core/src/routes/organization/scopes.ts index 12bd1eef3..b1622af3b 100644 --- a/packages/core/src/routes/organization/scopes.ts +++ b/packages/core/src/routes/organization/scopes.ts @@ -16,7 +16,10 @@ export default function organizationScopeRoutes( }, ]: RouterInitArgs ) { - const router = new SchemaRouter(OrganizationScopes, scopes, { errorHandler }); + const router = new SchemaRouter(OrganizationScopes, scopes, { + errorHandler, + searchFields: ['name'], + }); originalRouter.use(router.routes()); } diff --git a/packages/core/src/utils/SchemaQueries.ts b/packages/core/src/utils/SchemaQueries.ts index 78ca5bd53..cdde170ec 100644 --- a/packages/core/src/utils/SchemaQueries.ts +++ b/packages/core/src/utils/SchemaQueries.ts @@ -8,6 +8,7 @@ import { buildFindEntityByIdWithPool } from '#src/database/find-entity-by-id.js' import { buildInsertIntoWithPool } from '#src/database/insert-into.js'; import { buildGetTotalRowCountWithPool } from '#src/database/row-count.js'; import { buildUpdateWhereWithPool } from '#src/database/update-where.js'; +import { type SearchOptions } from '#src/database/utils.js'; /** * Query class that contains all the necessary CRUD queries for a schema. It is @@ -19,8 +20,16 @@ export default class SchemaQueries< CreateSchema extends Partial & { id: string }>, Schema extends SchemaLike & { id: string }, > { - #findTotalNumber: () => Promise<{ count: number }>; - #findAll: (limit: number, offset: number) => Promise; + #findTotalNumber: ( + search?: SearchOptions + ) => Promise<{ count: number }>; + + #findAll: ( + limit: number, + offset: number, + search?: SearchOptions + ) => Promise; + #findById: (id: string) => Promise>; #insert: (data: OmitAutoSetFields) => Promise>; @@ -43,13 +52,12 @@ export default class SchemaQueries< this.#deleteById = buildDeleteByIdWithPool(this.pool, this.schema.table); } - async findTotalNumber(): Promise { - const { count } = await this.#findTotalNumber(); - return count; - } - - async findAll(limit: number, offset: number): Promise { - return this.#findAll(limit, offset); + async findAll( + limit: number, + offset: number, + search?: SearchOptions + ): Promise<[totalNumber: number, rows: readonly Schema[]]> { + return Promise.all([this.findTotalNumber(search), this.#findAll(limit, offset, search)]); } async findById(id: string): Promise> { @@ -71,4 +79,9 @@ export default class SchemaQueries< async deleteById(id: string): Promise { await this.#deleteById(id); } + + protected async findTotalNumber(search?: SearchOptions): Promise { + const { count } = await this.#findTotalNumber(search); + return count; + } } diff --git a/packages/core/src/utils/SchemaRouter.test.ts b/packages/core/src/utils/SchemaRouter.test.ts index a643c4114..2f12a7902 100644 --- a/packages/core/src/utils/SchemaRouter.test.ts +++ b/packages/core/src/utils/SchemaRouter.test.ts @@ -38,8 +38,7 @@ describe('SchemaRouter', () => { ] as const satisfies readonly Schema[]; const queries = new SchemaQueries(createTestPool(undefined, { id: '1' }), schema); - jest.spyOn(queries, 'findTotalNumber').mockResolvedValue(entities.length); - jest.spyOn(queries, 'findAll').mockResolvedValue(entities); + jest.spyOn(queries, 'findAll').mockResolvedValue([entities.length, entities]); jest.spyOn(queries, 'findById').mockImplementation(async (id) => { const entity = entities.find((entity) => entity.id === id); if (!entity) { @@ -67,16 +66,14 @@ describe('SchemaRouter', () => { it('should be able to get all entities', async () => { const response = await request.get(baseRoute); - expect(queries.findAll).toHaveBeenCalledWith(20, 0); - expect(queries.findTotalNumber).toHaveBeenCalled(); + expect(queries.findAll).toHaveBeenCalledWith(20, 0, undefined); expect(response.body).toStrictEqual(entities); }); it('should be able to get all entities with pagination', async () => { const response = await request.get(`${baseRoute}?page=1&page_size=10`); - expect(queries.findAll).toHaveBeenCalledWith(10, 0); - expect(queries.findTotalNumber).toHaveBeenCalled(); + expect(queries.findAll).toHaveBeenCalledWith(10, 0, undefined); expect(response.body).toStrictEqual(entities); expect(response.header).toHaveProperty('total-number', '2'); }); diff --git a/packages/core/src/utils/SchemaRouter.ts b/packages/core/src/utils/SchemaRouter.ts index 2f2b1f9c1..0b2fd0bf4 100644 --- a/packages/core/src/utils/SchemaRouter.ts +++ b/packages/core/src/utils/SchemaRouter.ts @@ -1,11 +1,12 @@ import { type SchemaLike, type GeneratedSchema } from '@logto/schemas'; import { generateStandardId } from '@logto/shared'; -import { type DeepPartial } from '@silverhand/essentials'; +import { cond, type Optional, type DeepPartial } from '@silverhand/essentials'; import camelcase from 'camelcase'; import deepmerge from 'deepmerge'; import Router, { type IRouterParamContext } from 'koa-router'; import { z } from 'zod'; +import { type SearchOptions } from '#src/database/utils.js'; import koaGuard from '#src/middleware/koa-guard.js'; import koaPagination from '#src/middleware/koa-pagination.js'; @@ -34,7 +35,7 @@ const tableToPathname = (tableName: string) => tableName.replaceAll('_', '-'); const camelCaseSchemaId = (schema: T) => `${camelcase(schema.tableSingular)}Id` as const; -type SchemaRouterConfig = { +type SchemaRouterConfig = { /** Disable certain routes for the router. */ disabled: { /** Disable `GET /` route. */ @@ -48,7 +49,10 @@ type SchemaRouterConfig = { /** Disable `DELETE /:id` route. */ deleteById: boolean; }; + /** A custom error handler for the router before throwing the error. */ errorHandler?: (error: unknown) => void; + /** The fields that can be searched for the `GET /` route. */ + searchFields: SearchOptions['fields']; }; /** @@ -63,8 +67,6 @@ type SchemaRouterConfig = { * - `DELETE /:id`: Delete an entity by ID. * * Browse the source code for more details about request/response validation. - * - * @see {@link SchemaActions} for the `actions` configuration. */ export default class SchemaRouter< Key extends string, @@ -73,16 +75,16 @@ export default class SchemaRouter< StateT = unknown, CustomT extends IRouterParamContext = IRouterParamContext, > extends Router { - public readonly config: SchemaRouterConfig; + public readonly config: SchemaRouterConfig; constructor( public readonly schema: GeneratedSchema, public readonly queries: SchemaQueries, - config: DeepPartial = {} + config: DeepPartial> = {} ) { super({ prefix: '/' + tableToPathname(schema.table) }); - this.config = deepmerge>( + this.config = deepmerge>( { disabled: { get: false, @@ -91,6 +93,7 @@ export default class SchemaRouter< patchById: false, deleteById: false, }, + searchFields: [], }, config ); @@ -106,19 +109,29 @@ export default class SchemaRouter< }); } - const { disabled } = this.config; + const { disabled, searchFields } = this.config; if (!disabled.get) { this.get( '/', koaPagination(), - koaGuard({ response: schema.guard.array(), status: [200] }), + koaGuard({ + query: z.object({ q: z.string().optional() }), + response: schema.guard.array(), + status: [200], + }), async (ctx, next) => { + const { q } = ctx.guard.query; + const search: Optional> = cond( + q && + searchFields.length > 0 && { + fields: searchFields, + keyword: q, + } + ); const { limit, offset } = ctx.pagination; - const [count, entities] = await Promise.all([ - queries.findTotalNumber(), - queries.findAll(limit, offset), - ]); + const [count, entities] = await queries.findAll(limit, offset, search); + ctx.pagination.totalCount = count; ctx.body = entities; return next();