diff --git a/.env.example b/.env.example index 8d6b502..7ca792a 100644 --- a/.env.example +++ b/.env.example @@ -6,6 +6,10 @@ NODE_ENV=development # Database connection string DATABASE_URL="postgresql://postgres:postgres@localhost:5432/friendolls_dev?schema=public" +# Redis +REDIS_HOST=localhost +REDIS_PORT=6379 + # JWT Configuration # The expected issuer of the JWT token (usually {KEYCLOAK_AUTH_SERVER_URL}/realms/{KEYCLOAK_REALM}) JWT_ISSUER=https://your-keycloak-instance.com/auth/realms/your-realm-name diff --git a/package.json b/package.json index 9b184ab..0c71023 100644 --- a/package.json +++ b/package.json @@ -29,7 +29,7 @@ "dependencies": { "@nestjs/common": "^11.0.1", "@nestjs/config": "^4.0.2", - "@nestjs/core": "^11.0.1", + "@nestjs/core": "^11.1.9", "@nestjs/event-emitter": "^3.0.1", "@nestjs/passport": "^11.0.5", "@nestjs/platform-express": "^11.0.1", @@ -39,9 +39,11 @@ "@nestjs/websockets": "^11.1.9", "@prisma/adapter-pg": "^7.0.0", "@prisma/client": "^7.0.0", + "@socket.io/redis-adapter": "^8.3.0", "class-transformer": "^0.5.1", "class-validator": "^0.14.2", "dotenv": "^17.2.3", + "ioredis": "^5.8.2", "jsonwebtoken": "^9.0.2", "jwks-rsa": "^3.2.0", "passport": "^0.7.0", @@ -58,6 +60,7 @@ "@nestjs/schematics": "^11.0.0", "@nestjs/testing": "^11.0.1", "@types/express": "^5.0.0", + "@types/ioredis": "^5.0.0", "@types/jest": "^30.0.0", "@types/jsonwebtoken": "^9.0.7", "@types/node": "^22.10.7", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 647a182..d7d2108 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -15,7 +15,7 @@ importers: specifier: ^4.0.2 version: 4.0.2(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(rxjs@7.8.2) '@nestjs/core': - specifier: ^11.0.1 + specifier: ^11.1.9 version: 11.1.9(@nestjs/common@11.1.9(class-transformer@0.5.1)(class-validator@0.14.2)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/platform-express@11.1.9)(@nestjs/websockets@11.1.9)(reflect-metadata@0.2.2)(rxjs@7.8.2) '@nestjs/event-emitter': specifier: ^3.0.1 @@ -44,6 +44,9 @@ importers: '@prisma/client': specifier: ^7.0.0 version: 7.0.0(prisma@7.0.0(@types/react@19.2.6)(react-dom@19.2.0(react@19.2.0))(react@19.2.0)(typescript@5.9.3))(typescript@5.9.3) + '@socket.io/redis-adapter': + specifier: ^8.3.0 + version: 8.3.0(socket.io-adapter@2.5.5) class-transformer: specifier: ^0.5.1 version: 0.5.1 @@ -53,6 +56,9 @@ importers: dotenv: specifier: ^17.2.3 version: 17.2.3 + ioredis: + specifier: ^5.8.2 + version: 5.8.2 jsonwebtoken: specifier: ^9.0.2 version: 9.0.2 @@ -96,6 +102,9 @@ importers: '@types/express': specifier: ^5.0.0 version: 5.0.5 + '@types/ioredis': + specifier: ^5.0.0 + version: 5.0.0 '@types/jest': specifier: ^30.0.0 version: 30.0.0 @@ -607,6 +616,9 @@ packages: '@types/node': optional: true + '@ioredis/commands@1.4.0': + resolution: {integrity: sha512-aFT2yemJJo+TZCmieA7qnYGQooOS7QfNmYrzGtsYd3g9j5iDP8AimYYAesf79ohjbLG12XxC4nG5DyEnC88AsQ==} + '@isaacs/balanced-match@4.0.1': resolution: {integrity: sha512-yzMTt9lEb8Gv7zRioUilSglI0c0smZ9k5D65677DLWLtWJaXIS3CqcGyUFByYKlnUj6TkjLVs54fBl6+TiGQDQ==} engines: {node: 20 || >=22} @@ -992,6 +1004,12 @@ packages: '@socket.io/component-emitter@3.1.2': resolution: {integrity: sha512-9BCxFwvbGg/RsZK9tjXd8s4UcwR0MWeFQ1XEKIQVVvAGJyINdrqKMcTRyLoK8Rse1GjzLV9cwjWV1olXRWEXVA==} + '@socket.io/redis-adapter@8.3.0': + resolution: {integrity: sha512-ly0cra+48hDmChxmIpnESKrc94LjRL80TEmZVscuQ/WWkRP81nNj8W8cCGMqbI4L6NCuAaPRSzZF1a9GlAxxnA==} + engines: {node: '>=10.0.0'} + peerDependencies: + socket.io-adapter: ^2.5.4 + '@standard-schema/spec@1.0.0': resolution: {integrity: sha512-m2bOd0f2RT9k8QJx1JN85cZYyH1RqFBdlwtkSlf4tBDYLCiiZnv1fIIwacK6cqwXavOydf0NPToMQgpKq+dVlA==} @@ -1065,6 +1083,10 @@ packages: '@types/http-errors@2.0.5': resolution: {integrity: sha512-r8Tayk8HJnX0FztbZN7oVqGccWgw98T/0neJphO91KkmOzug1KkofZURD4UaD5uH8AqcFLfdPErnBod0u71/qg==} + '@types/ioredis@5.0.0': + resolution: {integrity: sha512-zJbJ3FVE17CNl5KXzdeSPtdltc4tMT3TzC6fxQS0sQngkbFZ6h+0uTafsRqu+eSLIugf6Yb0Ea0SUuRr42Nk9g==} + deprecated: This is a stub types definition. ioredis provides its own type definitions, so you do not need this installed. + '@types/istanbul-lib-coverage@2.0.6': resolution: {integrity: sha512-2QF/t/auWm0lsy8XtKVPG19v3sSOQlJe/YHZgfjb/KBBHOGSV+J2q/S671rcq9uTBrLAXmZpqJiaQbMT+zNU1w==} @@ -1650,6 +1672,10 @@ packages: resolution: {integrity: sha512-JQHZ2QMW6l3aH/j6xCqQThY/9OH4D/9ls34cgkUBiEeocRTU04tHfKPBsUK1PqZCUQM7GiA0IIXJSuXHI64Kbg==} engines: {node: '>=0.8'} + cluster-key-slot@1.1.2: + resolution: {integrity: sha512-RMr0FhtfXemyinomL4hrWcYJxmX6deFdCxpJzhDttxgO1+bcCnkk+9drydLVDmAMG7NE6aN/fl4F7ucU/90gAA==} + engines: {node: '>=0.10.0'} + co@4.6.0: resolution: {integrity: sha512-QVb0dM5HvG+uaxitm8wONl7jltx8dqhfU33DcqtOZcLSVIKSDDLDi7+0LbAKiyI8hD9u42m2YxXSkMGWThaecQ==} engines: {iojs: '>= 1.0.0', node: '>= 0.12.0'} @@ -2288,6 +2314,10 @@ packages: inherits@2.0.4: resolution: {integrity: sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==} + ioredis@5.8.2: + resolution: {integrity: sha512-C6uC+kleiIMmjViJINWk80sOQw5lEzse1ZmvD+S/s8p8CWapftSaC+kocGTx6xrbrJ4WmYQGC08ffHLr6ToR6Q==} + engines: {node: '>=12.22.0'} + ipaddr.js@1.9.1: resolution: {integrity: sha512-0KI/607xoxSToH7GjN1FfSbLoU0+btTicjsQSWQlh/hZykN8KpmMf7uYwPW3R+akZ6R/w18ZlXSHBYXiYUPO3g==} engines: {node: '>= 0.10'} @@ -2605,9 +2635,15 @@ packages: lodash.clonedeep@4.5.0: resolution: {integrity: sha512-H5ZhCF25riFd9uB5UCkVKo61m3S/xZk1x4wA6yp/L3RFP6Z/eHH1ymQcGLo7J3GMPfm0V/7m1tryHuGVxpqEBQ==} + lodash.defaults@4.2.0: + resolution: {integrity: sha512-qjxPLHd3r5DnsdGacqOMU6pb/avJzdh9tFX2ymgoZE27BmjXrNy/y4LoaiTeAb+O3gL8AfpJGtqfX/ae2leYYQ==} + lodash.includes@4.3.0: resolution: {integrity: sha512-W3Bx6mdkRTGtlJISOvVD/lbqjTlPPUDTMnlXZFnVwi9NKJ6tiAk6LVdlhZMm17VZisqhKcgzpO5Wz91PCt5b0w==} + lodash.isarguments@3.1.0: + resolution: {integrity: sha512-chi4NHZlZqZD18a0imDHnZPrDeBbTtVN7GXMwuGdRH9qotxAjYs3aVLKc7zNOG9eddR5Ksd8rvFEBc9SsggPpg==} + lodash.isboolean@3.0.3: resolution: {integrity: sha512-Bz5mupy2SVbPHURB98VAcw+aHh4vRV5IPNhILUCsOzRmsTmSQ17jIuqopAentWoehktxGd9e/hbIXq980/1QJg==} @@ -2819,6 +2855,9 @@ packages: resolution: {integrity: sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==} engines: {node: '>=0.10.0'} + notepack.io@3.0.1: + resolution: {integrity: sha512-TKC/8zH5pXIAMVQio2TvVDTtPRX+DJPHDqjRbxogtFiByHyzKmy96RA0JtCQJ+WouyyL4A10xomQzgbUT+1jCg==} + npm-run-path@4.0.1: resolution: {integrity: sha512-S48WzZW777zhNIrn7gxOlISNAqi9ZC/uQFnRdbeIHhZhCA6UqpkOT8T1G7BvfdgP4Er8gF4sUbaS0i7QvIfCWw==} engines: {node: '>=8'} @@ -3120,6 +3159,14 @@ packages: resolution: {integrity: sha512-GDhwkLfywWL2s6vEjyhri+eXmfH6j1L7JE27WhqLeYzoh/A3DBaYGEj2H/HFZCn/kMfim73FXxEJTw06WtxQwg==} engines: {node: '>= 14.18.0'} + redis-errors@1.2.0: + resolution: {integrity: sha512-1qny3OExCf0UvUV/5wpYKf2YwPcOqXzkwKKSmKHiE6ZMQs5heeE/c8eXK+PNllPvmjgAbfnsbpkGZWy8cBpn9w==} + engines: {node: '>=4'} + + redis-parser@3.0.0: + resolution: {integrity: sha512-DJnGAeenTdpMEH6uAJRK/uiyEIH9WVsUmoLwzudwGJUwZPp80PDBWPHXSAGNPwNvIXAbe7MSUB1zQFugFml66A==} + engines: {node: '>=4'} + reflect-metadata@0.2.2: resolution: {integrity: sha512-urBwgfrvVP/eAyXx4hluJivBKzuEbSQs9rKWCrCkbSxNv8mxPcUZKeuoF3Uy4mJl3Lwprp6yy5/39VWigZ4K6Q==} @@ -3296,6 +3343,9 @@ packages: resolution: {integrity: sha512-XlkWvfIm6RmsWtNJx+uqtKLS8eqFbxUg0ZzLXqY0caEy9l7hruX8IpiDnjsLavoBgqCCR71TqWO8MaXYheJ3RQ==} engines: {node: '>=10'} + standard-as-callback@2.1.0: + resolution: {integrity: sha512-qoRRSyROncaz1z0mvYqIE4lCd9p2R90i6GxW3uZv5ucSu8tU7B5HXUP1gG8pVZsYNVaXjk8ClXHPttLyxAL48A==} + statuses@2.0.2: resolution: {integrity: sha512-DvEy55V3DB7uknRo+4iOGT5fP1slR8wQohVdknigZPMpMstaKJQWhwiYBACJE3Ul2pTnATihhBYnRhZQHGBiRw==} engines: {node: '>= 0.8'} @@ -3534,6 +3584,10 @@ packages: engines: {node: '>=0.8.0'} hasBin: true + uid2@1.0.0: + resolution: {integrity: sha512-+I6aJUv63YAcY9n4mQreLUt0d4lvwkkopDNmpomkAUz0fAkEMV9pRWxN0EjhW1YfRhcuyHg2v3mwddCDW1+LFQ==} + engines: {node: '>= 4.0.0'} + uid@2.0.2: resolution: {integrity: sha512-u3xV3X7uzvi5b1MncmZo3i2Aw222Zk1keqLA1YkHldREkAhAqi65wuPfe7lHx8H/Wzy+8CE7S7uS3jekIM5s8g==} engines: {node: '>=8'} @@ -4199,6 +4253,8 @@ snapshots: optionalDependencies: '@types/node': 22.19.1 + '@ioredis/commands@1.4.0': {} + '@isaacs/balanced-match@4.0.1': {} '@isaacs/brace-expansion@5.0.0': @@ -4735,6 +4791,15 @@ snapshots: '@socket.io/component-emitter@3.1.2': {} + '@socket.io/redis-adapter@8.3.0(socket.io-adapter@2.5.5)': + dependencies: + debug: 4.3.7 + notepack.io: 3.0.1 + socket.io-adapter: 2.5.5 + uid2: 1.0.0 + transitivePeerDependencies: + - supports-color + '@standard-schema/spec@1.0.0': {} '@tokenizer/inflate@0.3.1': @@ -4837,6 +4902,12 @@ snapshots: '@types/http-errors@2.0.5': {} + '@types/ioredis@5.0.0': + dependencies: + ioredis: 5.8.2 + transitivePeerDependencies: + - supports-color + '@types/istanbul-lib-coverage@2.0.6': {} '@types/istanbul-lib-report@3.0.3': @@ -5485,6 +5556,8 @@ snapshots: clone@1.0.4: {} + cluster-key-slot@1.1.2: {} + co@4.6.0: {} collect-v8-coverage@1.0.3: {} @@ -6142,6 +6215,20 @@ snapshots: inherits@2.0.4: {} + ioredis@5.8.2: + dependencies: + '@ioredis/commands': 1.4.0 + cluster-key-slot: 1.1.2 + debug: 4.4.3 + denque: 2.1.0 + lodash.defaults: 4.2.0 + lodash.isarguments: 3.1.0 + redis-errors: 1.2.0 + redis-parser: 3.0.0 + standard-as-callback: 2.1.0 + transitivePeerDependencies: + - supports-color + ipaddr.js@1.9.1: {} is-arrayish@0.2.1: {} @@ -6636,8 +6723,12 @@ snapshots: lodash.clonedeep@4.5.0: {} + lodash.defaults@4.2.0: {} + lodash.includes@4.3.0: {} + lodash.isarguments@3.1.0: {} + lodash.isboolean@3.0.3: {} lodash.isinteger@4.0.4: {} @@ -6811,6 +6902,8 @@ snapshots: normalize-path@3.0.0: {} + notepack.io@3.0.1: {} + npm-run-path@4.0.1: dependencies: path-key: 3.1.1 @@ -7095,6 +7188,12 @@ snapshots: readdirp@4.1.2: {} + redis-errors@1.2.0: {} + + redis-parser@3.0.0: + dependencies: + redis-errors: 1.2.0 + reflect-metadata@0.2.2: {} regexp-to-ast@0.5.0: {} @@ -7298,6 +7397,8 @@ snapshots: dependencies: escape-string-regexp: 2.0.0 + standard-as-callback@2.1.0: {} + statuses@2.0.2: {} std-env@3.9.0: {} @@ -7530,6 +7631,8 @@ snapshots: uglify-js@3.19.3: optional: true + uid2@1.0.0: {} + uid@2.0.2: dependencies: '@lukeed/csprng': 1.1.0 diff --git a/src/app.module.ts b/src/app.module.ts index 7178ee4..2b13058 100644 --- a/src/app.module.ts +++ b/src/app.module.ts @@ -7,6 +7,7 @@ import { AppService } from './app.service'; import { UsersModule } from './users/users.module'; import { AuthModule } from './auth/auth.module'; import { DatabaseModule } from './database/database.module'; +import { RedisModule } from './database/redis.module'; import { WsModule } from './ws/ws.module'; import { FriendsModule } from './friends/friends.module'; @@ -63,6 +64,7 @@ function validateEnvironment(config: Record): Record { }), EventEmitterModule.forRoot(), DatabaseModule, + RedisModule, UsersModule, AuthModule, WsModule, diff --git a/src/database/redis.module.ts b/src/database/redis.module.ts new file mode 100644 index 0000000..893de67 --- /dev/null +++ b/src/database/redis.module.ts @@ -0,0 +1,52 @@ +import { Module, Global, Logger } from '@nestjs/common'; +import { ConfigService } from '@nestjs/config'; +import Redis from 'ioredis'; + +export const REDIS_CLIENT = 'REDIS_CLIENT'; + +@Global() +@Module({ + providers: [ + { + provide: REDIS_CLIENT, + useFactory: (configService: ConfigService) => { + const logger = new Logger('RedisModule'); + const host = configService.get('REDIS_HOST'); + const port = configService.get('REDIS_PORT'); + const password = configService.get('REDIS_PASSWORD'); + + // Fallback or "disabled" mode if no host is provided + if (!host) { + logger.warn( + 'REDIS_HOST not defined. Redis features will be disabled or fall back to local memory.', + ); + return null; + } + + const client = new Redis({ + host, + port: port || 6379, + password: password, + // Retry strategy: keep trying to reconnect + retryStrategy(times) { + const delay = Math.min(times * 50, 2000); + return delay; + }, + }); + + client.on('error', (err) => { + logger.error('Redis connection error', err); + }); + + client.on('connect', () => { + logger.log(`Connected to Redis at ${host}:${port || 6379}`); + }); + + return client; + }, + inject: [ConfigService], + }, + ], + exports: [REDIS_CLIENT], +}) +export class RedisModule {} diff --git a/src/main.ts b/src/main.ts index 2515d0f..f45c7c4 100644 --- a/src/main.ts +++ b/src/main.ts @@ -1,12 +1,20 @@ import { NestFactory } from '@nestjs/core'; import { ValidationPipe, Logger } from '@nestjs/common'; +import { ConfigService } from '@nestjs/config'; import { DocumentBuilder, SwaggerModule } from '@nestjs/swagger'; import { AppModule } from './app.module'; import { AllExceptionsFilter } from './common/filters/all-exceptions.filter'; +import { RedisIoAdapter } from './ws/redis-io.adapter'; async function bootstrap() { const logger = new Logger('Bootstrap'); const app = await NestFactory.create(AppModule); + const configService = app.get(ConfigService); + + // Configure Redis Adapter for horizontal scaling (if enabled) + const redisIoAdapter = new RedisIoAdapter(app, configService); + await redisIoAdapter.connectToRedis(); + app.useWebSocketAdapter(redisIoAdapter); // Enable global exception filter for consistent error responses app.useGlobalFilters(new AllExceptionsFilter()); diff --git a/src/ws/redis-io.adapter.ts b/src/ws/redis-io.adapter.ts new file mode 100644 index 0000000..421ba15 --- /dev/null +++ b/src/ws/redis-io.adapter.ts @@ -0,0 +1,81 @@ +import { IoAdapter } from '@nestjs/platform-socket.io'; +import { ServerOptions } from 'socket.io'; +import { createAdapter } from '@socket.io/redis-adapter'; +import Redis from 'ioredis'; +import { ConfigService } from '@nestjs/config'; +import { INestApplicationContext, Logger } from '@nestjs/common'; + +export class RedisIoAdapter extends IoAdapter { + private adapterConstructor: ReturnType; + private readonly logger = new Logger(RedisIoAdapter.name); + + constructor( + private app: INestApplicationContext, + private configService: ConfigService, + ) { + super(app); + } + + async connectToRedis(): Promise { + const host = this.configService.get('REDIS_HOST'); + const port = this.configService.get('REDIS_PORT'); + const password = this.configService.get('REDIS_PASSWORD'); + + // Only set up Redis adapter if host is configured + if (!host) { + this.logger.log('Redis adapter disabled (REDIS_HOST not set)'); + return; + } + + this.logger.log(`Connecting Redis adapter to ${host}:${port || 6379}`); + + try { + const pubClient = new Redis({ + host, + port: port || 6379, + password: password, + retryStrategy(times) { + // Retry connecting but don't crash if Redis is temporarily down during startup + return Math.min(times * 50, 2000); + }, + }); + + const subClient = pubClient.duplicate(); + + // Wait for connection to ensure it's valid + await new Promise((resolve, reject) => { + pubClient.once('connect', () => { + this.logger.log('Redis Pub client connected'); + resolve(); + }); + pubClient.once('error', (err) => { + this.logger.error('Redis Pub client error', err); + reject(err); + }); + }); + + // Handle subsequent errors gracefully + pubClient.on('error', (err) => { + this.logger.error('Redis Pub client error', err); + }); + subClient.on('error', (err) => { + this.logger.error('Redis Sub client error', err); + }); + + this.adapterConstructor = createAdapter(pubClient, subClient); + this.logger.log('Redis adapter initialized successfully'); + } catch (error) { + this.logger.error('Failed to initialize Redis adapter', error); + // We don't throw here to allow the app to start without Redis if connection fails, + // though functionality will be degraded if multiple instances are running. + } + } + + createIOServer(port: number, options?: ServerOptions): any { + const server = super.createIOServer(port, options); + if (this.adapterConstructor) { + server.adapter(this.adapterConstructor); + } + return server; + } +} diff --git a/src/ws/state/state.gateway.spec.ts b/src/ws/state/state.gateway.spec.ts index 3ec713e..6d58e19 100644 --- a/src/ws/state/state.gateway.spec.ts +++ b/src/ws/state/state.gateway.spec.ts @@ -5,12 +5,17 @@ import { AuthenticatedSocket } from '../../types/socket'; import { AuthService } from '../../auth/auth.service'; import { JwtVerificationService } from '../../auth/services/jwt-verification.service'; import { PrismaService } from '../../database/prisma.service'; +import { UserSocketService } from './user-socket.service'; interface MockSocket extends Partial { id: string; data: { user?: { keycloakSub: string; + email?: string; + name?: string; + preferred_username?: string; + picture?: string; }; userId?: string; friends?: Set; @@ -31,6 +36,7 @@ describe('StateGateway', () => { let mockAuthService: Partial; let mockJwtVerificationService: Partial; let mockPrismaService: Partial; + let mockUserSocketService: Partial; beforeEach(async () => { mockServer = { @@ -66,6 +72,14 @@ describe('StateGateway', () => { }, }; + mockUserSocketService = { + setSocket: jest.fn().mockResolvedValue(undefined), + removeSocket: jest.fn().mockResolvedValue(undefined), + getSocket: jest.fn().mockResolvedValue(null), + isUserOnline: jest.fn().mockResolvedValue(false), + getFriendsSockets: jest.fn().mockResolvedValue([]), + }; + const module: TestingModule = await Test.createTestingModule({ providers: [ StateGateway, @@ -75,6 +89,7 @@ describe('StateGateway', () => { useValue: mockJwtVerificationService, }, { provide: PrismaService, useValue: mockPrismaService }, + { provide: UserSocketService, useValue: mockUserSocketService }, ], }).compile(); @@ -130,6 +145,10 @@ describe('StateGateway', () => { keycloakSub: 'test-sub', }), ); + expect(mockUserSocketService.setSocket).toHaveBeenCalledWith( + 'user-id', + 'client1', + ); expect(mockLoggerLog).toHaveBeenCalledWith( `Client id: ${mockClient.id} connected (user: test-sub)`, ); @@ -165,35 +184,57 @@ describe('StateGateway', () => { }); describe('handleDisconnect', () => { - it('should log client disconnection', () => { + it('should log client disconnection', async () => { const mockClient: MockSocket = { id: 'client1', data: { user: { keycloakSub: 'test-sub' } }, }; - gateway.handleDisconnect(mockClient as unknown as AuthenticatedSocket); + await gateway.handleDisconnect(mockClient as unknown as AuthenticatedSocket); expect(mockLoggerLog).toHaveBeenCalledWith( `Client id: ${mockClient.id} disconnected (user: test-sub)`, ); }); - it('should handle disconnection when no user data', () => { + it('should handle disconnection when no user data', async () => { const mockClient: MockSocket = { id: 'client1', data: {}, }; - gateway.handleDisconnect(mockClient as unknown as AuthenticatedSocket); + await gateway.handleDisconnect(mockClient as unknown as AuthenticatedSocket); expect(mockLoggerLog).toHaveBeenCalledWith( `Client id: ${mockClient.id} disconnected (user: unknown)`, ); }); + + it('should remove socket if it matches', async () => { + const mockClient: MockSocket = { + id: 'client1', + data: { + user: { keycloakSub: 'test-sub' }, + userId: 'user-id', + friends: new Set(['friend-1']), + }, + }; + + (mockUserSocketService.getSocket as jest.Mock).mockResolvedValue('client1'); + (mockUserSocketService.getFriendsSockets as jest.Mock).mockResolvedValue([ + { userId: 'friend-1', socketId: 'friend-socket-id' } + ]); + + await gateway.handleDisconnect(mockClient as unknown as AuthenticatedSocket); + + expect(mockUserSocketService.getSocket).toHaveBeenCalledWith('user-id'); + expect(mockUserSocketService.removeSocket).toHaveBeenCalledWith('user-id'); + expect(mockServer.to).toHaveBeenCalledWith('friend-socket-id'); + }); }); describe('handleCursorReportPosition', () => { - it('should emit cursor position to connected friends', () => { + it('should emit cursor position to connected friends', async () => { const mockClient: MockSocket = { id: 'client1', data: { @@ -203,13 +244,14 @@ describe('StateGateway', () => { }, }; - // Setup the userSocketMap to simulate a connected friend - // eslint-disable-next-line @typescript-eslint/no-unsafe-member-access - (gateway as any).userSocketMap.set('friend-1', 'friend-socket-id'); + // Mock getFriendsSockets to return the friend's socket + (mockUserSocketService.getFriendsSockets as jest.Mock).mockResolvedValue([ + { userId: 'friend-1', socketId: 'friend-socket-id' }, + ]); const data: CursorPositionDto = { x: 100, y: 200 }; - gateway.handleCursorReportPosition( + await gateway.handleCursorReportPosition( mockClient as unknown as AuthenticatedSocket, data, ); @@ -224,7 +266,7 @@ describe('StateGateway', () => { }); }); - it('should not emit when no friends are online', () => { + it('should not emit when no friends are online', async () => { const mockClient: MockSocket = { id: 'client1', data: { @@ -234,10 +276,12 @@ describe('StateGateway', () => { }, }; - // Don't set up userSocketMap - friend is not online + // Mock getFriendsSockets to return empty array + (mockUserSocketService.getFriendsSockets as jest.Mock).mockResolvedValue([]); + const data: CursorPositionDto = { x: 100, y: 200 }; - gateway.handleCursorReportPosition( + await gateway.handleCursorReportPosition( mockClient as unknown as AuthenticatedSocket, data, ); @@ -246,7 +290,7 @@ describe('StateGateway', () => { expect(mockServer.to).not.toHaveBeenCalled(); }); - it('should log warning when userId is missing', () => { + it('should log warning when userId is missing', async () => { const mockClient: MockSocket = { id: 'client1', data: { @@ -258,7 +302,7 @@ describe('StateGateway', () => { const data: CursorPositionDto = { x: 100, y: 200 }; - gateway.handleCursorReportPosition( + await gateway.handleCursorReportPosition( mockClient as unknown as AuthenticatedSocket, data, ); @@ -271,19 +315,19 @@ describe('StateGateway', () => { expect(mockServer.to).not.toHaveBeenCalled(); }); - it('should throw exception when client is not authenticated', () => { + it('should throw exception when client is not authenticated', async () => { const mockClient: MockSocket = { id: 'client1', data: {}, }; const data: CursorPositionDto = { x: 100, y: 200 }; - expect(() => { + await expect( gateway.handleCursorReportPosition( mockClient as unknown as AuthenticatedSocket, data, - ); - }).toThrow('Unauthorized'); + ), + ).rejects.toThrow('Unauthorized'); }); }); }); diff --git a/src/ws/state/state.gateway.ts b/src/ws/state/state.gateway.ts index 9a7b93b..fc2970d 100644 --- a/src/ws/state/state.gateway.ts +++ b/src/ws/state/state.gateway.ts @@ -16,6 +16,7 @@ import { AuthService } from '../../auth/auth.service'; import { JwtVerificationService } from '../../auth/services/jwt-verification.service'; import { CursorPositionDto } from '../dto/cursor-position.dto'; import { PrismaService } from '../../database/prisma.service'; +import { UserSocketService } from './user-socket.service'; import { FriendEvents } from '../../friends/events/friend.events'; import type { @@ -45,7 +46,6 @@ export class StateGateway implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect { private readonly logger = new Logger(StateGateway.name); - private userSocketMap: Map = new Map(); private lastBroadcastMap: Map = new Map(); @WebSocketServer() io: Server; @@ -54,6 +54,7 @@ export class StateGateway private readonly authService: AuthService, private readonly jwtVerificationService: JwtVerificationService, private readonly prisma: PrismaService, + private readonly userSocketService: UserSocketService, ) {} afterInit() { @@ -94,7 +95,7 @@ export class StateGateway this.logger.log(`WebSocket authenticated: ${payload.sub}`); const user = await this.authService.syncUserFromToken(client.data.user); - this.userSocketMap.set(user.id, client.id); + await this.userSocketService.setSocket(user.id, client.id); client.data.userId = user.id; // Initialize friends cache using Prisma directly @@ -117,7 +118,7 @@ export class StateGateway } } - handleDisconnect(client: AuthenticatedSocket) { + async handleDisconnect(client: AuthenticatedSocket) { const user = client.data.user; if (user) { @@ -125,34 +126,28 @@ export class StateGateway if (userId) { // Check if this socket is still the active one for the user - const currentSocketId = this.userSocketMap.get(userId); + const currentSocketId = await this.userSocketService.getSocket(userId); if (currentSocketId === client.id) { - this.userSocketMap.delete(userId); + await this.userSocketService.removeSocket(userId); this.lastBroadcastMap.delete(userId); // Notify friends that this user has disconnected const friends = client.data.friends; if (friends) { - for (const friendId of friends) { - const friendSocketId = this.userSocketMap.get(friendId); - if (friendSocketId) { - this.io.to(friendSocketId).emit(WS_EVENT.FRIEND_DISCONNECTED, { - userId: userId, - }); - } + const friendIds = Array.from(friends); + const friendSockets = await this.userSocketService.getFriendsSockets(friendIds); + + for (const { socketId } of friendSockets) { + this.io.to(socketId).emit(WS_EVENT.FRIEND_DISCONNECTED, { + userId: userId, + }); } } } - } else { - // Fallback for cases where client.data.userId might not be set - for (const [uid, socketId] of this.userSocketMap.entries()) { - if (socketId === client.id) { - this.userSocketMap.delete(uid); - this.lastBroadcastMap.delete(uid); - break; - } - } } + // Note: We can't iterate over Redis keys easily to find socketId match without userId + // The previous fallback loop over map entries is not efficient with Redis. + // We rely on client.data.userId being set correctly during connection. } this.logger.log( @@ -160,12 +155,12 @@ export class StateGateway ); } - isUserOnline(userId: string): boolean { - return this.userSocketMap.has(userId); + async isUserOnline(userId: string): Promise { + return this.userSocketService.isUserOnline(userId); } @SubscribeMessage(WS_EVENT.CURSOR_REPORT_POSITION) - handleCursorReportPosition( + async handleCursorReportPosition( client: AuthenticatedSocket, data: CursorPositionDto, ) { @@ -192,25 +187,25 @@ export class StateGateway // Broadcast to online friends const friends = client.data.friends; if (friends) { - for (const friendId of friends) { - const friendSocketId = this.userSocketMap.get(friendId); - if (friendSocketId) { - const payload = { - userId: currentUserId, - position: data, - }; - this.io - .to(friendSocketId) - .emit(WS_EVENT.FRIEND_CURSOR_POSITION, payload); - } + const friendIds = Array.from(friends); + const friendSockets = await this.userSocketService.getFriendsSockets(friendIds); + + for (const { socketId } of friendSockets) { + const payload = { + userId: currentUserId, + position: data, + }; + this.io + .to(socketId) + .emit(WS_EVENT.FRIEND_CURSOR_POSITION, payload); } } } @OnEvent(FriendEvents.REQUEST_RECEIVED) - handleFriendRequestReceived(payload: FriendRequestReceivedEvent) { + async handleFriendRequestReceived(payload: FriendRequestReceivedEvent) { const { userId, friendRequest } = payload; - const socketId = this.userSocketMap.get(userId); + const socketId = await this.userSocketService.getSocket(userId); if (socketId) { this.io.to(socketId).emit(WS_EVENT.FRIEND_REQUEST_RECEIVED, { id: friendRequest.id, @@ -229,10 +224,10 @@ export class StateGateway } @OnEvent(FriendEvents.REQUEST_ACCEPTED) - handleFriendRequestAccepted(payload: FriendRequestAcceptedEvent) { + async handleFriendRequestAccepted(payload: FriendRequestAcceptedEvent) { const { userId, friendRequest } = payload; - const socketId = this.userSocketMap.get(userId); + const socketId = await this.userSocketService.getSocket(userId); // 1. Update cache for the user who sent the request (userId / friendRequest.senderId) if (socketId) { @@ -259,7 +254,7 @@ export class StateGateway } // 2. Update cache for the user who accepted the request (friendRequest.receiverId) - const receiverSocketId = this.userSocketMap.get(friendRequest.receiverId); + const receiverSocketId = await this.userSocketService.getSocket(friendRequest.receiverId); if (receiverSocketId) { const receiverSocket = this.io.sockets.sockets.get( receiverSocketId, @@ -271,9 +266,9 @@ export class StateGateway } @OnEvent(FriendEvents.REQUEST_DENIED) - handleFriendRequestDenied(payload: FriendRequestDeniedEvent) { + async handleFriendRequestDenied(payload: FriendRequestDeniedEvent) { const { userId, friendRequest } = payload; - const socketId = this.userSocketMap.get(userId); + const socketId = await this.userSocketService.getSocket(userId); if (socketId) { this.io.to(socketId).emit(WS_EVENT.FRIEND_REQUEST_DENIED, { id: friendRequest.id, @@ -292,10 +287,10 @@ export class StateGateway } @OnEvent(FriendEvents.UNFRIENDED) - handleUnfriended(payload: UnfriendedEvent) { + async handleUnfriended(payload: UnfriendedEvent) { const { userId, friendId } = payload; - const socketId = this.userSocketMap.get(userId); + const socketId = await this.userSocketService.getSocket(userId); // 1. Update cache for the user receiving the notification (userId) if (socketId) { @@ -313,7 +308,7 @@ export class StateGateway } // 2. Update cache for the user initiating the unfriend (friendId) - const initiatorSocketId = this.userSocketMap.get(friendId); + const initiatorSocketId = await this.userSocketService.getSocket(friendId); if (initiatorSocketId) { const initiatorSocket = this.io.sockets.sockets.get( initiatorSocketId, diff --git a/src/ws/state/user-socket.service.ts b/src/ws/state/user-socket.service.ts new file mode 100644 index 0000000..bee9bdc --- /dev/null +++ b/src/ws/state/user-socket.service.ts @@ -0,0 +1,113 @@ +import { Injectable, Inject, Logger } from '@nestjs/common'; +import { REDIS_CLIENT } from '../../database/redis.module'; +import Redis from 'ioredis'; + +@Injectable() +export class UserSocketService { + private readonly logger = new Logger(UserSocketService.name); + private localUserSocketMap: Map = new Map(); + private readonly PREFIX = 'socket:user:'; + private readonly TTL = 86400; // 24 hours + + constructor( + @Inject(REDIS_CLIENT) private readonly redisClient: Redis | null, + ) {} + + async setSocket(userId: string, socketId: string): Promise { + if (this.redisClient) { + try { + await this.redisClient.set( + `${this.PREFIX}${userId}`, + socketId, + 'EX', + this.TTL, + ); + } catch (error) { + this.logger.error( + `Failed to set socket for user ${userId} in Redis`, + error, + ); + // Fallback to local map on error? Or just log? + // Let's use local map as backup if redis is down/null + this.localUserSocketMap.set(userId, socketId); + } + } else { + this.localUserSocketMap.set(userId, socketId); + } + } + + async removeSocket(userId: string): Promise { + if (this.redisClient) { + try { + await this.redisClient.del(`${this.PREFIX}${userId}`); + } catch (error) { + this.logger.error( + `Failed to remove socket for user ${userId} from Redis`, + error, + ); + } + } + this.localUserSocketMap.delete(userId); + } + + async getSocket(userId: string): Promise { + if (this.redisClient) { + try { + const socketId = await this.redisClient.get(`${this.PREFIX}${userId}`); + return socketId; + } catch (error) { + this.logger.error( + `Failed to get socket for user ${userId} from Redis`, + error, + ); + return this.localUserSocketMap.get(userId) || null; + } + } + return this.localUserSocketMap.get(userId) || null; + } + + async isUserOnline(userId: string): Promise { + const socketId = await this.getSocket(userId); + return !!socketId; + } + + async getFriendsSockets(friendIds: string[]): Promise<{ userId: string; socketId: string }[]> { + if (friendIds.length === 0) { + return []; + } + + if (this.redisClient) { + try { + // Use pipeline for batch fetching + const pipeline = this.redisClient.pipeline(); + friendIds.forEach((id) => pipeline.get(`${this.PREFIX}${id}`)); + const results = await pipeline.exec(); + + const sockets: { userId: string; socketId: string }[] = []; + + if (results) { + results.forEach((result, index) => { + const [err, socketId] = result; + if (!err && socketId && typeof socketId === 'string') { + sockets.push({ userId: friendIds[index], socketId }); + } + }); + } + return sockets; + } catch (error) { + this.logger.error('Failed to batch get friend sockets from Redis', error); + // Fallback to local implementation + } + } + + // Local fallback + const sockets: { userId: string; socketId: string }[] = []; + for (const friendId of friendIds) { + const socketId = this.localUserSocketMap.get(friendId); + if (socketId) { + sockets.push({ userId: friendId, socketId }); + } + } + return sockets; + } +} diff --git a/src/ws/ws.module.ts b/src/ws/ws.module.ts index 3158cd3..236413c 100644 --- a/src/ws/ws.module.ts +++ b/src/ws/ws.module.ts @@ -1,11 +1,13 @@ import { Module, forwardRef } from '@nestjs/common'; import { StateGateway } from './state/state.gateway'; +import { UserSocketService } from './state/user-socket.service'; import { AuthModule } from '../auth/auth.module'; import { FriendsModule } from '../friends/friends.module'; +import { RedisModule } from '../database/redis.module'; @Module({ - imports: [AuthModule, forwardRef(() => FriendsModule)], - providers: [StateGateway], + imports: [AuthModule, RedisModule, forwardRef(() => FriendsModule)], + providers: [StateGateway, UserSocketService], exports: [StateGateway], }) export class WsModule {}