diff --git a/src/types/socket.d.ts b/src/types/socket.d.ts index aa96222..8d028c7 100644 --- a/src/types/socket.d.ts +++ b/src/types/socket.d.ts @@ -11,5 +11,7 @@ export type AuthenticatedSocket = BaseSocket< userId?: string; activeDollId?: string | null; friends?: Set; // Set of friend user IDs + senderName?: string; + senderNameCachedAt?: number; } >; diff --git a/src/ws/redis-io.adapter.ts b/src/ws/redis-io.adapter.ts index 206cfc3..f35b924 100644 --- a/src/ws/redis-io.adapter.ts +++ b/src/ws/redis-io.adapter.ts @@ -4,10 +4,18 @@ import { createAdapter } from '@socket.io/redis-adapter'; import Redis from 'ioredis'; import { ConfigService } from '@nestjs/config'; import { INestApplicationContext, Logger } from '@nestjs/common'; +import { + parsePositiveInteger, + parseRedisRequired, +} from '../common/config/env.utils'; + +const DEFAULT_REDIS_STARTUP_RETRIES = 10; export class RedisIoAdapter extends IoAdapter { private adapterConstructor: ReturnType; private readonly logger = new Logger(RedisIoAdapter.name); + private pubClient: Redis | null = null; + private subClient: Redis | null = null; constructor( private app: INestApplicationContext, @@ -18,41 +26,63 @@ export class RedisIoAdapter extends IoAdapter { async connectToRedis(): Promise { const host = this.configService.get('REDIS_HOST'); - const port = this.configService.get('REDIS_PORT'); + const port = parsePositiveInteger( + this.configService.get('REDIS_PORT'), + 6379, + ); const password = this.configService.get('REDIS_PASSWORD'); + const startupRetries = parsePositiveInteger( + this.configService.get('REDIS_STARTUP_RETRIES'), + DEFAULT_REDIS_STARTUP_RETRIES, + ); + const redisRequired = parseRedisRequired({ + nodeEnv: this.configService.get('NODE_ENV'), + redisRequired: this.configService.get('REDIS_REQUIRED'), + }); - // Only set up Redis adapter if host is configured if (!host) { + if (redisRequired) { + throw new Error( + 'REDIS_REQUIRED is enabled but REDIS_HOST is not configured', + ); + } + this.logger.log('Redis adapter disabled (REDIS_HOST not set)'); return; } - this.logger.log(`Connecting Redis adapter to ${host}:${port || 6379}`); + this.logger.log(`Connecting Redis adapter to ${host}:${port}`); try { + const connectTimeout = parsePositiveInteger( + this.configService.get('REDIS_CONNECT_TIMEOUT_MS'), + 5000, + ); const pubClient = new Redis({ host, - port: port || 6379, - password: password, + port, + password, + lazyConnect: true, + connectTimeout, + maxRetriesPerRequest: 1, + enableOfflineQueue: false, retryStrategy(times) { - // Retry connecting but don't crash if Redis is temporarily down during startup + if (times > startupRetries) { + return null; + } + return Math.min(times * 50, 2000); }, }); const subClient = pubClient.duplicate(); - // Wait for connection to ensure it's valid - await new Promise((resolve, reject) => { - pubClient.once('connect', () => { - this.logger.log('Redis Pub client connected'); - resolve(); - }); - pubClient.once('error', (err) => { - this.logger.error('Redis Pub client error', err); - reject(err); - }); - }); + await pubClient.connect(); + await subClient.connect(); + await pubClient.ping(); + await subClient.ping(); + + this.logger.log('Redis Pub/Sub clients connected'); // Handle subsequent errors gracefully pubClient.on('error', (err) => { @@ -73,21 +103,53 @@ export class RedisIoAdapter extends IoAdapter { }); this.adapterConstructor = createAdapter(pubClient, subClient); + this.pubClient = pubClient; + this.subClient = subClient; this.logger.log('Redis adapter initialized successfully'); } catch (error) { + await this.close(); this.logger.error('Failed to initialize Redis adapter', error); - // We don't throw here to allow the app to start without Redis if connection fails, - // though functionality will be degraded if multiple instances are running. + + if (redisRequired) { + throw error; + } } } createIOServer(port: number, options?: ServerOptions): any { + const cors = { + origin: true, + credentials: true, + }; + // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment - const server = super.createIOServer(port, options); + const server = super.createIOServer(port, { + ...(options ?? {}), + cors, + }); if (this.adapterConstructor) { // eslint-disable-next-line @typescript-eslint/no-unsafe-call, @typescript-eslint/no-unsafe-member-access server.adapter(this.adapterConstructor); } return server; } + + async close(): Promise { + const clients = [this.pubClient, this.subClient].filter( + (client): client is Redis => client !== null, + ); + + await Promise.all( + clients.map(async (client) => { + try { + await client.quit(); + } catch { + client.disconnect(); + } + }), + ); + + this.pubClient = null; + this.subClient = null; + } } diff --git a/src/ws/state/connection/handler.ts b/src/ws/state/connection/handler.ts index 450eabf..706759b 100644 --- a/src/ws/state/connection/handler.ts +++ b/src/ws/state/connection/handler.ts @@ -6,13 +6,11 @@ import { PrismaService } from '../../../database/prisma.service'; import { UserSocketService } from '../user-socket.service'; import { WsNotificationService } from '../ws-notification.service'; import { WS_EVENT } from '../ws-events'; -import { UsersService } from '../../../users/users.service'; export class ConnectionHandler { constructor( private readonly jwtVerificationService: JwtVerificationService, private readonly prisma: PrismaService, - private readonly usersService: UsersService, private readonly userSocketService: UserSocketService, private readonly wsNotificationService: WsNotificationService, private readonly logger: Logger, @@ -94,42 +92,42 @@ export class ConnectionHandler { this.logger.log( `WebSocket authenticated via initialize fallback (Pending Init): ${payload.sub}`, ); - - this.logger.log( - `WebSocket authenticated via initialize fallback (Pending Init): ${payload.sub}`, - ); } if (!userTokenData) { throw new WsException('Unauthorized: No user data found'); } - const user = await this.usersService.findOne(userTokenData.userId); - - // 2. Register socket mapping (Redis Write) - await this.userSocketService.setSocket(user.id, client.id); - client.data.userId = user.id; - - // 3. Fetch initial state (DB Read) - const [userWithDoll, friends] = await Promise.all([ + // 2. Fetch initial state (DB Read) + const [userState, friends] = await Promise.all([ this.prisma.user.findUnique({ - where: { id: user.id }, - select: { activeDollId: true }, + where: { id: userTokenData.userId }, + select: { id: true, name: true, username: true, activeDollId: true }, }), this.prisma.friendship.findMany({ - where: { userId: user.id }, + where: { userId: userTokenData.userId }, select: { friendId: true }, }), ]); - client.data.activeDollId = userWithDoll?.activeDollId || null; - client.data.friends = new Set(friends.map((f) => f.friendId)); + if (!userState) { + throw new WsException('Unauthorized: No user data found'); + } - this.logger.log(`Client initialized: ${user.id} (${client.id})`); + // 3. Register socket mapping (Redis Write) + await this.userSocketService.setSocket(userState.id, client.id); + client.data.userId = userState.id; + + client.data.activeDollId = userState.activeDollId || null; + client.data.friends = new Set(friends.map((f) => f.friendId)); + client.data.senderName = userState.name || userState.username; + client.data.senderNameCachedAt = Date.now(); + + this.logger.log(`Client initialized: ${userState.id} (${client.id})`); // 4. Notify client client.emit(WS_EVENT.INITIALIZED, { - userId: user.id, + userId: userState.id, activeDollId: client.data.activeDollId, }); } catch (error) { @@ -157,7 +155,9 @@ export class ConnectionHandler { // Notify friends that this user has disconnected const friends = client.data.friends; if (friends) { - const friendIds = Array.from(friends); + const friendIds = Array.from(friends).filter( + (friendId): friendId is string => typeof friendId === 'string', + ); const friendSockets = await this.userSocketService.getFriendsSockets(friendIds); @@ -179,9 +179,5 @@ export class ConnectionHandler { this.logger.log( `Client id: ${client.id} disconnected (user: ${user?.userId || 'unknown'})`, ); - - this.logger.log( - `Client id: ${client.id} disconnected (user: ${user?.userId || 'unknown'})`, - ); } } diff --git a/src/ws/state/state.gateway.ts b/src/ws/state/state.gateway.ts index 102a890..a541a44 100644 --- a/src/ws/state/state.gateway.ts +++ b/src/ws/state/state.gateway.ts @@ -1,4 +1,4 @@ -import { Logger, Inject } from '@nestjs/common'; +import { Logger, Inject, OnModuleDestroy } from '@nestjs/common'; import { OnGatewayConnection, OnGatewayDisconnect, @@ -22,7 +22,6 @@ import { PrismaService } from '../../database/prisma.service'; import { UserSocketService } from './user-socket.service'; import { WsNotificationService } from './ws-notification.service'; import { WS_EVENT, REDIS_CHANNEL } from './ws-events'; -import { UsersService } from '../../users/users.service'; import { ConnectionHandler } from './connection/handler'; import { CursorHandler } from './cursor/handler'; import { StatusHandler } from './status/handler'; @@ -31,14 +30,13 @@ import { RedisHandler } from './utils/redis-handler'; import { Broadcaster } from './utils/broadcasting'; import { Throttler } from './utils/throttling'; -@WebSocketGateway({ - cors: { - origin: true, - credentials: true, - }, -}) +@WebSocketGateway() export class StateGateway - implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect + implements + OnGatewayInit, + OnGatewayConnection, + OnGatewayDisconnect, + OnModuleDestroy { private readonly logger = new Logger(StateGateway.name); @@ -55,7 +53,6 @@ export class StateGateway constructor( private readonly jwtVerificationService: JwtVerificationService, private readonly prisma: PrismaService, - private readonly usersService: UsersService, private readonly userSocketService: UserSocketService, private readonly wsNotificationService: WsNotificationService, @Inject(REDIS_CLIENT) private readonly redisClient: Redis | null, @@ -70,7 +67,6 @@ export class StateGateway this.connectionHandler = new ConnectionHandler( this.jwtVerificationService, this.prisma, - this.usersService, this.userSocketService, this.wsNotificationService, this.logger, @@ -156,11 +152,16 @@ export class StateGateway await this.statusHandler.handleClientReportUserStatus(client, data); } - @SubscribeMessage(WS_EVENT.CLIENT_SEND_INTERACTION) async handleSendInteraction( client: AuthenticatedSocket, data: SendInteractionDto, ) { await this.interactionHandler.handleSendInteraction(client, data); } + + onModuleDestroy() { + if (this.redisSubscriber) { + this.redisSubscriber.removeAllListeners('message'); + } + } }