feat(ws): harden Redis socket adapter lifecycle
This commit is contained in:
2
src/types/socket.d.ts
vendored
2
src/types/socket.d.ts
vendored
@@ -11,5 +11,7 @@ export type AuthenticatedSocket = BaseSocket<
|
|||||||
userId?: string;
|
userId?: string;
|
||||||
activeDollId?: string | null;
|
activeDollId?: string | null;
|
||||||
friends?: Set<string>; // Set of friend user IDs
|
friends?: Set<string>; // Set of friend user IDs
|
||||||
|
senderName?: string;
|
||||||
|
senderNameCachedAt?: number;
|
||||||
}
|
}
|
||||||
>;
|
>;
|
||||||
|
|||||||
@@ -4,10 +4,18 @@ import { createAdapter } from '@socket.io/redis-adapter';
|
|||||||
import Redis from 'ioredis';
|
import Redis from 'ioredis';
|
||||||
import { ConfigService } from '@nestjs/config';
|
import { ConfigService } from '@nestjs/config';
|
||||||
import { INestApplicationContext, Logger } from '@nestjs/common';
|
import { INestApplicationContext, Logger } from '@nestjs/common';
|
||||||
|
import {
|
||||||
|
parsePositiveInteger,
|
||||||
|
parseRedisRequired,
|
||||||
|
} from '../common/config/env.utils';
|
||||||
|
|
||||||
|
const DEFAULT_REDIS_STARTUP_RETRIES = 10;
|
||||||
|
|
||||||
export class RedisIoAdapter extends IoAdapter {
|
export class RedisIoAdapter extends IoAdapter {
|
||||||
private adapterConstructor: ReturnType<typeof createAdapter>;
|
private adapterConstructor: ReturnType<typeof createAdapter>;
|
||||||
private readonly logger = new Logger(RedisIoAdapter.name);
|
private readonly logger = new Logger(RedisIoAdapter.name);
|
||||||
|
private pubClient: Redis | null = null;
|
||||||
|
private subClient: Redis | null = null;
|
||||||
|
|
||||||
constructor(
|
constructor(
|
||||||
private app: INestApplicationContext,
|
private app: INestApplicationContext,
|
||||||
@@ -18,41 +26,63 @@ export class RedisIoAdapter extends IoAdapter {
|
|||||||
|
|
||||||
async connectToRedis(): Promise<void> {
|
async connectToRedis(): Promise<void> {
|
||||||
const host = this.configService.get<string>('REDIS_HOST');
|
const host = this.configService.get<string>('REDIS_HOST');
|
||||||
const port = this.configService.get<number>('REDIS_PORT');
|
const port = parsePositiveInteger(
|
||||||
|
this.configService.get<string>('REDIS_PORT'),
|
||||||
|
6379,
|
||||||
|
);
|
||||||
const password = this.configService.get<string>('REDIS_PASSWORD');
|
const password = this.configService.get<string>('REDIS_PASSWORD');
|
||||||
|
const startupRetries = parsePositiveInteger(
|
||||||
|
this.configService.get<string>('REDIS_STARTUP_RETRIES'),
|
||||||
|
DEFAULT_REDIS_STARTUP_RETRIES,
|
||||||
|
);
|
||||||
|
const redisRequired = parseRedisRequired({
|
||||||
|
nodeEnv: this.configService.get<string>('NODE_ENV'),
|
||||||
|
redisRequired: this.configService.get<string>('REDIS_REQUIRED'),
|
||||||
|
});
|
||||||
|
|
||||||
// Only set up Redis adapter if host is configured
|
|
||||||
if (!host) {
|
if (!host) {
|
||||||
|
if (redisRequired) {
|
||||||
|
throw new Error(
|
||||||
|
'REDIS_REQUIRED is enabled but REDIS_HOST is not configured',
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
this.logger.log('Redis adapter disabled (REDIS_HOST not set)');
|
this.logger.log('Redis adapter disabled (REDIS_HOST not set)');
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
this.logger.log(`Connecting Redis adapter to ${host}:${port || 6379}`);
|
this.logger.log(`Connecting Redis adapter to ${host}:${port}`);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
const connectTimeout = parsePositiveInteger(
|
||||||
|
this.configService.get<string>('REDIS_CONNECT_TIMEOUT_MS'),
|
||||||
|
5000,
|
||||||
|
);
|
||||||
const pubClient = new Redis({
|
const pubClient = new Redis({
|
||||||
host,
|
host,
|
||||||
port: port || 6379,
|
port,
|
||||||
password: password,
|
password,
|
||||||
|
lazyConnect: true,
|
||||||
|
connectTimeout,
|
||||||
|
maxRetriesPerRequest: 1,
|
||||||
|
enableOfflineQueue: false,
|
||||||
retryStrategy(times) {
|
retryStrategy(times) {
|
||||||
// Retry connecting but don't crash if Redis is temporarily down during startup
|
if (times > startupRetries) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
return Math.min(times * 50, 2000);
|
return Math.min(times * 50, 2000);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const subClient = pubClient.duplicate();
|
const subClient = pubClient.duplicate();
|
||||||
|
|
||||||
// Wait for connection to ensure it's valid
|
await pubClient.connect();
|
||||||
await new Promise<void>((resolve, reject) => {
|
await subClient.connect();
|
||||||
pubClient.once('connect', () => {
|
await pubClient.ping();
|
||||||
this.logger.log('Redis Pub client connected');
|
await subClient.ping();
|
||||||
resolve();
|
|
||||||
});
|
this.logger.log('Redis Pub/Sub clients connected');
|
||||||
pubClient.once('error', (err) => {
|
|
||||||
this.logger.error('Redis Pub client error', err);
|
|
||||||
reject(err);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
// Handle subsequent errors gracefully
|
// Handle subsequent errors gracefully
|
||||||
pubClient.on('error', (err) => {
|
pubClient.on('error', (err) => {
|
||||||
@@ -73,21 +103,53 @@ export class RedisIoAdapter extends IoAdapter {
|
|||||||
});
|
});
|
||||||
|
|
||||||
this.adapterConstructor = createAdapter(pubClient, subClient);
|
this.adapterConstructor = createAdapter(pubClient, subClient);
|
||||||
|
this.pubClient = pubClient;
|
||||||
|
this.subClient = subClient;
|
||||||
this.logger.log('Redis adapter initialized successfully');
|
this.logger.log('Redis adapter initialized successfully');
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
await this.close();
|
||||||
this.logger.error('Failed to initialize Redis adapter', error);
|
this.logger.error('Failed to initialize Redis adapter', error);
|
||||||
// We don't throw here to allow the app to start without Redis if connection fails,
|
|
||||||
// though functionality will be degraded if multiple instances are running.
|
if (redisRequired) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
createIOServer(port: number, options?: ServerOptions): any {
|
createIOServer(port: number, options?: ServerOptions): any {
|
||||||
|
const cors = {
|
||||||
|
origin: true,
|
||||||
|
credentials: true,
|
||||||
|
};
|
||||||
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
|
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
|
||||||
const server = super.createIOServer(port, options);
|
const server = super.createIOServer(port, {
|
||||||
|
...(options ?? {}),
|
||||||
|
cors,
|
||||||
|
});
|
||||||
if (this.adapterConstructor) {
|
if (this.adapterConstructor) {
|
||||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-call, @typescript-eslint/no-unsafe-member-access
|
// eslint-disable-next-line @typescript-eslint/no-unsafe-call, @typescript-eslint/no-unsafe-member-access
|
||||||
server.adapter(this.adapterConstructor);
|
server.adapter(this.adapterConstructor);
|
||||||
}
|
}
|
||||||
return server;
|
return server;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async close(): Promise<void> {
|
||||||
|
const clients = [this.pubClient, this.subClient].filter(
|
||||||
|
(client): client is Redis => client !== null,
|
||||||
|
);
|
||||||
|
|
||||||
|
await Promise.all(
|
||||||
|
clients.map(async (client) => {
|
||||||
|
try {
|
||||||
|
await client.quit();
|
||||||
|
} catch {
|
||||||
|
client.disconnect();
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
this.pubClient = null;
|
||||||
|
this.subClient = null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,13 +6,11 @@ import { PrismaService } from '../../../database/prisma.service';
|
|||||||
import { UserSocketService } from '../user-socket.service';
|
import { UserSocketService } from '../user-socket.service';
|
||||||
import { WsNotificationService } from '../ws-notification.service';
|
import { WsNotificationService } from '../ws-notification.service';
|
||||||
import { WS_EVENT } from '../ws-events';
|
import { WS_EVENT } from '../ws-events';
|
||||||
import { UsersService } from '../../../users/users.service';
|
|
||||||
|
|
||||||
export class ConnectionHandler {
|
export class ConnectionHandler {
|
||||||
constructor(
|
constructor(
|
||||||
private readonly jwtVerificationService: JwtVerificationService,
|
private readonly jwtVerificationService: JwtVerificationService,
|
||||||
private readonly prisma: PrismaService,
|
private readonly prisma: PrismaService,
|
||||||
private readonly usersService: UsersService,
|
|
||||||
private readonly userSocketService: UserSocketService,
|
private readonly userSocketService: UserSocketService,
|
||||||
private readonly wsNotificationService: WsNotificationService,
|
private readonly wsNotificationService: WsNotificationService,
|
||||||
private readonly logger: Logger,
|
private readonly logger: Logger,
|
||||||
@@ -94,42 +92,42 @@ export class ConnectionHandler {
|
|||||||
this.logger.log(
|
this.logger.log(
|
||||||
`WebSocket authenticated via initialize fallback (Pending Init): ${payload.sub}`,
|
`WebSocket authenticated via initialize fallback (Pending Init): ${payload.sub}`,
|
||||||
);
|
);
|
||||||
|
|
||||||
this.logger.log(
|
|
||||||
`WebSocket authenticated via initialize fallback (Pending Init): ${payload.sub}`,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!userTokenData) {
|
if (!userTokenData) {
|
||||||
throw new WsException('Unauthorized: No user data found');
|
throw new WsException('Unauthorized: No user data found');
|
||||||
}
|
}
|
||||||
|
|
||||||
const user = await this.usersService.findOne(userTokenData.userId);
|
// 2. Fetch initial state (DB Read)
|
||||||
|
const [userState, friends] = await Promise.all([
|
||||||
// 2. Register socket mapping (Redis Write)
|
|
||||||
await this.userSocketService.setSocket(user.id, client.id);
|
|
||||||
client.data.userId = user.id;
|
|
||||||
|
|
||||||
// 3. Fetch initial state (DB Read)
|
|
||||||
const [userWithDoll, friends] = await Promise.all([
|
|
||||||
this.prisma.user.findUnique({
|
this.prisma.user.findUnique({
|
||||||
where: { id: user.id },
|
where: { id: userTokenData.userId },
|
||||||
select: { activeDollId: true },
|
select: { id: true, name: true, username: true, activeDollId: true },
|
||||||
}),
|
}),
|
||||||
this.prisma.friendship.findMany({
|
this.prisma.friendship.findMany({
|
||||||
where: { userId: user.id },
|
where: { userId: userTokenData.userId },
|
||||||
select: { friendId: true },
|
select: { friendId: true },
|
||||||
}),
|
}),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
client.data.activeDollId = userWithDoll?.activeDollId || null;
|
if (!userState) {
|
||||||
client.data.friends = new Set(friends.map((f) => f.friendId));
|
throw new WsException('Unauthorized: No user data found');
|
||||||
|
}
|
||||||
|
|
||||||
this.logger.log(`Client initialized: ${user.id} (${client.id})`);
|
// 3. Register socket mapping (Redis Write)
|
||||||
|
await this.userSocketService.setSocket(userState.id, client.id);
|
||||||
|
client.data.userId = userState.id;
|
||||||
|
|
||||||
|
client.data.activeDollId = userState.activeDollId || null;
|
||||||
|
client.data.friends = new Set(friends.map((f) => f.friendId));
|
||||||
|
client.data.senderName = userState.name || userState.username;
|
||||||
|
client.data.senderNameCachedAt = Date.now();
|
||||||
|
|
||||||
|
this.logger.log(`Client initialized: ${userState.id} (${client.id})`);
|
||||||
|
|
||||||
// 4. Notify client
|
// 4. Notify client
|
||||||
client.emit(WS_EVENT.INITIALIZED, {
|
client.emit(WS_EVENT.INITIALIZED, {
|
||||||
userId: user.id,
|
userId: userState.id,
|
||||||
activeDollId: client.data.activeDollId,
|
activeDollId: client.data.activeDollId,
|
||||||
});
|
});
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@@ -157,7 +155,9 @@ export class ConnectionHandler {
|
|||||||
// Notify friends that this user has disconnected
|
// Notify friends that this user has disconnected
|
||||||
const friends = client.data.friends;
|
const friends = client.data.friends;
|
||||||
if (friends) {
|
if (friends) {
|
||||||
const friendIds = Array.from(friends);
|
const friendIds = Array.from(friends).filter(
|
||||||
|
(friendId): friendId is string => typeof friendId === 'string',
|
||||||
|
);
|
||||||
const friendSockets =
|
const friendSockets =
|
||||||
await this.userSocketService.getFriendsSockets(friendIds);
|
await this.userSocketService.getFriendsSockets(friendIds);
|
||||||
|
|
||||||
@@ -179,9 +179,5 @@ export class ConnectionHandler {
|
|||||||
this.logger.log(
|
this.logger.log(
|
||||||
`Client id: ${client.id} disconnected (user: ${user?.userId || 'unknown'})`,
|
`Client id: ${client.id} disconnected (user: ${user?.userId || 'unknown'})`,
|
||||||
);
|
);
|
||||||
|
|
||||||
this.logger.log(
|
|
||||||
`Client id: ${client.id} disconnected (user: ${user?.userId || 'unknown'})`,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { Logger, Inject } from '@nestjs/common';
|
import { Logger, Inject, OnModuleDestroy } from '@nestjs/common';
|
||||||
import {
|
import {
|
||||||
OnGatewayConnection,
|
OnGatewayConnection,
|
||||||
OnGatewayDisconnect,
|
OnGatewayDisconnect,
|
||||||
@@ -22,7 +22,6 @@ import { PrismaService } from '../../database/prisma.service';
|
|||||||
import { UserSocketService } from './user-socket.service';
|
import { UserSocketService } from './user-socket.service';
|
||||||
import { WsNotificationService } from './ws-notification.service';
|
import { WsNotificationService } from './ws-notification.service';
|
||||||
import { WS_EVENT, REDIS_CHANNEL } from './ws-events';
|
import { WS_EVENT, REDIS_CHANNEL } from './ws-events';
|
||||||
import { UsersService } from '../../users/users.service';
|
|
||||||
import { ConnectionHandler } from './connection/handler';
|
import { ConnectionHandler } from './connection/handler';
|
||||||
import { CursorHandler } from './cursor/handler';
|
import { CursorHandler } from './cursor/handler';
|
||||||
import { StatusHandler } from './status/handler';
|
import { StatusHandler } from './status/handler';
|
||||||
@@ -31,14 +30,13 @@ import { RedisHandler } from './utils/redis-handler';
|
|||||||
import { Broadcaster } from './utils/broadcasting';
|
import { Broadcaster } from './utils/broadcasting';
|
||||||
import { Throttler } from './utils/throttling';
|
import { Throttler } from './utils/throttling';
|
||||||
|
|
||||||
@WebSocketGateway({
|
@WebSocketGateway()
|
||||||
cors: {
|
|
||||||
origin: true,
|
|
||||||
credentials: true,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
export class StateGateway
|
export class StateGateway
|
||||||
implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect
|
implements
|
||||||
|
OnGatewayInit,
|
||||||
|
OnGatewayConnection,
|
||||||
|
OnGatewayDisconnect,
|
||||||
|
OnModuleDestroy
|
||||||
{
|
{
|
||||||
private readonly logger = new Logger(StateGateway.name);
|
private readonly logger = new Logger(StateGateway.name);
|
||||||
|
|
||||||
@@ -55,7 +53,6 @@ export class StateGateway
|
|||||||
constructor(
|
constructor(
|
||||||
private readonly jwtVerificationService: JwtVerificationService,
|
private readonly jwtVerificationService: JwtVerificationService,
|
||||||
private readonly prisma: PrismaService,
|
private readonly prisma: PrismaService,
|
||||||
private readonly usersService: UsersService,
|
|
||||||
private readonly userSocketService: UserSocketService,
|
private readonly userSocketService: UserSocketService,
|
||||||
private readonly wsNotificationService: WsNotificationService,
|
private readonly wsNotificationService: WsNotificationService,
|
||||||
@Inject(REDIS_CLIENT) private readonly redisClient: Redis | null,
|
@Inject(REDIS_CLIENT) private readonly redisClient: Redis | null,
|
||||||
@@ -70,7 +67,6 @@ export class StateGateway
|
|||||||
this.connectionHandler = new ConnectionHandler(
|
this.connectionHandler = new ConnectionHandler(
|
||||||
this.jwtVerificationService,
|
this.jwtVerificationService,
|
||||||
this.prisma,
|
this.prisma,
|
||||||
this.usersService,
|
|
||||||
this.userSocketService,
|
this.userSocketService,
|
||||||
this.wsNotificationService,
|
this.wsNotificationService,
|
||||||
this.logger,
|
this.logger,
|
||||||
@@ -156,11 +152,16 @@ export class StateGateway
|
|||||||
await this.statusHandler.handleClientReportUserStatus(client, data);
|
await this.statusHandler.handleClientReportUserStatus(client, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
@SubscribeMessage(WS_EVENT.CLIENT_SEND_INTERACTION)
|
|
||||||
async handleSendInteraction(
|
async handleSendInteraction(
|
||||||
client: AuthenticatedSocket,
|
client: AuthenticatedSocket,
|
||||||
data: SendInteractionDto,
|
data: SendInteractionDto,
|
||||||
) {
|
) {
|
||||||
await this.interactionHandler.handleSendInteraction(client, data);
|
await this.interactionHandler.handleSendInteraction(client, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
onModuleDestroy() {
|
||||||
|
if (this.redisSubscriber) {
|
||||||
|
this.redisSubscriber.removeAllListeners('message');
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user