ws auth
This commit is contained in:
@@ -40,6 +40,7 @@
|
|||||||
"class-transformer": "^0.5.1",
|
"class-transformer": "^0.5.1",
|
||||||
"class-validator": "^0.14.2",
|
"class-validator": "^0.14.2",
|
||||||
"dotenv": "^17.2.3",
|
"dotenv": "^17.2.3",
|
||||||
|
"jsonwebtoken": "^9.0.2",
|
||||||
"jwks-rsa": "^3.2.0",
|
"jwks-rsa": "^3.2.0",
|
||||||
"passport": "^0.7.0",
|
"passport": "^0.7.0",
|
||||||
"passport-jwt": "^4.0.1",
|
"passport-jwt": "^4.0.1",
|
||||||
@@ -56,6 +57,7 @@
|
|||||||
"@nestjs/testing": "^11.0.1",
|
"@nestjs/testing": "^11.0.1",
|
||||||
"@types/express": "^5.0.0",
|
"@types/express": "^5.0.0",
|
||||||
"@types/jest": "^30.0.0",
|
"@types/jest": "^30.0.0",
|
||||||
|
"@types/jsonwebtoken": "^9.0.7",
|
||||||
"@types/node": "^22.10.7",
|
"@types/node": "^22.10.7",
|
||||||
"@types/passport-jwt": "^4.0.1",
|
"@types/passport-jwt": "^4.0.1",
|
||||||
"@types/pg": "^8.15.6",
|
"@types/pg": "^8.15.6",
|
||||||
|
|||||||
14
pnpm-lock.yaml
generated
14
pnpm-lock.yaml
generated
@@ -28,7 +28,7 @@ importers:
|
|||||||
version: 11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/websockets@11.1.9)(rxjs@7.8.2)
|
version: 11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/websockets@11.1.9)(rxjs@7.8.2)
|
||||||
'@nestjs/swagger':
|
'@nestjs/swagger':
|
||||||
specifier: ^11.2.3
|
specifier: ^11.2.3
|
||||||
version: 11.2.3(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9)(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)
|
version: 11.2.3(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/platform-express@11.1.9)(@nestjs/websockets@11.1.9)(reflect-metadata@0.2.2)(rxjs@7.8.2))(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)
|
||||||
'@nestjs/websockets':
|
'@nestjs/websockets':
|
||||||
specifier: ^11.1.9
|
specifier: ^11.1.9
|
||||||
version: 11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9)(@nestjs/platform-socket.io@11.1.9)(reflect-metadata@0.2.2)(rxjs@7.8.2)
|
version: 11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9)(@nestjs/platform-socket.io@11.1.9)(reflect-metadata@0.2.2)(rxjs@7.8.2)
|
||||||
@@ -47,6 +47,9 @@ importers:
|
|||||||
dotenv:
|
dotenv:
|
||||||
specifier: ^17.2.3
|
specifier: ^17.2.3
|
||||||
version: 17.2.3
|
version: 17.2.3
|
||||||
|
jsonwebtoken:
|
||||||
|
specifier: ^9.0.2
|
||||||
|
version: 9.0.2
|
||||||
jwks-rsa:
|
jwks-rsa:
|
||||||
specifier: ^3.2.0
|
specifier: ^3.2.0
|
||||||
version: 3.2.0
|
version: 3.2.0
|
||||||
@@ -83,13 +86,16 @@ importers:
|
|||||||
version: 11.0.9(chokidar@4.0.3)(typescript@5.9.3)
|
version: 11.0.9(chokidar@4.0.3)(typescript@5.9.3)
|
||||||
'@nestjs/testing':
|
'@nestjs/testing':
|
||||||
specifier: ^11.0.1
|
specifier: ^11.0.1
|
||||||
version: 11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9)(@nestjs/platform-express@11.1.9)
|
version: 11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/platform-express@11.1.9)(@nestjs/websockets@11.1.9)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/platform-express@11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9))
|
||||||
'@types/express':
|
'@types/express':
|
||||||
specifier: ^5.0.0
|
specifier: ^5.0.0
|
||||||
version: 5.0.5
|
version: 5.0.5
|
||||||
'@types/jest':
|
'@types/jest':
|
||||||
specifier: ^30.0.0
|
specifier: ^30.0.0
|
||||||
version: 30.0.0
|
version: 30.0.0
|
||||||
|
'@types/jsonwebtoken':
|
||||||
|
specifier: ^9.0.7
|
||||||
|
version: 9.0.10
|
||||||
'@types/node':
|
'@types/node':
|
||||||
specifier: ^22.10.7
|
specifier: ^22.10.7
|
||||||
version: 22.19.1
|
version: 22.19.1
|
||||||
@@ -4532,7 +4538,7 @@ snapshots:
|
|||||||
transitivePeerDependencies:
|
transitivePeerDependencies:
|
||||||
- chokidar
|
- chokidar
|
||||||
|
|
||||||
'@nestjs/swagger@11.2.3(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9)(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)':
|
'@nestjs/swagger@11.2.3(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/platform-express@11.1.9)(@nestjs/websockets@11.1.9)(reflect-metadata@0.2.2)(rxjs@7.8.2))(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)':
|
||||||
dependencies:
|
dependencies:
|
||||||
'@microsoft/tsdoc': 0.16.0
|
'@microsoft/tsdoc': 0.16.0
|
||||||
'@nestjs/common': 11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2)
|
'@nestjs/common': 11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2)
|
||||||
@@ -4547,7 +4553,7 @@ snapshots:
|
|||||||
class-transformer: 0.5.1
|
class-transformer: 0.5.1
|
||||||
class-validator: 0.14.2
|
class-validator: 0.14.2
|
||||||
|
|
||||||
'@nestjs/testing@11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9)(@nestjs/platform-express@11.1.9)':
|
'@nestjs/testing@11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/platform-express@11.1.9)(@nestjs/websockets@11.1.9)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/platform-express@11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9))':
|
||||||
dependencies:
|
dependencies:
|
||||||
'@nestjs/common': 11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2)
|
'@nestjs/common': 11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2)
|
||||||
'@nestjs/core': 11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/platform-express@11.1.9)(@nestjs/websockets@11.1.9)(reflect-metadata@0.2.2)(rxjs@7.8.2)
|
'@nestjs/core': 11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/platform-express@11.1.9)(@nestjs/websockets@11.1.9)(reflect-metadata@0.2.2)(rxjs@7.8.2)
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
packages:
|
||||||
|
- '.'
|
||||||
|
|
||||||
onlyBuiltDependencies:
|
onlyBuiltDependencies:
|
||||||
- '@nestjs/core'
|
- '@nestjs/core'
|
||||||
- '@prisma/engines'
|
- '@prisma/engines'
|
||||||
|
|||||||
@@ -3,46 +3,16 @@ import { ConfigModule } from '@nestjs/config';
|
|||||||
import { PassportModule } from '@nestjs/passport';
|
import { PassportModule } from '@nestjs/passport';
|
||||||
import { AuthService } from './auth.service';
|
import { AuthService } from './auth.service';
|
||||||
import { JwtStrategy } from './strategies/jwt.strategy';
|
import { JwtStrategy } from './strategies/jwt.strategy';
|
||||||
|
import { JwtVerificationService } from './services/jwt-verification.service';
|
||||||
import { UsersModule } from '../users/users.module';
|
import { UsersModule } from '../users/users.module';
|
||||||
|
|
||||||
/**
|
|
||||||
* Authentication Module
|
|
||||||
*
|
|
||||||
* Provides Keycloak OpenID Connect authentication using JWT tokens.
|
|
||||||
* This module configures:
|
|
||||||
* - Passport for authentication strategies
|
|
||||||
* - JWT strategy for validating Keycloak tokens
|
|
||||||
* - Integration with UsersModule for user synchronization
|
|
||||||
*
|
|
||||||
* The module requires the following environment variables:
|
|
||||||
* - JWT_ISSUER: Expected JWT issuer
|
|
||||||
* - JWT_AUDIENCE: Expected JWT audience
|
|
||||||
* - JWKS_URI: URI for fetching Keycloak's public keys
|
|
||||||
*/
|
|
||||||
@Module({
|
@Module({
|
||||||
imports: [
|
imports: [
|
||||||
// Import ConfigModule to access environment variables
|
|
||||||
ConfigModule,
|
ConfigModule,
|
||||||
|
|
||||||
// Import PassportModule for authentication strategies
|
|
||||||
PassportModule.register({ defaultStrategy: 'jwt' }),
|
PassportModule.register({ defaultStrategy: 'jwt' }),
|
||||||
|
|
||||||
// Import UsersModule to enable user synchronization (with forwardRef to avoid circular dependency)
|
|
||||||
forwardRef(() => UsersModule),
|
forwardRef(() => UsersModule),
|
||||||
],
|
],
|
||||||
providers: [
|
providers: [JwtStrategy, AuthService, JwtVerificationService],
|
||||||
// Register the JWT strategy for validating Keycloak tokens
|
exports: [AuthService, PassportModule, JwtVerificationService],
|
||||||
JwtStrategy,
|
|
||||||
|
|
||||||
// Register the auth service for business logic
|
|
||||||
AuthService,
|
|
||||||
],
|
|
||||||
exports: [
|
|
||||||
// Export AuthService so other modules can use it
|
|
||||||
AuthService,
|
|
||||||
|
|
||||||
// Export PassportModule so guards can be used in other modules
|
|
||||||
PassportModule,
|
|
||||||
],
|
|
||||||
})
|
})
|
||||||
export class AuthModule {}
|
export class AuthModule {}
|
||||||
|
|||||||
79
src/auth/services/jwt-verification.service.spec.ts
Normal file
79
src/auth/services/jwt-verification.service.spec.ts
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import { Test, TestingModule } from '@nestjs/testing';
|
||||||
|
import { ConfigService } from '@nestjs/config';
|
||||||
|
import { JwtVerificationService } from './jwt-verification.service';
|
||||||
|
|
||||||
|
describe('JwtVerificationService', () => {
|
||||||
|
let service: JwtVerificationService;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
const mockConfigService = {
|
||||||
|
get: jest.fn((key: string) => {
|
||||||
|
const config: Record<string, string> = {
|
||||||
|
JWKS_URI: 'https://test.com/.well-known/jwks.json',
|
||||||
|
JWT_ISSUER: 'https://test.com',
|
||||||
|
JWT_AUDIENCE: 'test-audience',
|
||||||
|
};
|
||||||
|
return config[key];
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
const module: TestingModule = await Test.createTestingModule({
|
||||||
|
providers: [
|
||||||
|
JwtVerificationService,
|
||||||
|
{ provide: ConfigService, useValue: mockConfigService },
|
||||||
|
],
|
||||||
|
}).compile();
|
||||||
|
|
||||||
|
service = module.get<JwtVerificationService>(JwtVerificationService);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should be defined', () => {
|
||||||
|
expect(service).toBeDefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('extractToken', () => {
|
||||||
|
it('should extract token from auth object', () => {
|
||||||
|
const handshake = {
|
||||||
|
auth: { token: 'test-token' },
|
||||||
|
headers: {},
|
||||||
|
};
|
||||||
|
|
||||||
|
const token = service.extractToken(handshake);
|
||||||
|
|
||||||
|
expect(token).toBe('test-token');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should extract token from Authorization header', () => {
|
||||||
|
const handshake = {
|
||||||
|
auth: {},
|
||||||
|
headers: { authorization: 'Bearer test-token' },
|
||||||
|
};
|
||||||
|
|
||||||
|
const token = service.extractToken(handshake);
|
||||||
|
|
||||||
|
expect(token).toBe('test-token');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should prioritize auth.token over header', () => {
|
||||||
|
const handshake = {
|
||||||
|
auth: { token: 'auth-token' },
|
||||||
|
headers: { authorization: 'Bearer header-token' },
|
||||||
|
};
|
||||||
|
|
||||||
|
const token = service.extractToken(handshake);
|
||||||
|
|
||||||
|
expect(token).toBe('auth-token');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return undefined when no token present', () => {
|
||||||
|
const handshake = {
|
||||||
|
auth: {},
|
||||||
|
headers: {},
|
||||||
|
};
|
||||||
|
|
||||||
|
const token = service.extractToken(handshake);
|
||||||
|
|
||||||
|
expect(token).toBeUndefined();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
93
src/auth/services/jwt-verification.service.ts
Normal file
93
src/auth/services/jwt-verification.service.ts
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
import { Injectable, Logger } from '@nestjs/common';
|
||||||
|
import { ConfigService } from '@nestjs/config';
|
||||||
|
import { verify, type JwtHeader } from 'jsonwebtoken';
|
||||||
|
import { JwksClient, type SigningKey } from 'jwks-rsa';
|
||||||
|
import type { JwtPayload } from '../strategies/jwt.strategy';
|
||||||
|
|
||||||
|
const JWT_ALGORITHM = 'RS256';
|
||||||
|
const BEARER_PREFIX = 'Bearer ';
|
||||||
|
|
||||||
|
@Injectable()
|
||||||
|
export class JwtVerificationService {
|
||||||
|
private readonly logger = new Logger(JwtVerificationService.name);
|
||||||
|
private readonly jwksClient: JwksClient;
|
||||||
|
private readonly issuer: string;
|
||||||
|
private readonly audience: string | undefined;
|
||||||
|
|
||||||
|
constructor(private readonly configService: ConfigService) {
|
||||||
|
const jwksUri = this.configService.get<string>('JWKS_URI');
|
||||||
|
this.issuer = this.configService.get<string>('JWT_ISSUER') || '';
|
||||||
|
this.audience = this.configService.get<string>('JWT_AUDIENCE');
|
||||||
|
|
||||||
|
if (!jwksUri) {
|
||||||
|
throw new Error('JWKS_URI must be configured');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!this.issuer) {
|
||||||
|
throw new Error('JWT_ISSUER must be configured');
|
||||||
|
}
|
||||||
|
|
||||||
|
this.jwksClient = new JwksClient({
|
||||||
|
jwksUri,
|
||||||
|
cache: true,
|
||||||
|
rateLimit: true,
|
||||||
|
jwksRequestsPerMinute: 5,
|
||||||
|
});
|
||||||
|
|
||||||
|
this.logger.log('JWT Verification Service initialized');
|
||||||
|
}
|
||||||
|
|
||||||
|
async verifyToken(token: string): Promise<JwtPayload> {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
const getKey = (
|
||||||
|
header: JwtHeader,
|
||||||
|
callback: (err: Error | null, signingKey?: string | Buffer) => void,
|
||||||
|
) => {
|
||||||
|
this.jwksClient.getSigningKey(
|
||||||
|
header.kid,
|
||||||
|
(err: Error | null, key?: SigningKey) => {
|
||||||
|
if (err) {
|
||||||
|
callback(err);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const signingKey = key?.getPublicKey();
|
||||||
|
callback(null, signingKey);
|
||||||
|
},
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
verify(
|
||||||
|
token,
|
||||||
|
getKey,
|
||||||
|
{
|
||||||
|
issuer: this.issuer,
|
||||||
|
audience: this.audience,
|
||||||
|
algorithms: [JWT_ALGORITHM],
|
||||||
|
},
|
||||||
|
(err, decoded) => {
|
||||||
|
if (err) {
|
||||||
|
reject(err);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
resolve(decoded as JwtPayload);
|
||||||
|
},
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
extractToken(handshake: {
|
||||||
|
auth?: { token?: string };
|
||||||
|
headers?: { authorization?: string };
|
||||||
|
}): string | undefined {
|
||||||
|
if (handshake.auth?.token) {
|
||||||
|
return handshake.auth.token;
|
||||||
|
}
|
||||||
|
|
||||||
|
const authHeader = handshake.headers?.authorization;
|
||||||
|
if (authHeader?.startsWith(BEARER_PREFIX)) {
|
||||||
|
return authHeader.replace(BEARER_PREFIX, '');
|
||||||
|
}
|
||||||
|
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
}
|
||||||
10
src/types/socket.d.ts
vendored
Normal file
10
src/types/socket.d.ts
vendored
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
import type { Socket as BaseSocket } from 'socket.io';
|
||||||
|
import type { AuthenticatedUser } from '../auth/decorators/current-user.decorator';
|
||||||
|
import type { DefaultEventsMap } from 'socket.io/dist/typed-events';
|
||||||
|
|
||||||
|
export type AuthenticatedSocket = BaseSocket<
|
||||||
|
DefaultEventsMap, // ClientToServerEvents
|
||||||
|
DefaultEventsMap, // ServerToClientEvents
|
||||||
|
DefaultEventsMap, // InterServerEvents
|
||||||
|
{ user?: AuthenticatedUser }
|
||||||
|
>;
|
||||||
@@ -16,6 +16,6 @@ import { AuthModule } from '../auth/auth.module';
|
|||||||
imports: [forwardRef(() => AuthModule)],
|
imports: [forwardRef(() => AuthModule)],
|
||||||
providers: [UsersService],
|
providers: [UsersService],
|
||||||
controllers: [UsersController],
|
controllers: [UsersController],
|
||||||
exports: [UsersService], // Export so AuthModule can use it
|
exports: [UsersService],
|
||||||
})
|
})
|
||||||
export class UsersModule {}
|
export class UsersModule {}
|
||||||
|
|||||||
9
src/ws/dto/cursor-position.dto.ts
Normal file
9
src/ws/dto/cursor-position.dto.ts
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
import { IsNumber } from 'class-validator';
|
||||||
|
|
||||||
|
export class CursorPositionDto {
|
||||||
|
@IsNumber()
|
||||||
|
x: number;
|
||||||
|
|
||||||
|
@IsNumber()
|
||||||
|
y: number;
|
||||||
|
}
|
||||||
@@ -1,12 +1,28 @@
|
|||||||
import { Test, TestingModule } from '@nestjs/testing';
|
import { Test, TestingModule } from '@nestjs/testing';
|
||||||
import { StateGateway } from './state.gateway';
|
import { StateGateway } from './state.gateway';
|
||||||
import { Socket } from 'socket.io';
|
import { AuthenticatedSocket } from '../../types/socket';
|
||||||
|
import { AuthService } from '../../auth/auth.service';
|
||||||
|
import { JwtVerificationService } from '../../auth/services/jwt-verification.service';
|
||||||
|
|
||||||
|
interface MockSocket extends Partial<AuthenticatedSocket> {
|
||||||
|
id: string;
|
||||||
|
data: {
|
||||||
|
user?: {
|
||||||
|
keycloakSub: string;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
handshake?: any;
|
||||||
|
disconnect?: jest.Mock;
|
||||||
|
}
|
||||||
|
|
||||||
describe('StateGateway', () => {
|
describe('StateGateway', () => {
|
||||||
let gateway: StateGateway;
|
let gateway: StateGateway;
|
||||||
let mockLoggerLog: jest.SpyInstance;
|
let mockLoggerLog: jest.SpyInstance;
|
||||||
let mockLoggerDebug: jest.SpyInstance;
|
let mockLoggerDebug: jest.SpyInstance;
|
||||||
let mockServer: any;
|
let mockLoggerWarn: jest.SpyInstance;
|
||||||
|
let mockServer: { sockets: { sockets: { size: number } } };
|
||||||
|
let mockAuthService: Partial<AuthService>;
|
||||||
|
let mockJwtVerificationService: Partial<JwtVerificationService>;
|
||||||
|
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
mockServer = {
|
mockServer = {
|
||||||
@@ -17,23 +33,40 @@ describe('StateGateway', () => {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
mockAuthService = {
|
||||||
|
syncUserFromToken: jest.fn().mockResolvedValue({
|
||||||
|
id: 'user-id',
|
||||||
|
keycloakSub: 'test-sub',
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
mockJwtVerificationService = {
|
||||||
|
extractToken: jest.fn((handshake) => handshake.auth?.token),
|
||||||
|
verifyToken: jest.fn().mockResolvedValue({
|
||||||
|
sub: 'test-sub',
|
||||||
|
email: 'test@example.com',
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
const module: TestingModule = await Test.createTestingModule({
|
const module: TestingModule = await Test.createTestingModule({
|
||||||
providers: [StateGateway],
|
providers: [
|
||||||
|
StateGateway,
|
||||||
|
{ provide: AuthService, useValue: mockAuthService },
|
||||||
|
{
|
||||||
|
provide: JwtVerificationService,
|
||||||
|
useValue: mockJwtVerificationService,
|
||||||
|
},
|
||||||
|
],
|
||||||
}).compile();
|
}).compile();
|
||||||
|
|
||||||
gateway = module.get<StateGateway>(StateGateway);
|
gateway = module.get<StateGateway>(StateGateway);
|
||||||
gateway.io = mockServer;
|
gateway.io = mockServer as any;
|
||||||
|
|
||||||
// Spy on logger methods
|
mockLoggerLog = jest.spyOn(gateway['logger'], 'log').mockImplementation();
|
||||||
mockLoggerLog = jest
|
|
||||||
// gateway is private in `state.gateway.ts` so we have to
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access
|
|
||||||
.spyOn((gateway as any).logger, 'log')
|
|
||||||
.mockImplementation();
|
|
||||||
mockLoggerDebug = jest
|
mockLoggerDebug = jest
|
||||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access
|
.spyOn(gateway['logger'], 'debug')
|
||||||
.spyOn((gateway as any).logger, 'debug')
|
|
||||||
.mockImplementation();
|
.mockImplementation();
|
||||||
|
mockLoggerWarn = jest.spyOn(gateway['logger'], 'warn').mockImplementation();
|
||||||
});
|
});
|
||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
@@ -53,45 +86,127 @@ describe('StateGateway', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
describe('handleConnection', () => {
|
describe('handleConnection', () => {
|
||||||
it('should log client connection and number of connected clients', () => {
|
it('should log client connection and sync user when authenticated', async () => {
|
||||||
const mockClient = { id: 'client1' } as Socket;
|
const mockClient: MockSocket = {
|
||||||
|
id: 'client1',
|
||||||
|
data: { user: { keycloakSub: 'test-sub' } },
|
||||||
|
handshake: {
|
||||||
|
auth: { token: 'mock-token' },
|
||||||
|
headers: {},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
gateway.handleConnection(mockClient);
|
await gateway.handleConnection(
|
||||||
|
mockClient as unknown as AuthenticatedSocket,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(mockJwtVerificationService.extractToken).toHaveBeenCalledWith(
|
||||||
|
mockClient.handshake,
|
||||||
|
);
|
||||||
|
expect(mockJwtVerificationService.verifyToken).toHaveBeenCalledWith(
|
||||||
|
'mock-token',
|
||||||
|
);
|
||||||
|
expect(mockAuthService.syncUserFromToken).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
keycloakSub: 'test-sub',
|
||||||
|
}),
|
||||||
|
);
|
||||||
expect(mockLoggerLog).toHaveBeenCalledWith(
|
expect(mockLoggerLog).toHaveBeenCalledWith(
|
||||||
`Client id: ${mockClient.id} connected`,
|
`Client id: ${mockClient.id} connected (user: test-sub)`,
|
||||||
);
|
);
|
||||||
expect(mockLoggerDebug).toHaveBeenCalledWith(
|
expect(mockLoggerDebug).toHaveBeenCalledWith(
|
||||||
'Number of connected clients: 5',
|
'Number of connected clients: 5',
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should disconnect client when no token provided', async () => {
|
||||||
|
const mockClient: MockSocket = {
|
||||||
|
id: 'client1',
|
||||||
|
data: {},
|
||||||
|
handshake: {
|
||||||
|
auth: {},
|
||||||
|
headers: {},
|
||||||
|
},
|
||||||
|
disconnect: jest.fn(),
|
||||||
|
};
|
||||||
|
|
||||||
|
(mockJwtVerificationService.extractToken as jest.Mock).mockReturnValue(
|
||||||
|
undefined,
|
||||||
|
);
|
||||||
|
|
||||||
|
await gateway.handleConnection(
|
||||||
|
mockClient as unknown as AuthenticatedSocket,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(mockLoggerWarn).toHaveBeenCalledWith(
|
||||||
|
'WebSocket connection attempt without token',
|
||||||
|
);
|
||||||
|
expect(mockClient.disconnect).toHaveBeenCalled();
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('handleDisconnect', () => {
|
describe('handleDisconnect', () => {
|
||||||
it('should log client disconnection', () => {
|
it('should log client disconnection', () => {
|
||||||
const mockClient = { id: 'client1' } as Socket;
|
const mockClient: MockSocket = {
|
||||||
|
id: 'client1',
|
||||||
|
data: { user: { keycloakSub: 'test-sub' } },
|
||||||
|
};
|
||||||
|
|
||||||
gateway.handleDisconnect(mockClient);
|
gateway.handleDisconnect(mockClient as unknown as AuthenticatedSocket);
|
||||||
|
|
||||||
expect(mockLoggerLog).toHaveBeenCalledWith(
|
expect(mockLoggerLog).toHaveBeenCalledWith(
|
||||||
`Cliend id:${mockClient.id} disconnected`,
|
`Client id: ${mockClient.id} disconnected (user: test-sub)`,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle disconnection when no user data', () => {
|
||||||
|
const mockClient: MockSocket = {
|
||||||
|
id: 'client1',
|
||||||
|
data: {},
|
||||||
|
};
|
||||||
|
|
||||||
|
gateway.handleDisconnect(mockClient as unknown as AuthenticatedSocket);
|
||||||
|
|
||||||
|
expect(mockLoggerLog).toHaveBeenCalledWith(
|
||||||
|
`Client id: ${mockClient.id} disconnected (user: unknown)`,
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('handleCursorReportPosition', () => {
|
describe('handleCursorReportPosition', () => {
|
||||||
it('should log message received from client', () => {
|
it('should log message received from authenticated client', () => {
|
||||||
const mockClient = { id: 'client1' } as Socket;
|
const mockClient: MockSocket = {
|
||||||
|
id: 'client1',
|
||||||
|
data: { user: { keycloakSub: 'test-sub' } },
|
||||||
|
};
|
||||||
const data = { x: 100, y: 200 };
|
const data = { x: 100, y: 200 };
|
||||||
|
|
||||||
gateway.handleCursorReportPosition(mockClient, data);
|
gateway.handleCursorReportPosition(
|
||||||
|
mockClient as unknown as AuthenticatedSocket,
|
||||||
|
data,
|
||||||
|
);
|
||||||
|
|
||||||
expect(mockLoggerLog).toHaveBeenCalledWith(
|
expect(mockLoggerLog).toHaveBeenCalledWith(
|
||||||
`Message received from client id: ${mockClient.id}`,
|
`Message received from client id: ${mockClient.id} (user: test-sub)`,
|
||||||
);
|
);
|
||||||
expect(mockLoggerDebug).toHaveBeenCalledWith(
|
expect(mockLoggerDebug).toHaveBeenCalledWith(
|
||||||
`Payload: ${JSON.stringify(data, null, 0)}`,
|
`Payload: ${JSON.stringify(data, null, 0)}`,
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should throw exception when client is not authenticated', () => {
|
||||||
|
const mockClient: MockSocket = {
|
||||||
|
id: 'client1',
|
||||||
|
data: {},
|
||||||
|
};
|
||||||
|
const data = { x: 100, y: 200 };
|
||||||
|
|
||||||
|
expect(() => {
|
||||||
|
gateway.handleCursorReportPosition(
|
||||||
|
mockClient as unknown as AuthenticatedSocket,
|
||||||
|
data,
|
||||||
|
);
|
||||||
|
}).toThrow('Unauthorized');
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -6,11 +6,25 @@ import {
|
|||||||
SubscribeMessage,
|
SubscribeMessage,
|
||||||
WebSocketGateway,
|
WebSocketGateway,
|
||||||
WebSocketServer,
|
WebSocketServer,
|
||||||
|
WsException,
|
||||||
} from '@nestjs/websockets';
|
} from '@nestjs/websockets';
|
||||||
|
|
||||||
import { Server, Socket } from 'socket.io';
|
import type { Server } from 'socket.io';
|
||||||
|
import type { AuthenticatedSocket } from '../../types/socket';
|
||||||
|
import { AuthService } from '../../auth/auth.service';
|
||||||
|
import { JwtVerificationService } from '../../auth/services/jwt-verification.service';
|
||||||
|
import { CursorPositionDto } from '../dto/cursor-position.dto';
|
||||||
|
|
||||||
@WebSocketGateway()
|
const WS_EVENT = {
|
||||||
|
CURSOR_REPORT_POSITION: 'cursor-report-position',
|
||||||
|
} as const;
|
||||||
|
|
||||||
|
@WebSocketGateway({
|
||||||
|
cors: {
|
||||||
|
origin: true,
|
||||||
|
credentials: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
export class StateGateway
|
export class StateGateway
|
||||||
implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect
|
implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect
|
||||||
{
|
{
|
||||||
@@ -18,24 +32,84 @@ export class StateGateway
|
|||||||
|
|
||||||
@WebSocketServer() io: Server;
|
@WebSocketServer() io: Server;
|
||||||
|
|
||||||
|
constructor(
|
||||||
|
private readonly authService: AuthService,
|
||||||
|
private readonly jwtVerificationService: JwtVerificationService,
|
||||||
|
) {}
|
||||||
|
|
||||||
afterInit() {
|
afterInit() {
|
||||||
this.logger.log('Initialized');
|
this.logger.log('Initialized');
|
||||||
}
|
}
|
||||||
|
|
||||||
handleConnection(client: Socket) {
|
async handleConnection(client: AuthenticatedSocket) {
|
||||||
const { sockets } = this.io.sockets;
|
try {
|
||||||
|
this.logger.debug(
|
||||||
|
`Connection attempt - handshake auth: ${JSON.stringify(client.handshake.auth)}`,
|
||||||
|
);
|
||||||
|
this.logger.debug(
|
||||||
|
`Connection attempt - handshake headers: ${JSON.stringify(client.handshake.headers)}`,
|
||||||
|
);
|
||||||
|
|
||||||
this.logger.log(`Client id: ${client.id} connected`);
|
const token = this.jwtVerificationService.extractToken(client.handshake);
|
||||||
this.logger.debug(`Number of connected clients: ${sockets.size}`);
|
|
||||||
|
if (!token) {
|
||||||
|
this.logger.warn('WebSocket connection attempt without token');
|
||||||
|
client.disconnect();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const payload = await this.jwtVerificationService.verifyToken(token);
|
||||||
|
|
||||||
|
if (!payload.sub) {
|
||||||
|
throw new WsException('Invalid token: missing subject');
|
||||||
|
}
|
||||||
|
|
||||||
|
client.data.user = {
|
||||||
|
keycloakSub: payload.sub,
|
||||||
|
email: payload.email,
|
||||||
|
name: payload.name,
|
||||||
|
username: payload.preferred_username,
|
||||||
|
picture: payload.picture,
|
||||||
|
};
|
||||||
|
|
||||||
|
this.logger.log(`WebSocket authenticated: ${payload.sub}`);
|
||||||
|
|
||||||
|
await this.authService.syncUserFromToken(client.data.user);
|
||||||
|
|
||||||
|
const { sockets } = this.io.sockets;
|
||||||
|
this.logger.log(
|
||||||
|
`Client id: ${client.id} connected (user: ${payload.sub})`,
|
||||||
|
);
|
||||||
|
this.logger.debug(`Number of connected clients: ${sockets.size}`);
|
||||||
|
} catch (error: unknown) {
|
||||||
|
const errorMessage =
|
||||||
|
error instanceof Error ? error.message : 'Unknown error';
|
||||||
|
this.logger.error(`Connection error: ${errorMessage}`);
|
||||||
|
client.disconnect();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
handleDisconnect(client: Socket) {
|
handleDisconnect(client: AuthenticatedSocket) {
|
||||||
this.logger.log(`Cliend id:${client.id} disconnected`);
|
const user = client.data.user;
|
||||||
|
this.logger.log(
|
||||||
|
`Client id: ${client.id} disconnected (user: ${user?.keycloakSub || 'unknown'})`,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@SubscribeMessage('cursor-report-position')
|
@SubscribeMessage(WS_EVENT.CURSOR_REPORT_POSITION)
|
||||||
handleCursorReportPosition(client: Socket, data: any) {
|
handleCursorReportPosition(
|
||||||
this.logger.log(`Message received from client id: ${client.id}`);
|
client: AuthenticatedSocket,
|
||||||
|
data: CursorPositionDto,
|
||||||
|
) {
|
||||||
|
const user = client.data.user;
|
||||||
|
|
||||||
|
if (!user) {
|
||||||
|
throw new WsException('Unauthorized');
|
||||||
|
}
|
||||||
|
|
||||||
|
this.logger.log(
|
||||||
|
`Message received from client id: ${client.id} (user: ${user.keycloakSub})`,
|
||||||
|
);
|
||||||
this.logger.debug(`Payload: ${JSON.stringify(data, null, 0)}`);
|
this.logger.debug(`Payload: ${JSON.stringify(data, null, 0)}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
import { Module } from '@nestjs/common';
|
import { Module } from '@nestjs/common';
|
||||||
import { StateGateway } from './state/state.gateway';
|
import { StateGateway } from './state/state.gateway';
|
||||||
|
import { AuthModule } from '../auth/auth.module';
|
||||||
|
|
||||||
@Module({
|
@Module({
|
||||||
|
imports: [AuthModule],
|
||||||
providers: [StateGateway],
|
providers: [StateGateway],
|
||||||
})
|
})
|
||||||
export class WsModule {}
|
export class WsModule {}
|
||||||
|
|||||||
Reference in New Issue
Block a user