From f04ffea61259e6899547d30e06a1b8ce18f58b62 Mon Sep 17 00:00:00 2001 From: Wind-Explorer Date: Thu, 4 Dec 2025 23:43:41 +0800 Subject: [PATCH] ws auth --- package.json | 2 + pnpm-lock.yaml | 14 +- pnpm-workspace.yaml | 3 + src/auth/auth.module.ts | 36 +--- .../services/jwt-verification.service.spec.ts | 79 +++++++++ src/auth/services/jwt-verification.service.ts | 93 ++++++++++ src/types/socket.d.ts | 10 ++ src/users/users.module.ts | 2 +- src/ws/dto/cursor-position.dto.ts | 9 + src/ws/state/state.gateway.spec.ts | 161 +++++++++++++++--- src/ws/state/state.gateway.ts | 96 +++++++++-- src/ws/ws.module.ts | 2 + 12 files changed, 435 insertions(+), 72 deletions(-) create mode 100644 src/auth/services/jwt-verification.service.spec.ts create mode 100644 src/auth/services/jwt-verification.service.ts create mode 100644 src/types/socket.d.ts create mode 100644 src/ws/dto/cursor-position.dto.ts diff --git a/package.json b/package.json index 5898ae8..4699fe7 100644 --- a/package.json +++ b/package.json @@ -40,6 +40,7 @@ "class-transformer": "^0.5.1", "class-validator": "^0.14.2", "dotenv": "^17.2.3", + "jsonwebtoken": "^9.0.2", "jwks-rsa": "^3.2.0", "passport": "^0.7.0", "passport-jwt": "^4.0.1", @@ -56,6 +57,7 @@ "@nestjs/testing": "^11.0.1", "@types/express": "^5.0.0", "@types/jest": "^30.0.0", + "@types/jsonwebtoken": "^9.0.7", "@types/node": "^22.10.7", "@types/passport-jwt": "^4.0.1", "@types/pg": "^8.15.6", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 1b08cb5..2612bc6 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -28,7 +28,7 @@ importers: version: 11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/websockets@11.1.9)(rxjs@7.8.2) '@nestjs/swagger': specifier: ^11.2.3 - version: 11.2.3(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9)(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2) + version: 11.2.3(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/platform-express@11.1.9)(@nestjs/websockets@11.1.9)(reflect-metadata@0.2.2)(rxjs@7.8.2))(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2) '@nestjs/websockets': specifier: ^11.1.9 version: 11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9)(@nestjs/platform-socket.io@11.1.9)(reflect-metadata@0.2.2)(rxjs@7.8.2) @@ -47,6 +47,9 @@ importers: dotenv: specifier: ^17.2.3 version: 17.2.3 + jsonwebtoken: + specifier: ^9.0.2 + version: 9.0.2 jwks-rsa: specifier: ^3.2.0 version: 3.2.0 @@ -83,13 +86,16 @@ importers: version: 11.0.9(chokidar@4.0.3)(typescript@5.9.3) '@nestjs/testing': specifier: ^11.0.1 - version: 11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9)(@nestjs/platform-express@11.1.9) + version: 11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/platform-express@11.1.9)(@nestjs/websockets@11.1.9)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/platform-express@11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9)) '@types/express': specifier: ^5.0.0 version: 5.0.5 '@types/jest': specifier: ^30.0.0 version: 30.0.0 + '@types/jsonwebtoken': + specifier: ^9.0.7 + version: 9.0.10 '@types/node': specifier: ^22.10.7 version: 22.19.1 @@ -4532,7 +4538,7 @@ snapshots: transitivePeerDependencies: - chokidar - '@nestjs/swagger@11.2.3(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9)(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)': + '@nestjs/swagger@11.2.3(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/platform-express@11.1.9)(@nestjs/websockets@11.1.9)(reflect-metadata@0.2.2)(rxjs@7.8.2))(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)': dependencies: '@microsoft/tsdoc': 0.16.0 '@nestjs/common': 11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2) @@ -4547,7 +4553,7 @@ snapshots: class-transformer: 0.5.1 class-validator: 0.14.2 - '@nestjs/testing@11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9)(@nestjs/platform-express@11.1.9)': + '@nestjs/testing@11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/platform-express@11.1.9)(@nestjs/websockets@11.1.9)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/platform-express@11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.9))': dependencies: '@nestjs/common': 11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2) '@nestjs/core': 11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/platform-express@11.1.9)(@nestjs/websockets@11.1.9)(reflect-metadata@0.2.2)(rxjs@7.8.2) diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index 0f71e7a..8622d67 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -1,3 +1,6 @@ +packages: + - '.' + onlyBuiltDependencies: - '@nestjs/core' - '@prisma/engines' diff --git a/src/auth/auth.module.ts b/src/auth/auth.module.ts index a63dd4e..de5aec3 100644 --- a/src/auth/auth.module.ts +++ b/src/auth/auth.module.ts @@ -3,46 +3,16 @@ import { ConfigModule } from '@nestjs/config'; import { PassportModule } from '@nestjs/passport'; import { AuthService } from './auth.service'; import { JwtStrategy } from './strategies/jwt.strategy'; +import { JwtVerificationService } from './services/jwt-verification.service'; import { UsersModule } from '../users/users.module'; -/** - * Authentication Module - * - * Provides Keycloak OpenID Connect authentication using JWT tokens. - * This module configures: - * - Passport for authentication strategies - * - JWT strategy for validating Keycloak tokens - * - Integration with UsersModule for user synchronization - * - * The module requires the following environment variables: - * - JWT_ISSUER: Expected JWT issuer - * - JWT_AUDIENCE: Expected JWT audience - * - JWKS_URI: URI for fetching Keycloak's public keys - */ @Module({ imports: [ - // Import ConfigModule to access environment variables ConfigModule, - - // Import PassportModule for authentication strategies PassportModule.register({ defaultStrategy: 'jwt' }), - - // Import UsersModule to enable user synchronization (with forwardRef to avoid circular dependency) forwardRef(() => UsersModule), ], - providers: [ - // Register the JWT strategy for validating Keycloak tokens - JwtStrategy, - - // Register the auth service for business logic - AuthService, - ], - exports: [ - // Export AuthService so other modules can use it - AuthService, - - // Export PassportModule so guards can be used in other modules - PassportModule, - ], + providers: [JwtStrategy, AuthService, JwtVerificationService], + exports: [AuthService, PassportModule, JwtVerificationService], }) export class AuthModule {} diff --git a/src/auth/services/jwt-verification.service.spec.ts b/src/auth/services/jwt-verification.service.spec.ts new file mode 100644 index 0000000..d6dda64 --- /dev/null +++ b/src/auth/services/jwt-verification.service.spec.ts @@ -0,0 +1,79 @@ +import { Test, TestingModule } from '@nestjs/testing'; +import { ConfigService } from '@nestjs/config'; +import { JwtVerificationService } from './jwt-verification.service'; + +describe('JwtVerificationService', () => { + let service: JwtVerificationService; + + beforeEach(async () => { + const mockConfigService = { + get: jest.fn((key: string) => { + const config: Record = { + JWKS_URI: 'https://test.com/.well-known/jwks.json', + JWT_ISSUER: 'https://test.com', + JWT_AUDIENCE: 'test-audience', + }; + return config[key]; + }), + }; + + const module: TestingModule = await Test.createTestingModule({ + providers: [ + JwtVerificationService, + { provide: ConfigService, useValue: mockConfigService }, + ], + }).compile(); + + service = module.get(JwtVerificationService); + }); + + it('should be defined', () => { + expect(service).toBeDefined(); + }); + + describe('extractToken', () => { + it('should extract token from auth object', () => { + const handshake = { + auth: { token: 'test-token' }, + headers: {}, + }; + + const token = service.extractToken(handshake); + + expect(token).toBe('test-token'); + }); + + it('should extract token from Authorization header', () => { + const handshake = { + auth: {}, + headers: { authorization: 'Bearer test-token' }, + }; + + const token = service.extractToken(handshake); + + expect(token).toBe('test-token'); + }); + + it('should prioritize auth.token over header', () => { + const handshake = { + auth: { token: 'auth-token' }, + headers: { authorization: 'Bearer header-token' }, + }; + + const token = service.extractToken(handshake); + + expect(token).toBe('auth-token'); + }); + + it('should return undefined when no token present', () => { + const handshake = { + auth: {}, + headers: {}, + }; + + const token = service.extractToken(handshake); + + expect(token).toBeUndefined(); + }); + }); +}); diff --git a/src/auth/services/jwt-verification.service.ts b/src/auth/services/jwt-verification.service.ts new file mode 100644 index 0000000..27814c5 --- /dev/null +++ b/src/auth/services/jwt-verification.service.ts @@ -0,0 +1,93 @@ +import { Injectable, Logger } from '@nestjs/common'; +import { ConfigService } from '@nestjs/config'; +import { verify, type JwtHeader } from 'jsonwebtoken'; +import { JwksClient, type SigningKey } from 'jwks-rsa'; +import type { JwtPayload } from '../strategies/jwt.strategy'; + +const JWT_ALGORITHM = 'RS256'; +const BEARER_PREFIX = 'Bearer '; + +@Injectable() +export class JwtVerificationService { + private readonly logger = new Logger(JwtVerificationService.name); + private readonly jwksClient: JwksClient; + private readonly issuer: string; + private readonly audience: string | undefined; + + constructor(private readonly configService: ConfigService) { + const jwksUri = this.configService.get('JWKS_URI'); + this.issuer = this.configService.get('JWT_ISSUER') || ''; + this.audience = this.configService.get('JWT_AUDIENCE'); + + if (!jwksUri) { + throw new Error('JWKS_URI must be configured'); + } + + if (!this.issuer) { + throw new Error('JWT_ISSUER must be configured'); + } + + this.jwksClient = new JwksClient({ + jwksUri, + cache: true, + rateLimit: true, + jwksRequestsPerMinute: 5, + }); + + this.logger.log('JWT Verification Service initialized'); + } + + async verifyToken(token: string): Promise { + return new Promise((resolve, reject) => { + const getKey = ( + header: JwtHeader, + callback: (err: Error | null, signingKey?: string | Buffer) => void, + ) => { + this.jwksClient.getSigningKey( + header.kid, + (err: Error | null, key?: SigningKey) => { + if (err) { + callback(err); + return; + } + const signingKey = key?.getPublicKey(); + callback(null, signingKey); + }, + ); + }; + + verify( + token, + getKey, + { + issuer: this.issuer, + audience: this.audience, + algorithms: [JWT_ALGORITHM], + }, + (err, decoded) => { + if (err) { + reject(err); + return; + } + resolve(decoded as JwtPayload); + }, + ); + }); + } + + extractToken(handshake: { + auth?: { token?: string }; + headers?: { authorization?: string }; + }): string | undefined { + if (handshake.auth?.token) { + return handshake.auth.token; + } + + const authHeader = handshake.headers?.authorization; + if (authHeader?.startsWith(BEARER_PREFIX)) { + return authHeader.replace(BEARER_PREFIX, ''); + } + + return undefined; + } +} diff --git a/src/types/socket.d.ts b/src/types/socket.d.ts new file mode 100644 index 0000000..04dde24 --- /dev/null +++ b/src/types/socket.d.ts @@ -0,0 +1,10 @@ +import type { Socket as BaseSocket } from 'socket.io'; +import type { AuthenticatedUser } from '../auth/decorators/current-user.decorator'; +import type { DefaultEventsMap } from 'socket.io/dist/typed-events'; + +export type AuthenticatedSocket = BaseSocket< + DefaultEventsMap, // ClientToServerEvents + DefaultEventsMap, // ServerToClientEvents + DefaultEventsMap, // InterServerEvents + { user?: AuthenticatedUser } +>; diff --git a/src/users/users.module.ts b/src/users/users.module.ts index 914b70d..03a1448 100644 --- a/src/users/users.module.ts +++ b/src/users/users.module.ts @@ -16,6 +16,6 @@ import { AuthModule } from '../auth/auth.module'; imports: [forwardRef(() => AuthModule)], providers: [UsersService], controllers: [UsersController], - exports: [UsersService], // Export so AuthModule can use it + exports: [UsersService], }) export class UsersModule {} diff --git a/src/ws/dto/cursor-position.dto.ts b/src/ws/dto/cursor-position.dto.ts new file mode 100644 index 0000000..b15ce91 --- /dev/null +++ b/src/ws/dto/cursor-position.dto.ts @@ -0,0 +1,9 @@ +import { IsNumber } from 'class-validator'; + +export class CursorPositionDto { + @IsNumber() + x: number; + + @IsNumber() + y: number; +} diff --git a/src/ws/state/state.gateway.spec.ts b/src/ws/state/state.gateway.spec.ts index 0a64ba4..0706ee9 100644 --- a/src/ws/state/state.gateway.spec.ts +++ b/src/ws/state/state.gateway.spec.ts @@ -1,12 +1,28 @@ import { Test, TestingModule } from '@nestjs/testing'; import { StateGateway } from './state.gateway'; -import { Socket } from 'socket.io'; +import { AuthenticatedSocket } from '../../types/socket'; +import { AuthService } from '../../auth/auth.service'; +import { JwtVerificationService } from '../../auth/services/jwt-verification.service'; + +interface MockSocket extends Partial { + id: string; + data: { + user?: { + keycloakSub: string; + }; + }; + handshake?: any; + disconnect?: jest.Mock; +} describe('StateGateway', () => { let gateway: StateGateway; let mockLoggerLog: jest.SpyInstance; let mockLoggerDebug: jest.SpyInstance; - let mockServer: any; + let mockLoggerWarn: jest.SpyInstance; + let mockServer: { sockets: { sockets: { size: number } } }; + let mockAuthService: Partial; + let mockJwtVerificationService: Partial; beforeEach(async () => { mockServer = { @@ -17,23 +33,40 @@ describe('StateGateway', () => { }, }; + mockAuthService = { + syncUserFromToken: jest.fn().mockResolvedValue({ + id: 'user-id', + keycloakSub: 'test-sub', + }), + }; + + mockJwtVerificationService = { + extractToken: jest.fn((handshake) => handshake.auth?.token), + verifyToken: jest.fn().mockResolvedValue({ + sub: 'test-sub', + email: 'test@example.com', + }), + }; + const module: TestingModule = await Test.createTestingModule({ - providers: [StateGateway], + providers: [ + StateGateway, + { provide: AuthService, useValue: mockAuthService }, + { + provide: JwtVerificationService, + useValue: mockJwtVerificationService, + }, + ], }).compile(); gateway = module.get(StateGateway); - gateway.io = mockServer; + gateway.io = mockServer as any; - // Spy on logger methods - mockLoggerLog = jest - // gateway is private in `state.gateway.ts` so we have to - // eslint-disable-next-line @typescript-eslint/no-unsafe-member-access - .spyOn((gateway as any).logger, 'log') - .mockImplementation(); + mockLoggerLog = jest.spyOn(gateway['logger'], 'log').mockImplementation(); mockLoggerDebug = jest - // eslint-disable-next-line @typescript-eslint/no-unsafe-member-access - .spyOn((gateway as any).logger, 'debug') + .spyOn(gateway['logger'], 'debug') .mockImplementation(); + mockLoggerWarn = jest.spyOn(gateway['logger'], 'warn').mockImplementation(); }); afterEach(() => { @@ -53,45 +86,127 @@ describe('StateGateway', () => { }); describe('handleConnection', () => { - it('should log client connection and number of connected clients', () => { - const mockClient = { id: 'client1' } as Socket; + it('should log client connection and sync user when authenticated', async () => { + const mockClient: MockSocket = { + id: 'client1', + data: { user: { keycloakSub: 'test-sub' } }, + handshake: { + auth: { token: 'mock-token' }, + headers: {}, + }, + }; - gateway.handleConnection(mockClient); + await gateway.handleConnection( + mockClient as unknown as AuthenticatedSocket, + ); + expect(mockJwtVerificationService.extractToken).toHaveBeenCalledWith( + mockClient.handshake, + ); + expect(mockJwtVerificationService.verifyToken).toHaveBeenCalledWith( + 'mock-token', + ); + expect(mockAuthService.syncUserFromToken).toHaveBeenCalledWith( + expect.objectContaining({ + keycloakSub: 'test-sub', + }), + ); expect(mockLoggerLog).toHaveBeenCalledWith( - `Client id: ${mockClient.id} connected`, + `Client id: ${mockClient.id} connected (user: test-sub)`, ); expect(mockLoggerDebug).toHaveBeenCalledWith( 'Number of connected clients: 5', ); }); + + it('should disconnect client when no token provided', async () => { + const mockClient: MockSocket = { + id: 'client1', + data: {}, + handshake: { + auth: {}, + headers: {}, + }, + disconnect: jest.fn(), + }; + + (mockJwtVerificationService.extractToken as jest.Mock).mockReturnValue( + undefined, + ); + + await gateway.handleConnection( + mockClient as unknown as AuthenticatedSocket, + ); + + expect(mockLoggerWarn).toHaveBeenCalledWith( + 'WebSocket connection attempt without token', + ); + expect(mockClient.disconnect).toHaveBeenCalled(); + }); }); describe('handleDisconnect', () => { it('should log client disconnection', () => { - const mockClient = { id: 'client1' } as Socket; + const mockClient: MockSocket = { + id: 'client1', + data: { user: { keycloakSub: 'test-sub' } }, + }; - gateway.handleDisconnect(mockClient); + gateway.handleDisconnect(mockClient as unknown as AuthenticatedSocket); expect(mockLoggerLog).toHaveBeenCalledWith( - `Cliend id:${mockClient.id} disconnected`, + `Client id: ${mockClient.id} disconnected (user: test-sub)`, + ); + }); + + it('should handle disconnection when no user data', () => { + const mockClient: MockSocket = { + id: 'client1', + data: {}, + }; + + gateway.handleDisconnect(mockClient as unknown as AuthenticatedSocket); + + expect(mockLoggerLog).toHaveBeenCalledWith( + `Client id: ${mockClient.id} disconnected (user: unknown)`, ); }); }); describe('handleCursorReportPosition', () => { - it('should log message received from client', () => { - const mockClient = { id: 'client1' } as Socket; + it('should log message received from authenticated client', () => { + const mockClient: MockSocket = { + id: 'client1', + data: { user: { keycloakSub: 'test-sub' } }, + }; const data = { x: 100, y: 200 }; - gateway.handleCursorReportPosition(mockClient, data); + gateway.handleCursorReportPosition( + mockClient as unknown as AuthenticatedSocket, + data, + ); expect(mockLoggerLog).toHaveBeenCalledWith( - `Message received from client id: ${mockClient.id}`, + `Message received from client id: ${mockClient.id} (user: test-sub)`, ); expect(mockLoggerDebug).toHaveBeenCalledWith( `Payload: ${JSON.stringify(data, null, 0)}`, ); }); + + it('should throw exception when client is not authenticated', () => { + const mockClient: MockSocket = { + id: 'client1', + data: {}, + }; + const data = { x: 100, y: 200 }; + + expect(() => { + gateway.handleCursorReportPosition( + mockClient as unknown as AuthenticatedSocket, + data, + ); + }).toThrow('Unauthorized'); + }); }); }); diff --git a/src/ws/state/state.gateway.ts b/src/ws/state/state.gateway.ts index 6a6f765..7d54ac1 100644 --- a/src/ws/state/state.gateway.ts +++ b/src/ws/state/state.gateway.ts @@ -6,11 +6,25 @@ import { SubscribeMessage, WebSocketGateway, WebSocketServer, + WsException, } from '@nestjs/websockets'; -import { Server, Socket } from 'socket.io'; +import type { Server } from 'socket.io'; +import type { AuthenticatedSocket } from '../../types/socket'; +import { AuthService } from '../../auth/auth.service'; +import { JwtVerificationService } from '../../auth/services/jwt-verification.service'; +import { CursorPositionDto } from '../dto/cursor-position.dto'; -@WebSocketGateway() +const WS_EVENT = { + CURSOR_REPORT_POSITION: 'cursor-report-position', +} as const; + +@WebSocketGateway({ + cors: { + origin: true, + credentials: true, + }, +}) export class StateGateway implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect { @@ -18,24 +32,84 @@ export class StateGateway @WebSocketServer() io: Server; + constructor( + private readonly authService: AuthService, + private readonly jwtVerificationService: JwtVerificationService, + ) {} + afterInit() { this.logger.log('Initialized'); } - handleConnection(client: Socket) { - const { sockets } = this.io.sockets; + async handleConnection(client: AuthenticatedSocket) { + try { + this.logger.debug( + `Connection attempt - handshake auth: ${JSON.stringify(client.handshake.auth)}`, + ); + this.logger.debug( + `Connection attempt - handshake headers: ${JSON.stringify(client.handshake.headers)}`, + ); - this.logger.log(`Client id: ${client.id} connected`); - this.logger.debug(`Number of connected clients: ${sockets.size}`); + const token = this.jwtVerificationService.extractToken(client.handshake); + + if (!token) { + this.logger.warn('WebSocket connection attempt without token'); + client.disconnect(); + return; + } + + const payload = await this.jwtVerificationService.verifyToken(token); + + if (!payload.sub) { + throw new WsException('Invalid token: missing subject'); + } + + client.data.user = { + keycloakSub: payload.sub, + email: payload.email, + name: payload.name, + username: payload.preferred_username, + picture: payload.picture, + }; + + this.logger.log(`WebSocket authenticated: ${payload.sub}`); + + await this.authService.syncUserFromToken(client.data.user); + + const { sockets } = this.io.sockets; + this.logger.log( + `Client id: ${client.id} connected (user: ${payload.sub})`, + ); + this.logger.debug(`Number of connected clients: ${sockets.size}`); + } catch (error: unknown) { + const errorMessage = + error instanceof Error ? error.message : 'Unknown error'; + this.logger.error(`Connection error: ${errorMessage}`); + client.disconnect(); + } } - handleDisconnect(client: Socket) { - this.logger.log(`Cliend id:${client.id} disconnected`); + handleDisconnect(client: AuthenticatedSocket) { + const user = client.data.user; + this.logger.log( + `Client id: ${client.id} disconnected (user: ${user?.keycloakSub || 'unknown'})`, + ); } - @SubscribeMessage('cursor-report-position') - handleCursorReportPosition(client: Socket, data: any) { - this.logger.log(`Message received from client id: ${client.id}`); + @SubscribeMessage(WS_EVENT.CURSOR_REPORT_POSITION) + handleCursorReportPosition( + client: AuthenticatedSocket, + data: CursorPositionDto, + ) { + const user = client.data.user; + + if (!user) { + throw new WsException('Unauthorized'); + } + + this.logger.log( + `Message received from client id: ${client.id} (user: ${user.keycloakSub})`, + ); this.logger.debug(`Payload: ${JSON.stringify(data, null, 0)}`); } } diff --git a/src/ws/ws.module.ts b/src/ws/ws.module.ts index 2c2b007..62d8171 100644 --- a/src/ws/ws.module.ts +++ b/src/ws/ws.module.ts @@ -1,7 +1,9 @@ import { Module } from '@nestjs/common'; import { StateGateway } from './state/state.gateway'; +import { AuthModule } from '../auth/auth.module'; @Module({ + imports: [AuthModule], providers: [StateGateway], }) export class WsModule {}