diff --git a/server/src/services/search.service.spec.ts b/server/src/services/search.service.spec.ts index e0b03f31ae..548a058c79 100644 --- a/server/src/services/search.service.spec.ts +++ b/server/src/services/search.service.spec.ts @@ -1,11 +1,16 @@ +import { BadRequestException } from '@nestjs/common'; import { mapAsset } from 'src/dtos/asset-response.dto'; import { SearchSuggestionType } from 'src/dtos/search.dto'; import { IAssetRepository } from 'src/interfaces/asset.interface'; +import { IMachineLearningRepository } from 'src/interfaces/machine-learning.interface'; +import { IPartnerRepository } from 'src/interfaces/partner.interface'; import { IPersonRepository } from 'src/interfaces/person.interface'; import { ISearchRepository } from 'src/interfaces/search.interface'; +import { ISystemMetadataRepository } from 'src/interfaces/system-metadata.interface'; import { SearchService } from 'src/services/search.service'; import { assetStub } from 'test/fixtures/asset.stub'; import { authStub } from 'test/fixtures/auth.stub'; +import { partnerStub } from 'test/fixtures/partner.stub'; import { personStub } from 'test/fixtures/person.stub'; import { newTestService } from 'test/utils'; import { Mocked, beforeEach, vitest } from 'vitest'; @@ -16,11 +21,15 @@ describe(SearchService.name, () => { let sut: SearchService; let assetMock: Mocked; + let machineLearningMock: Mocked; + let partnerMock: Mocked; let personMock: Mocked; let searchMock: Mocked; + let systemMock: Mocked; beforeEach(() => { - ({ sut, assetMock, personMock, searchMock } = newTestService(SearchService)); + ({ sut, assetMock, machineLearningMock, partnerMock, personMock, searchMock, systemMock } = + newTestService(SearchService)); }); it('should work', () => { @@ -80,4 +89,99 @@ describe(SearchService.name, () => { expect(searchMock.getCountries).toHaveBeenCalledWith([authStub.user1.user.id]); }); }); + + describe('searchSmart', () => { + beforeEach(() => { + searchMock.searchSmart.mockResolvedValue({ hasNextPage: false, items: [] }); + machineLearningMock.encodeText.mockResolvedValue([1, 2, 3]); + }); + + it('should raise a BadRequestException if machine learning is disabled', async () => { + systemMock.get.mockResolvedValue({ + machineLearning: { enabled: false }, + }); + + await expect(sut.searchSmart(authStub.user1, { query: 'test' })).rejects.toThrowError( + new BadRequestException('Smart search is not enabled'), + ); + }); + + it('should raise a BadRequestException if smart search is disabled', async () => { + systemMock.get.mockResolvedValue({ + machineLearning: { clip: { enabled: false } }, + }); + + await expect(sut.searchSmart(authStub.user1, { query: 'test' })).rejects.toThrowError( + new BadRequestException('Smart search is not enabled'), + ); + }); + + it('should work', async () => { + await sut.searchSmart(authStub.user1, { query: 'test' }); + + expect(machineLearningMock.encodeText).toHaveBeenCalledWith( + expect.any(String), + 'test', + expect.objectContaining({ modelName: expect.any(String) }), + ); + expect(searchMock.searchSmart).toHaveBeenCalledWith( + { page: 1, size: 100 }, + { query: 'test', embedding: [1, 2, 3], userIds: [authStub.user1.user.id] }, + ); + }); + + it('should include partner shared assets', async () => { + partnerMock.getAll.mockResolvedValue([partnerStub.adminToUser1]); + + await sut.searchSmart(authStub.user1, { query: 'test' }); + + expect(machineLearningMock.encodeText).toHaveBeenCalledWith( + expect.any(String), + 'test', + expect.objectContaining({ modelName: expect.any(String) }), + ); + expect(searchMock.searchSmart).toHaveBeenCalledWith( + { page: 1, size: 100 }, + { query: 'test', embedding: [1, 2, 3], userIds: [authStub.user1.user.id, authStub.admin.user.id] }, + ); + }); + + it('should consider page and size parameters', async () => { + await sut.searchSmart(authStub.user1, { query: 'test', page: 2, size: 50 }); + + expect(machineLearningMock.encodeText).toHaveBeenCalledWith( + expect.any(String), + 'test', + expect.objectContaining({ modelName: expect.any(String) }), + ); + expect(searchMock.searchSmart).toHaveBeenCalledWith( + { page: 2, size: 50 }, + expect.objectContaining({ query: 'test', embedding: [1, 2, 3], userIds: [authStub.user1.user.id] }), + ); + }); + + it('should use clip model specified in config', async () => { + systemMock.get.mockResolvedValue({ + machineLearning: { clip: { modelName: 'ViT-B-16-SigLIP__webli' } }, + }); + + await sut.searchSmart(authStub.user1, { query: 'test' }); + + expect(machineLearningMock.encodeText).toHaveBeenCalledWith( + expect.any(String), + 'test', + expect.objectContaining({ modelName: 'ViT-B-16-SigLIP__webli' }), + ); + }); + + it('should use language specified in request', async () => { + await sut.searchSmart(authStub.user1, { query: 'test', language: 'de' }); + + expect(machineLearningMock.encodeText).toHaveBeenCalledWith( + expect.any(String), + 'test', + expect.objectContaining({ language: 'de' }), + ); + }); + }); });