mirror of
https://github.com/immich-app/immich.git
synced 2025-01-21 00:52:43 -05:00
chore: migrate oauth to repo (#13211)
This commit is contained in:
parent
9d9bf1c88d
commit
a5e9adb593
8 changed files with 171 additions and 121 deletions
22
server/src/interfaces/oauth.interface.ts
Normal file
22
server/src/interfaces/oauth.interface.ts
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
import { UserinfoResponse } from 'openid-client';
|
||||||
|
|
||||||
|
export const IOAuthRepository = 'IOAuthRepository';
|
||||||
|
|
||||||
|
export type OAuthConfig = {
|
||||||
|
clientId: string;
|
||||||
|
clientSecret: string;
|
||||||
|
issuerUrl: string;
|
||||||
|
mobileOverrideEnabled: boolean;
|
||||||
|
mobileRedirectUri: string;
|
||||||
|
profileSigningAlgorithm: string;
|
||||||
|
scope: string;
|
||||||
|
signingAlgorithm: string;
|
||||||
|
};
|
||||||
|
export type OAuthProfile = UserinfoResponse;
|
||||||
|
|
||||||
|
export interface IOAuthRepository {
|
||||||
|
init(): void;
|
||||||
|
authorize(config: OAuthConfig, redirectUrl: string): Promise<string>;
|
||||||
|
getLogoutEndpoint(config: OAuthConfig): Promise<string | undefined>;
|
||||||
|
getProfile(config: OAuthConfig, url: string, redirectUrl: string): Promise<OAuthProfile>;
|
||||||
|
}
|
|
@ -20,6 +20,7 @@ import { IMetadataRepository } from 'src/interfaces/metadata.interface';
|
||||||
import { IMetricRepository } from 'src/interfaces/metric.interface';
|
import { IMetricRepository } from 'src/interfaces/metric.interface';
|
||||||
import { IMoveRepository } from 'src/interfaces/move.interface';
|
import { IMoveRepository } from 'src/interfaces/move.interface';
|
||||||
import { INotificationRepository } from 'src/interfaces/notification.interface';
|
import { INotificationRepository } from 'src/interfaces/notification.interface';
|
||||||
|
import { IOAuthRepository } from 'src/interfaces/oauth.interface';
|
||||||
import { IPartnerRepository } from 'src/interfaces/partner.interface';
|
import { IPartnerRepository } from 'src/interfaces/partner.interface';
|
||||||
import { IPersonRepository } from 'src/interfaces/person.interface';
|
import { IPersonRepository } from 'src/interfaces/person.interface';
|
||||||
import { ISearchRepository } from 'src/interfaces/search.interface';
|
import { ISearchRepository } from 'src/interfaces/search.interface';
|
||||||
|
@ -56,6 +57,7 @@ import { MetadataRepository } from 'src/repositories/metadata.repository';
|
||||||
import { MetricRepository } from 'src/repositories/metric.repository';
|
import { MetricRepository } from 'src/repositories/metric.repository';
|
||||||
import { MoveRepository } from 'src/repositories/move.repository';
|
import { MoveRepository } from 'src/repositories/move.repository';
|
||||||
import { NotificationRepository } from 'src/repositories/notification.repository';
|
import { NotificationRepository } from 'src/repositories/notification.repository';
|
||||||
|
import { OAuthRepository } from 'src/repositories/oauth.repository';
|
||||||
import { PartnerRepository } from 'src/repositories/partner.repository';
|
import { PartnerRepository } from 'src/repositories/partner.repository';
|
||||||
import { PersonRepository } from 'src/repositories/person.repository';
|
import { PersonRepository } from 'src/repositories/person.repository';
|
||||||
import { SearchRepository } from 'src/repositories/search.repository';
|
import { SearchRepository } from 'src/repositories/search.repository';
|
||||||
|
@ -94,6 +96,7 @@ export const repositories = [
|
||||||
{ provide: IMetricRepository, useClass: MetricRepository },
|
{ provide: IMetricRepository, useClass: MetricRepository },
|
||||||
{ provide: IMoveRepository, useClass: MoveRepository },
|
{ provide: IMoveRepository, useClass: MoveRepository },
|
||||||
{ provide: INotificationRepository, useClass: NotificationRepository },
|
{ provide: INotificationRepository, useClass: NotificationRepository },
|
||||||
|
{ provide: IOAuthRepository, useClass: OAuthRepository },
|
||||||
{ provide: IPartnerRepository, useClass: PartnerRepository },
|
{ provide: IPartnerRepository, useClass: PartnerRepository },
|
||||||
{ provide: IPersonRepository, useClass: PersonRepository },
|
{ provide: IPersonRepository, useClass: PersonRepository },
|
||||||
{ provide: ISearchRepository, useClass: SearchRepository },
|
{ provide: ISearchRepository, useClass: SearchRepository },
|
||||||
|
|
73
server/src/repositories/oauth.repository.ts
Normal file
73
server/src/repositories/oauth.repository.ts
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
import { Inject, Injectable, InternalServerErrorException } from '@nestjs/common';
|
||||||
|
import { custom, generators, Issuer } from 'openid-client';
|
||||||
|
import { ILoggerRepository } from 'src/interfaces/logger.interface';
|
||||||
|
import { IOAuthRepository, OAuthConfig, OAuthProfile } from 'src/interfaces/oauth.interface';
|
||||||
|
import { Instrumentation } from 'src/utils/instrumentation';
|
||||||
|
|
||||||
|
@Instrumentation()
|
||||||
|
@Injectable()
|
||||||
|
export class OAuthRepository implements IOAuthRepository {
|
||||||
|
constructor(@Inject(ILoggerRepository) private logger: ILoggerRepository) {
|
||||||
|
this.logger.setContext(OAuthRepository.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
init() {
|
||||||
|
custom.setHttpOptionsDefaults({ timeout: 30_000 });
|
||||||
|
}
|
||||||
|
|
||||||
|
async authorize(config: OAuthConfig, redirectUrl: string) {
|
||||||
|
const client = await this.getClient(config);
|
||||||
|
return client.authorizationUrl({
|
||||||
|
redirect_uri: redirectUrl,
|
||||||
|
scope: config.scope,
|
||||||
|
state: generators.state(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async getLogoutEndpoint(config: OAuthConfig) {
|
||||||
|
const client = await this.getClient(config);
|
||||||
|
return client.issuer.metadata.end_session_endpoint;
|
||||||
|
}
|
||||||
|
|
||||||
|
async getProfile(config: OAuthConfig, url: string, redirectUrl: string): Promise<OAuthProfile> {
|
||||||
|
const client = await this.getClient(config);
|
||||||
|
const params = client.callbackParams(url);
|
||||||
|
try {
|
||||||
|
const tokens = await client.callback(redirectUrl, params, { state: params.state });
|
||||||
|
return await client.userinfo<OAuthProfile>(tokens.access_token || '');
|
||||||
|
} catch (error: Error | any) {
|
||||||
|
if (error.message.includes('unexpected JWT alg received')) {
|
||||||
|
this.logger.warn(
|
||||||
|
[
|
||||||
|
'Algorithm mismatch. Make sure the signing algorithm is set correctly in the OAuth settings.',
|
||||||
|
'Or, that you have specified a signing key in your OAuth provider.',
|
||||||
|
].join(' '),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async getClient({
|
||||||
|
issuerUrl,
|
||||||
|
clientId,
|
||||||
|
clientSecret,
|
||||||
|
profileSigningAlgorithm,
|
||||||
|
signingAlgorithm,
|
||||||
|
}: OAuthConfig) {
|
||||||
|
try {
|
||||||
|
const issuer = await Issuer.discover(issuerUrl);
|
||||||
|
return new issuer.Client({
|
||||||
|
client_id: clientId,
|
||||||
|
client_secret: clientSecret,
|
||||||
|
response_types: ['code'],
|
||||||
|
userinfo_signed_response_alg: profileSigningAlgorithm === 'none' ? undefined : profileSigningAlgorithm,
|
||||||
|
id_token_signed_response_alg: signingAlgorithm,
|
||||||
|
});
|
||||||
|
} catch (error: any | AggregateError) {
|
||||||
|
this.logger.error(`Error in OAuth discovery: ${error}`, error?.stack, error?.errors);
|
||||||
|
throw new InternalServerErrorException(`Error in OAuth discovery: ${error}`, { cause: error });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,5 +1,4 @@
|
||||||
import { BadRequestException, ForbiddenException, UnauthorizedException } from '@nestjs/common';
|
import { BadRequestException, ForbiddenException, UnauthorizedException } from '@nestjs/common';
|
||||||
import { Issuer, generators } from 'openid-client';
|
|
||||||
import { AuthDto, SignUpDto } from 'src/dtos/auth.dto';
|
import { AuthDto, SignUpDto } from 'src/dtos/auth.dto';
|
||||||
import { UserMetadataEntity } from 'src/entities/user-metadata.entity';
|
import { UserMetadataEntity } from 'src/entities/user-metadata.entity';
|
||||||
import { UserEntity } from 'src/entities/user.entity';
|
import { UserEntity } from 'src/entities/user.entity';
|
||||||
|
@ -7,6 +6,7 @@ import { AuthType, Permission } from 'src/enum';
|
||||||
import { IKeyRepository } from 'src/interfaces/api-key.interface';
|
import { IKeyRepository } from 'src/interfaces/api-key.interface';
|
||||||
import { ICryptoRepository } from 'src/interfaces/crypto.interface';
|
import { ICryptoRepository } from 'src/interfaces/crypto.interface';
|
||||||
import { IEventRepository } from 'src/interfaces/event.interface';
|
import { IEventRepository } from 'src/interfaces/event.interface';
|
||||||
|
import { IOAuthRepository } from 'src/interfaces/oauth.interface';
|
||||||
import { ISessionRepository } from 'src/interfaces/session.interface';
|
import { ISessionRepository } from 'src/interfaces/session.interface';
|
||||||
import { ISharedLinkRepository } from 'src/interfaces/shared-link.interface';
|
import { ISharedLinkRepository } from 'src/interfaces/shared-link.interface';
|
||||||
import { ISystemMetadataRepository } from 'src/interfaces/system-metadata.interface';
|
import { ISystemMetadataRepository } from 'src/interfaces/system-metadata.interface';
|
||||||
|
@ -19,7 +19,7 @@ import { sharedLinkStub } from 'test/fixtures/shared-link.stub';
|
||||||
import { systemConfigStub } from 'test/fixtures/system-config.stub';
|
import { systemConfigStub } from 'test/fixtures/system-config.stub';
|
||||||
import { userStub } from 'test/fixtures/user.stub';
|
import { userStub } from 'test/fixtures/user.stub';
|
||||||
import { newTestService } from 'test/utils';
|
import { newTestService } from 'test/utils';
|
||||||
import { Mock, Mocked, vitest } from 'vitest';
|
import { Mocked } from 'vitest';
|
||||||
|
|
||||||
// const token = Buffer.from('my-api-key', 'utf8').toString('base64');
|
// const token = Buffer.from('my-api-key', 'utf8').toString('base64');
|
||||||
|
|
||||||
|
@ -53,36 +53,19 @@ describe('AuthService', () => {
|
||||||
let cryptoMock: Mocked<ICryptoRepository>;
|
let cryptoMock: Mocked<ICryptoRepository>;
|
||||||
let eventMock: Mocked<IEventRepository>;
|
let eventMock: Mocked<IEventRepository>;
|
||||||
let keyMock: Mocked<IKeyRepository>;
|
let keyMock: Mocked<IKeyRepository>;
|
||||||
|
let oauthMock: Mocked<IOAuthRepository>;
|
||||||
let sessionMock: Mocked<ISessionRepository>;
|
let sessionMock: Mocked<ISessionRepository>;
|
||||||
let sharedLinkMock: Mocked<ISharedLinkRepository>;
|
let sharedLinkMock: Mocked<ISharedLinkRepository>;
|
||||||
let systemMock: Mocked<ISystemMetadataRepository>;
|
let systemMock: Mocked<ISystemMetadataRepository>;
|
||||||
let userMock: Mocked<IUserRepository>;
|
let userMock: Mocked<IUserRepository>;
|
||||||
|
|
||||||
let callbackMock: Mock;
|
|
||||||
let userinfoMock: Mock;
|
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
callbackMock = vitest.fn().mockReturnValue({ access_token: 'access-token' });
|
({ sut, cryptoMock, eventMock, keyMock, oauthMock, sessionMock, sharedLinkMock, systemMock, userMock } =
|
||||||
userinfoMock = vitest.fn().mockResolvedValue({ sub, email });
|
|
||||||
|
|
||||||
vitest.spyOn(generators, 'state').mockReturnValue('state');
|
|
||||||
vitest.spyOn(Issuer, 'discover').mockResolvedValue({
|
|
||||||
id_token_signing_alg_values_supported: ['RS256'],
|
|
||||||
Client: vitest.fn().mockResolvedValue({
|
|
||||||
issuer: {
|
|
||||||
metadata: {
|
|
||||||
end_session_endpoint: 'http://end-session-endpoint',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
authorizationUrl: vitest.fn().mockReturnValue('http://authorization-url'),
|
|
||||||
callbackParams: vitest.fn().mockReturnValue({ state: 'state' }),
|
|
||||||
callback: callbackMock,
|
|
||||||
userinfo: userinfoMock,
|
|
||||||
}),
|
|
||||||
} as any);
|
|
||||||
|
|
||||||
({ sut, cryptoMock, eventMock, keyMock, sessionMock, sharedLinkMock, systemMock, userMock } =
|
|
||||||
newTestService(AuthService));
|
newTestService(AuthService));
|
||||||
|
|
||||||
|
oauthMock.authorize.mockResolvedValue('access-token');
|
||||||
|
oauthMock.getProfile.mockResolvedValue({ sub, email });
|
||||||
|
oauthMock.getLogoutEndpoint.mockResolvedValue('http://end-session-endpoint');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should be defined', () => {
|
it('should be defined', () => {
|
||||||
|
@ -515,21 +498,21 @@ describe('AuthService', () => {
|
||||||
expect(userMock.create).toHaveBeenCalledTimes(1);
|
expect(userMock.create).toHaveBeenCalledTimes(1);
|
||||||
});
|
});
|
||||||
|
|
||||||
// TODO write once oidc has been moved to a repo and can be mocked.
|
it('should throw an error if user should be auto registered but the email claim does not exist', async () => {
|
||||||
// it('should throw an error if user should be auto registered but the email claim does not exist', async () => {
|
systemMock.get.mockResolvedValue(systemConfigStub.enabled);
|
||||||
// systemMock.get.mockResolvedValue(systemConfigStub.enabled);
|
userMock.getByEmail.mockResolvedValue(null);
|
||||||
// userMock.getByEmail.mockResolvedValue(null);
|
userMock.getAdmin.mockResolvedValue(userStub.user1);
|
||||||
// userMock.getAdmin.mockResolvedValue(userStub.user1);
|
userMock.create.mockResolvedValue(userStub.user1);
|
||||||
// userMock.create.mockResolvedValue(userStub.user1);
|
sessionMock.create.mockResolvedValue(sessionStub.valid);
|
||||||
// sessionMock.create.mockResolvedValue(sessionStub.valid);
|
oauthMock.getProfile.mockResolvedValue({ sub, email: undefined });
|
||||||
|
|
||||||
// await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).rejects.toBeInstanceOf(
|
await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).rejects.toBeInstanceOf(
|
||||||
// BadRequestException,
|
BadRequestException,
|
||||||
// );
|
);
|
||||||
|
|
||||||
// expect(userMock.getByEmail).toHaveBeenCalledTimes(1);
|
expect(userMock.getByEmail).not.toHaveBeenCalled();
|
||||||
// expect(userMock.create).toHaveBeenCalledTimes(1);
|
expect(userMock.create).not.toHaveBeenCalled();
|
||||||
// });
|
});
|
||||||
|
|
||||||
for (const url of [
|
for (const url of [
|
||||||
'app.immich:/',
|
'app.immich:/',
|
||||||
|
@ -545,7 +528,7 @@ describe('AuthService', () => {
|
||||||
sessionMock.create.mockResolvedValue(sessionStub.valid);
|
sessionMock.create.mockResolvedValue(sessionStub.valid);
|
||||||
|
|
||||||
await sut.callback({ url }, loginDetails);
|
await sut.callback({ url }, loginDetails);
|
||||||
expect(callbackMock).toHaveBeenCalledWith('http://mobile-redirect', { state: 'state' }, { state: 'state' });
|
expect(oauthMock.getProfile).toHaveBeenCalledWith(expect.objectContaining({}), url, 'http://mobile-redirect');
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -567,7 +550,7 @@ describe('AuthService', () => {
|
||||||
userMock.getByEmail.mockResolvedValue(null);
|
userMock.getByEmail.mockResolvedValue(null);
|
||||||
userMock.getAdmin.mockResolvedValue(userStub.user1);
|
userMock.getAdmin.mockResolvedValue(userStub.user1);
|
||||||
userMock.create.mockResolvedValue(userStub.user1);
|
userMock.create.mockResolvedValue(userStub.user1);
|
||||||
userinfoMock.mockResolvedValue({ sub, email, immich_quota: 'abc' });
|
oauthMock.getProfile.mockResolvedValue({ sub, email, immich_quota: 'abc' });
|
||||||
|
|
||||||
await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).resolves.toEqual(
|
await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).resolves.toEqual(
|
||||||
loginResponseStub.user1oauth,
|
loginResponseStub.user1oauth,
|
||||||
|
@ -581,7 +564,7 @@ describe('AuthService', () => {
|
||||||
userMock.getByEmail.mockResolvedValue(null);
|
userMock.getByEmail.mockResolvedValue(null);
|
||||||
userMock.getAdmin.mockResolvedValue(userStub.user1);
|
userMock.getAdmin.mockResolvedValue(userStub.user1);
|
||||||
userMock.create.mockResolvedValue(userStub.user1);
|
userMock.create.mockResolvedValue(userStub.user1);
|
||||||
userinfoMock.mockResolvedValue({ sub, email, immich_quota: -5 });
|
oauthMock.getProfile.mockResolvedValue({ sub, email, immich_quota: -5 });
|
||||||
|
|
||||||
await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).resolves.toEqual(
|
await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).resolves.toEqual(
|
||||||
loginResponseStub.user1oauth,
|
loginResponseStub.user1oauth,
|
||||||
|
@ -595,7 +578,7 @@ describe('AuthService', () => {
|
||||||
userMock.getByEmail.mockResolvedValue(null);
|
userMock.getByEmail.mockResolvedValue(null);
|
||||||
userMock.getAdmin.mockResolvedValue(userStub.user1);
|
userMock.getAdmin.mockResolvedValue(userStub.user1);
|
||||||
userMock.create.mockResolvedValue(userStub.user1);
|
userMock.create.mockResolvedValue(userStub.user1);
|
||||||
userinfoMock.mockResolvedValue({ sub, email, immich_quota: 0 });
|
oauthMock.getProfile.mockResolvedValue({ sub, email, immich_quota: 0 });
|
||||||
|
|
||||||
await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).resolves.toEqual(
|
await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).resolves.toEqual(
|
||||||
loginResponseStub.user1oauth,
|
loginResponseStub.user1oauth,
|
||||||
|
@ -615,7 +598,7 @@ describe('AuthService', () => {
|
||||||
userMock.getByEmail.mockResolvedValue(null);
|
userMock.getByEmail.mockResolvedValue(null);
|
||||||
userMock.getAdmin.mockResolvedValue(userStub.user1);
|
userMock.getAdmin.mockResolvedValue(userStub.user1);
|
||||||
userMock.create.mockResolvedValue(userStub.user1);
|
userMock.create.mockResolvedValue(userStub.user1);
|
||||||
userinfoMock.mockResolvedValue({ sub, email, immich_quota: 5 });
|
oauthMock.getProfile.mockResolvedValue({ sub, email, immich_quota: 5 });
|
||||||
|
|
||||||
await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).resolves.toEqual(
|
await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).resolves.toEqual(
|
||||||
loginResponseStub.user1oauth,
|
loginResponseStub.user1oauth,
|
||||||
|
|
|
@ -1,16 +1,8 @@
|
||||||
import {
|
import { BadRequestException, ForbiddenException, Injectable, UnauthorizedException } from '@nestjs/common';
|
||||||
BadRequestException,
|
|
||||||
ForbiddenException,
|
|
||||||
Injectable,
|
|
||||||
InternalServerErrorException,
|
|
||||||
UnauthorizedException,
|
|
||||||
} from '@nestjs/common';
|
|
||||||
import { isNumber, isString } from 'class-validator';
|
import { isNumber, isString } from 'class-validator';
|
||||||
import cookieParser from 'cookie';
|
import cookieParser from 'cookie';
|
||||||
import { DateTime } from 'luxon';
|
import { DateTime } from 'luxon';
|
||||||
import { IncomingHttpHeaders } from 'node:http';
|
import { IncomingHttpHeaders } from 'node:http';
|
||||||
import { Issuer, UserinfoResponse, custom, generators } from 'openid-client';
|
|
||||||
import { SystemConfig } from 'src/config';
|
|
||||||
import { LOGIN_URL, MOBILE_REDIRECT, SALT_ROUNDS } from 'src/constants';
|
import { LOGIN_URL, MOBILE_REDIRECT, SALT_ROUNDS } from 'src/constants';
|
||||||
import { OnEvent } from 'src/decorators';
|
import { OnEvent } from 'src/decorators';
|
||||||
import {
|
import {
|
||||||
|
@ -30,6 +22,7 @@ import {
|
||||||
import { UserAdminResponseDto, mapUserAdmin } from 'src/dtos/user.dto';
|
import { UserAdminResponseDto, mapUserAdmin } from 'src/dtos/user.dto';
|
||||||
import { UserEntity } from 'src/entities/user.entity';
|
import { UserEntity } from 'src/entities/user.entity';
|
||||||
import { AuthType, Permission } from 'src/enum';
|
import { AuthType, Permission } from 'src/enum';
|
||||||
|
import { OAuthProfile } from 'src/interfaces/oauth.interface';
|
||||||
import { BaseService } from 'src/services/base.service';
|
import { BaseService } from 'src/services/base.service';
|
||||||
import { isGranted } from 'src/utils/access';
|
import { isGranted } from 'src/utils/access';
|
||||||
import { HumanReadableSize } from 'src/utils/bytes';
|
import { HumanReadableSize } from 'src/utils/bytes';
|
||||||
|
@ -42,8 +35,6 @@ export interface LoginDetails {
|
||||||
deviceOS: string;
|
deviceOS: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
type OAuthProfile = UserinfoResponse;
|
|
||||||
|
|
||||||
interface ClaimOptions<T> {
|
interface ClaimOptions<T> {
|
||||||
key: string;
|
key: string;
|
||||||
default: T;
|
default: T;
|
||||||
|
@ -65,7 +56,7 @@ export type ValidateRequest = {
|
||||||
export class AuthService extends BaseService {
|
export class AuthService extends BaseService {
|
||||||
@OnEvent({ name: 'app.bootstrap' })
|
@OnEvent({ name: 'app.bootstrap' })
|
||||||
onBootstrap() {
|
onBootstrap() {
|
||||||
custom.setHttpOptionsDefaults({ timeout: 30_000 });
|
this.oauthRepository.init();
|
||||||
}
|
}
|
||||||
|
|
||||||
async login(dto: LoginCredentialDto, details: LoginDetails) {
|
async login(dto: LoginCredentialDto, details: LoginDetails) {
|
||||||
|
@ -191,21 +182,20 @@ export class AuthService extends BaseService {
|
||||||
}
|
}
|
||||||
|
|
||||||
async authorize(dto: OAuthConfigDto): Promise<OAuthAuthorizeResponseDto> {
|
async authorize(dto: OAuthConfigDto): Promise<OAuthAuthorizeResponseDto> {
|
||||||
const config = await this.getConfig({ withCache: false });
|
const { oauth } = await this.getConfig({ withCache: false });
|
||||||
const client = await this.getOAuthClient(config);
|
|
||||||
const url = client.authorizationUrl({
|
|
||||||
redirect_uri: this.normalize(config, dto.redirectUri),
|
|
||||||
scope: config.oauth.scope,
|
|
||||||
state: generators.state(),
|
|
||||||
});
|
|
||||||
|
|
||||||
|
if (!oauth.enabled) {
|
||||||
|
throw new BadRequestException('OAuth is not enabled');
|
||||||
|
}
|
||||||
|
|
||||||
|
const url = await this.oauthRepository.authorize(oauth, dto.redirectUri);
|
||||||
return { url };
|
return { url };
|
||||||
}
|
}
|
||||||
|
|
||||||
async callback(dto: OAuthCallbackDto, loginDetails: LoginDetails) {
|
async callback(dto: OAuthCallbackDto, loginDetails: LoginDetails) {
|
||||||
const config = await this.getConfig({ withCache: false });
|
const { oauth } = await this.getConfig({ withCache: false });
|
||||||
const profile = await this.getOAuthProfile(config, dto.url);
|
const profile = await this.oauthRepository.getProfile(oauth, dto.url, this.normalize(oauth, dto.url.split('?')[0]));
|
||||||
const { autoRegister, defaultStorageQuota, storageLabelClaim, storageQuotaClaim } = config.oauth;
|
const { autoRegister, defaultStorageQuota, storageLabelClaim, storageQuotaClaim } = oauth;
|
||||||
this.logger.debug(`Logging in with OAuth: ${JSON.stringify(profile)}`);
|
this.logger.debug(`Logging in with OAuth: ${JSON.stringify(profile)}`);
|
||||||
let user = await this.userRepository.getByOAuthId(profile.sub);
|
let user = await this.userRepository.getByOAuthId(profile.sub);
|
||||||
|
|
||||||
|
@ -263,8 +253,12 @@ export class AuthService extends BaseService {
|
||||||
}
|
}
|
||||||
|
|
||||||
async link(auth: AuthDto, dto: OAuthCallbackDto): Promise<UserAdminResponseDto> {
|
async link(auth: AuthDto, dto: OAuthCallbackDto): Promise<UserAdminResponseDto> {
|
||||||
const config = await this.getConfig({ withCache: false });
|
const { oauth } = await this.getConfig({ withCache: false });
|
||||||
const { sub: oauthId } = await this.getOAuthProfile(config, dto.url);
|
const { sub: oauthId } = await this.oauthRepository.getProfile(
|
||||||
|
oauth,
|
||||||
|
dto.url,
|
||||||
|
this.normalize(oauth, dto.url.split('?')[0]),
|
||||||
|
);
|
||||||
const duplicate = await this.userRepository.getByOAuthId(oauthId);
|
const duplicate = await this.userRepository.getByOAuthId(oauthId);
|
||||||
if (duplicate && duplicate.id !== auth.user.id) {
|
if (duplicate && duplicate.id !== auth.user.id) {
|
||||||
this.logger.warn(`OAuth link account failed: sub is already linked to another user (${duplicate.email}).`);
|
this.logger.warn(`OAuth link account failed: sub is already linked to another user (${duplicate.email}).`);
|
||||||
|
@ -290,60 +284,7 @@ export class AuthService extends BaseService {
|
||||||
return LOGIN_URL;
|
return LOGIN_URL;
|
||||||
}
|
}
|
||||||
|
|
||||||
const client = await this.getOAuthClient(config);
|
return (await this.oauthRepository.getLogoutEndpoint(config.oauth)) || LOGIN_URL;
|
||||||
return client.issuer.metadata.end_session_endpoint || LOGIN_URL;
|
|
||||||
}
|
|
||||||
|
|
||||||
private async getOAuthProfile(config: SystemConfig, url: string): Promise<OAuthProfile> {
|
|
||||||
const redirectUri = this.normalize(config, url.split('?')[0]);
|
|
||||||
const client = await this.getOAuthClient(config);
|
|
||||||
const params = client.callbackParams(url);
|
|
||||||
try {
|
|
||||||
const tokens = await client.callback(redirectUri, params, { state: params.state });
|
|
||||||
return client.userinfo<OAuthProfile>(tokens.access_token || '');
|
|
||||||
} catch (error: Error | any) {
|
|
||||||
if (error.message.includes('unexpected JWT alg received')) {
|
|
||||||
this.logger.warn(
|
|
||||||
[
|
|
||||||
'Algorithm mismatch. Make sure the signing algorithm is set correctly in the OAuth settings.',
|
|
||||||
'Or, that you have specified a signing key in your OAuth provider.',
|
|
||||||
].join(' '),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
throw error;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private async getOAuthClient(config: SystemConfig) {
|
|
||||||
const { enabled, clientId, clientSecret, issuerUrl, signingAlgorithm, profileSigningAlgorithm } = config.oauth;
|
|
||||||
|
|
||||||
if (!enabled) {
|
|
||||||
throw new BadRequestException('OAuth2 is not enabled');
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
const issuer = await Issuer.discover(issuerUrl);
|
|
||||||
return new issuer.Client({
|
|
||||||
client_id: clientId,
|
|
||||||
client_secret: clientSecret,
|
|
||||||
response_types: ['code'],
|
|
||||||
userinfo_signed_response_alg: profileSigningAlgorithm === 'none' ? undefined : profileSigningAlgorithm,
|
|
||||||
id_token_signed_response_alg: signingAlgorithm,
|
|
||||||
});
|
|
||||||
} catch (error: any | AggregateError) {
|
|
||||||
this.logger.error(`Error in OAuth discovery: ${error}`, error?.stack, error?.errors);
|
|
||||||
throw new InternalServerErrorException(`Error in OAuth discovery: ${error}`, { cause: error });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private normalize(config: SystemConfig, redirectUri: string) {
|
|
||||||
const isMobile = redirectUri.startsWith('app.immich:/');
|
|
||||||
const { mobileRedirectUri, mobileOverrideEnabled } = config.oauth;
|
|
||||||
if (isMobile && mobileOverrideEnabled && mobileRedirectUri) {
|
|
||||||
return mobileRedirectUri;
|
|
||||||
}
|
|
||||||
return redirectUri;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private getBearerToken(headers: IncomingHttpHeaders): string | null {
|
private getBearerToken(headers: IncomingHttpHeaders): string | null {
|
||||||
|
@ -427,4 +368,15 @@ export class AuthService extends BaseService {
|
||||||
const value = profile[options.key as keyof OAuthProfile];
|
const value = profile[options.key as keyof OAuthProfile];
|
||||||
return options.isValid(value) ? (value as T) : options.default;
|
return options.isValid(value) ? (value as T) : options.default;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private normalize(
|
||||||
|
{ mobileRedirectUri, mobileOverrideEnabled }: { mobileRedirectUri: string; mobileOverrideEnabled: boolean },
|
||||||
|
redirectUri: string,
|
||||||
|
) {
|
||||||
|
const isMobile = redirectUri.startsWith('app.immich:/');
|
||||||
|
if (isMobile && mobileOverrideEnabled && mobileRedirectUri) {
|
||||||
|
return mobileRedirectUri;
|
||||||
|
}
|
||||||
|
return redirectUri;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ import { IMetadataRepository } from 'src/interfaces/metadata.interface';
|
||||||
import { IMetricRepository } from 'src/interfaces/metric.interface';
|
import { IMetricRepository } from 'src/interfaces/metric.interface';
|
||||||
import { IMoveRepository } from 'src/interfaces/move.interface';
|
import { IMoveRepository } from 'src/interfaces/move.interface';
|
||||||
import { INotificationRepository } from 'src/interfaces/notification.interface';
|
import { INotificationRepository } from 'src/interfaces/notification.interface';
|
||||||
|
import { IOAuthRepository } from 'src/interfaces/oauth.interface';
|
||||||
import { IPartnerRepository } from 'src/interfaces/partner.interface';
|
import { IPartnerRepository } from 'src/interfaces/partner.interface';
|
||||||
import { IPersonRepository } from 'src/interfaces/person.interface';
|
import { IPersonRepository } from 'src/interfaces/person.interface';
|
||||||
import { ISearchRepository } from 'src/interfaces/search.interface';
|
import { ISearchRepository } from 'src/interfaces/search.interface';
|
||||||
|
@ -65,6 +66,7 @@ export class BaseService {
|
||||||
@Inject(IMetricRepository) protected metricRepository: IMetricRepository,
|
@Inject(IMetricRepository) protected metricRepository: IMetricRepository,
|
||||||
@Inject(IMoveRepository) protected moveRepository: IMoveRepository,
|
@Inject(IMoveRepository) protected moveRepository: IMoveRepository,
|
||||||
@Inject(INotificationRepository) protected notificationRepository: INotificationRepository,
|
@Inject(INotificationRepository) protected notificationRepository: INotificationRepository,
|
||||||
|
@Inject(IOAuthRepository) protected oauthRepository: IOAuthRepository,
|
||||||
@Inject(IPartnerRepository) protected partnerRepository: IPartnerRepository,
|
@Inject(IPartnerRepository) protected partnerRepository: IPartnerRepository,
|
||||||
@Inject(IPersonRepository) protected personRepository: IPersonRepository,
|
@Inject(IPersonRepository) protected personRepository: IPersonRepository,
|
||||||
@Inject(ISearchRepository) protected searchRepository: ISearchRepository,
|
@Inject(ISearchRepository) protected searchRepository: ISearchRepository,
|
||||||
|
|
11
server/test/repositories/oauth.repository.mock.ts
Normal file
11
server/test/repositories/oauth.repository.mock.ts
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
import { IOAuthRepository } from 'src/interfaces/oauth.interface';
|
||||||
|
import { Mocked } from 'vitest';
|
||||||
|
|
||||||
|
export const newOAuthRepositoryMock = (): Mocked<IOAuthRepository> => {
|
||||||
|
return {
|
||||||
|
init: vitest.fn(),
|
||||||
|
authorize: vitest.fn(),
|
||||||
|
getLogoutEndpoint: vitest.fn(),
|
||||||
|
getProfile: vitest.fn(),
|
||||||
|
};
|
||||||
|
};
|
|
@ -21,6 +21,7 @@ import { newMetadataRepositoryMock } from 'test/repositories/metadata.repository
|
||||||
import { newMetricRepositoryMock } from 'test/repositories/metric.repository.mock';
|
import { newMetricRepositoryMock } from 'test/repositories/metric.repository.mock';
|
||||||
import { newMoveRepositoryMock } from 'test/repositories/move.repository.mock';
|
import { newMoveRepositoryMock } from 'test/repositories/move.repository.mock';
|
||||||
import { newNotificationRepositoryMock } from 'test/repositories/notification.repository.mock';
|
import { newNotificationRepositoryMock } from 'test/repositories/notification.repository.mock';
|
||||||
|
import { newOAuthRepositoryMock } from 'test/repositories/oauth.repository.mock';
|
||||||
import { newPartnerRepositoryMock } from 'test/repositories/partner.repository.mock';
|
import { newPartnerRepositoryMock } from 'test/repositories/partner.repository.mock';
|
||||||
import { newPersonRepositoryMock } from 'test/repositories/person.repository.mock';
|
import { newPersonRepositoryMock } from 'test/repositories/person.repository.mock';
|
||||||
import { newSearchRepositoryMock } from 'test/repositories/search.repository.mock';
|
import { newSearchRepositoryMock } from 'test/repositories/search.repository.mock';
|
||||||
|
@ -64,6 +65,7 @@ export const newTestService = <T extends BaseService>(Service: Constructor<T, Ba
|
||||||
const metricMock = newMetricRepositoryMock();
|
const metricMock = newMetricRepositoryMock();
|
||||||
const moveMock = newMoveRepositoryMock();
|
const moveMock = newMoveRepositoryMock();
|
||||||
const notificationMock = newNotificationRepositoryMock();
|
const notificationMock = newNotificationRepositoryMock();
|
||||||
|
const oauthMock = newOAuthRepositoryMock();
|
||||||
const partnerMock = newPartnerRepositoryMock();
|
const partnerMock = newPartnerRepositoryMock();
|
||||||
const personMock = newPersonRepositoryMock();
|
const personMock = newPersonRepositoryMock();
|
||||||
const searchMock = newSearchRepositoryMock();
|
const searchMock = newSearchRepositoryMock();
|
||||||
|
@ -102,6 +104,7 @@ export const newTestService = <T extends BaseService>(Service: Constructor<T, Ba
|
||||||
metricMock,
|
metricMock,
|
||||||
moveMock,
|
moveMock,
|
||||||
notificationMock,
|
notificationMock,
|
||||||
|
oauthMock,
|
||||||
partnerMock,
|
partnerMock,
|
||||||
personMock,
|
personMock,
|
||||||
searchMock,
|
searchMock,
|
||||||
|
@ -142,6 +145,7 @@ export const newTestService = <T extends BaseService>(Service: Constructor<T, Ba
|
||||||
metricMock,
|
metricMock,
|
||||||
moveMock,
|
moveMock,
|
||||||
notificationMock,
|
notificationMock,
|
||||||
|
oauthMock,
|
||||||
partnerMock,
|
partnerMock,
|
||||||
personMock,
|
personMock,
|
||||||
searchMock,
|
searchMock,
|
||||||
|
|
Loading…
Add table
Reference in a new issue