production hardening
This commit is contained in:
@@ -4,10 +4,18 @@ import { createAdapter } from '@socket.io/redis-adapter';
|
||||
import Redis from 'ioredis';
|
||||
import { ConfigService } from '@nestjs/config';
|
||||
import { INestApplicationContext, Logger } from '@nestjs/common';
|
||||
import {
|
||||
parsePositiveInteger,
|
||||
parseRedisRequired,
|
||||
} from '../common/config/env.utils';
|
||||
|
||||
const DEFAULT_REDIS_STARTUP_RETRIES = 10;
|
||||
|
||||
export class RedisIoAdapter extends IoAdapter {
|
||||
private adapterConstructor: ReturnType<typeof createAdapter>;
|
||||
private readonly logger = new Logger(RedisIoAdapter.name);
|
||||
private pubClient: Redis | null = null;
|
||||
private subClient: Redis | null = null;
|
||||
|
||||
constructor(
|
||||
private app: INestApplicationContext,
|
||||
@@ -18,41 +26,63 @@ export class RedisIoAdapter extends IoAdapter {
|
||||
|
||||
async connectToRedis(): Promise<void> {
|
||||
const host = this.configService.get<string>('REDIS_HOST');
|
||||
const port = this.configService.get<number>('REDIS_PORT');
|
||||
const port = parsePositiveInteger(
|
||||
this.configService.get<string>('REDIS_PORT'),
|
||||
6379,
|
||||
);
|
||||
const password = this.configService.get<string>('REDIS_PASSWORD');
|
||||
const startupRetries = parsePositiveInteger(
|
||||
this.configService.get<string>('REDIS_STARTUP_RETRIES'),
|
||||
DEFAULT_REDIS_STARTUP_RETRIES,
|
||||
);
|
||||
const redisRequired = parseRedisRequired({
|
||||
nodeEnv: this.configService.get<string>('NODE_ENV'),
|
||||
redisRequired: this.configService.get<string>('REDIS_REQUIRED'),
|
||||
});
|
||||
|
||||
// Only set up Redis adapter if host is configured
|
||||
if (!host) {
|
||||
if (redisRequired) {
|
||||
throw new Error(
|
||||
'REDIS_REQUIRED is enabled but REDIS_HOST is not configured',
|
||||
);
|
||||
}
|
||||
|
||||
this.logger.log('Redis adapter disabled (REDIS_HOST not set)');
|
||||
return;
|
||||
}
|
||||
|
||||
this.logger.log(`Connecting Redis adapter to ${host}:${port || 6379}`);
|
||||
this.logger.log(`Connecting Redis adapter to ${host}:${port}`);
|
||||
|
||||
try {
|
||||
const connectTimeout = parsePositiveInteger(
|
||||
this.configService.get<string>('REDIS_CONNECT_TIMEOUT_MS'),
|
||||
5000,
|
||||
);
|
||||
const pubClient = new Redis({
|
||||
host,
|
||||
port: port || 6379,
|
||||
password: password,
|
||||
port,
|
||||
password,
|
||||
lazyConnect: true,
|
||||
connectTimeout,
|
||||
maxRetriesPerRequest: 1,
|
||||
enableOfflineQueue: false,
|
||||
retryStrategy(times) {
|
||||
// Retry connecting but don't crash if Redis is temporarily down during startup
|
||||
if (times > startupRetries) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return Math.min(times * 50, 2000);
|
||||
},
|
||||
});
|
||||
|
||||
const subClient = pubClient.duplicate();
|
||||
|
||||
// Wait for connection to ensure it's valid
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
pubClient.once('connect', () => {
|
||||
this.logger.log('Redis Pub client connected');
|
||||
resolve();
|
||||
});
|
||||
pubClient.once('error', (err) => {
|
||||
this.logger.error('Redis Pub client error', err);
|
||||
reject(err);
|
||||
});
|
||||
});
|
||||
await pubClient.connect();
|
||||
await subClient.connect();
|
||||
await pubClient.ping();
|
||||
await subClient.ping();
|
||||
|
||||
this.logger.log('Redis Pub/Sub clients connected');
|
||||
|
||||
// Handle subsequent errors gracefully
|
||||
pubClient.on('error', (err) => {
|
||||
@@ -73,21 +103,53 @@ export class RedisIoAdapter extends IoAdapter {
|
||||
});
|
||||
|
||||
this.adapterConstructor = createAdapter(pubClient, subClient);
|
||||
this.pubClient = pubClient;
|
||||
this.subClient = subClient;
|
||||
this.logger.log('Redis adapter initialized successfully');
|
||||
} catch (error) {
|
||||
await this.close();
|
||||
this.logger.error('Failed to initialize Redis adapter', error);
|
||||
// We don't throw here to allow the app to start without Redis if connection fails,
|
||||
// though functionality will be degraded if multiple instances are running.
|
||||
|
||||
if (redisRequired) {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
createIOServer(port: number, options?: ServerOptions): any {
|
||||
const cors = {
|
||||
origin: true,
|
||||
credentials: true,
|
||||
};
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
|
||||
const server = super.createIOServer(port, options);
|
||||
const server = super.createIOServer(port, {
|
||||
...(options ?? {}),
|
||||
cors,
|
||||
});
|
||||
if (this.adapterConstructor) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-call, @typescript-eslint/no-unsafe-member-access
|
||||
server.adapter(this.adapterConstructor);
|
||||
}
|
||||
return server;
|
||||
}
|
||||
|
||||
async close(): Promise<void> {
|
||||
const clients = [this.pubClient, this.subClient].filter(
|
||||
(client): client is Redis => client !== null,
|
||||
);
|
||||
|
||||
await Promise.all(
|
||||
clients.map(async (client) => {
|
||||
try {
|
||||
await client.quit();
|
||||
} catch {
|
||||
client.disconnect();
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
this.pubClient = null;
|
||||
this.subClient = null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,13 +6,11 @@ import { PrismaService } from '../../../database/prisma.service';
|
||||
import { UserSocketService } from '../user-socket.service';
|
||||
import { WsNotificationService } from '../ws-notification.service';
|
||||
import { WS_EVENT } from '../ws-events';
|
||||
import { UsersService } from '../../../users/users.service';
|
||||
|
||||
export class ConnectionHandler {
|
||||
constructor(
|
||||
private readonly jwtVerificationService: JwtVerificationService,
|
||||
private readonly prisma: PrismaService,
|
||||
private readonly usersService: UsersService,
|
||||
private readonly userSocketService: UserSocketService,
|
||||
private readonly wsNotificationService: WsNotificationService,
|
||||
private readonly logger: Logger,
|
||||
@@ -43,6 +41,7 @@ export class ConnectionHandler {
|
||||
// Initialize defaults
|
||||
client.data.activeDollId = null;
|
||||
client.data.friends = new Set();
|
||||
client.data.senderName = undefined;
|
||||
// userId is not set yet, it will be set in handleClientInitialize
|
||||
|
||||
this.logger.log(`WebSocket authenticated (Pending Init): ${payload.sub}`);
|
||||
@@ -94,42 +93,42 @@ export class ConnectionHandler {
|
||||
this.logger.log(
|
||||
`WebSocket authenticated via initialize fallback (Pending Init): ${payload.sub}`,
|
||||
);
|
||||
|
||||
this.logger.log(
|
||||
`WebSocket authenticated via initialize fallback (Pending Init): ${payload.sub}`,
|
||||
);
|
||||
}
|
||||
|
||||
if (!userTokenData) {
|
||||
throw new WsException('Unauthorized: No user data found');
|
||||
}
|
||||
|
||||
const user = await this.usersService.findOne(userTokenData.userId);
|
||||
|
||||
// 2. Register socket mapping (Redis Write)
|
||||
await this.userSocketService.setSocket(user.id, client.id);
|
||||
client.data.userId = user.id;
|
||||
|
||||
// 3. Fetch initial state (DB Read)
|
||||
const [userWithDoll, friends] = await Promise.all([
|
||||
// 2. Fetch initial state (DB Read)
|
||||
const [userState, friends] = await Promise.all([
|
||||
this.prisma.user.findUnique({
|
||||
where: { id: user.id },
|
||||
select: { activeDollId: true },
|
||||
where: { id: userTokenData.userId },
|
||||
select: { id: true, name: true, username: true, activeDollId: true },
|
||||
}),
|
||||
this.prisma.friendship.findMany({
|
||||
where: { userId: user.id },
|
||||
where: { userId: userTokenData.userId },
|
||||
select: { friendId: true },
|
||||
}),
|
||||
]);
|
||||
|
||||
client.data.activeDollId = userWithDoll?.activeDollId || null;
|
||||
client.data.friends = new Set(friends.map((f) => f.friendId));
|
||||
if (!userState) {
|
||||
throw new WsException('Unauthorized: No user data found');
|
||||
}
|
||||
|
||||
this.logger.log(`Client initialized: ${user.id} (${client.id})`);
|
||||
// 3. Register socket mapping (Redis Write)
|
||||
await this.userSocketService.setSocket(userState.id, client.id);
|
||||
client.data.userId = userState.id;
|
||||
|
||||
client.data.activeDollId = userState.activeDollId || null;
|
||||
client.data.friends = new Set(friends.map((f) => f.friendId));
|
||||
client.data.senderName = userState.name || userState.username;
|
||||
client.data.senderNameCachedAt = Date.now();
|
||||
|
||||
this.logger.log(`Client initialized: ${userState.id} (${client.id})`);
|
||||
|
||||
// 4. Notify client
|
||||
client.emit(WS_EVENT.INITIALIZED, {
|
||||
userId: user.id,
|
||||
userId: userState.id,
|
||||
activeDollId: client.data.activeDollId,
|
||||
});
|
||||
} catch (error) {
|
||||
@@ -157,7 +156,9 @@ export class ConnectionHandler {
|
||||
// Notify friends that this user has disconnected
|
||||
const friends = client.data.friends;
|
||||
if (friends) {
|
||||
const friendIds = Array.from(friends);
|
||||
const friendIds = Array.from(friends).filter(
|
||||
(friendId): friendId is string => typeof friendId === 'string',
|
||||
);
|
||||
const friendSockets =
|
||||
await this.userSocketService.getFriendsSockets(friendIds);
|
||||
|
||||
@@ -179,9 +180,5 @@ export class ConnectionHandler {
|
||||
this.logger.log(
|
||||
`Client id: ${client.id} disconnected (user: ${user?.userId || 'unknown'})`,
|
||||
);
|
||||
|
||||
this.logger.log(
|
||||
`Client id: ${client.id} disconnected (user: ${user?.userId || 'unknown'})`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
import { WsException } from '@nestjs/websockets';
|
||||
import type { AuthenticatedSocket } from '../../../types/socket';
|
||||
import { PrismaService } from '../../../database/prisma.service';
|
||||
import { SendInteractionDto } from '../../dto/send-interaction.dto';
|
||||
import { InteractionPayloadDto } from '../../dto/interaction-payload.dto';
|
||||
import { PrismaService } from '../../../database/prisma.service';
|
||||
import { UserSocketService } from '../user-socket.service';
|
||||
import { WsNotificationService } from '../ws-notification.service';
|
||||
import { WS_EVENT } from '../ws-events';
|
||||
import { Validator } from '../utils/validation';
|
||||
|
||||
const SENDER_NAME_CACHE_TTL_MS = 10 * 60 * 1000;
|
||||
|
||||
export class InteractionHandler {
|
||||
private readonly logger = new Logger(InteractionHandler.name);
|
||||
|
||||
@@ -18,6 +20,32 @@ export class InteractionHandler {
|
||||
private readonly wsNotificationService: WsNotificationService,
|
||||
) {}
|
||||
|
||||
private async resolveSenderName(
|
||||
client: AuthenticatedSocket,
|
||||
userId: string,
|
||||
): Promise<string> {
|
||||
const cachedName = client.data.senderName;
|
||||
const cachedAt = client.data.senderNameCachedAt;
|
||||
const cacheIsFresh =
|
||||
cachedName &&
|
||||
typeof cachedAt === 'number' &&
|
||||
Date.now() - cachedAt < SENDER_NAME_CACHE_TTL_MS;
|
||||
|
||||
if (cacheIsFresh) {
|
||||
return cachedName;
|
||||
}
|
||||
|
||||
const sender = await this.prisma.user.findUnique({
|
||||
where: { id: userId },
|
||||
select: { name: true, username: true },
|
||||
});
|
||||
|
||||
const senderName = sender?.name || sender?.username || 'Unknown';
|
||||
client.data.senderName = senderName;
|
||||
client.data.senderNameCachedAt = Date.now();
|
||||
return senderName;
|
||||
}
|
||||
|
||||
async handleSendInteraction(
|
||||
client: AuthenticatedSocket,
|
||||
data: SendInteractionDto,
|
||||
@@ -61,11 +89,7 @@ export class InteractionHandler {
|
||||
}
|
||||
|
||||
// 3. Construct payload
|
||||
const sender = await this.prisma.user.findUnique({
|
||||
where: { id: currentUserId },
|
||||
select: { name: true, username: true },
|
||||
});
|
||||
const senderName = sender?.name || sender?.username || 'Unknown';
|
||||
const senderName = await this.resolveSenderName(client, currentUserId);
|
||||
|
||||
const payload: InteractionPayloadDto = {
|
||||
senderUserId: currentUserId,
|
||||
|
||||
@@ -3,7 +3,6 @@ import { Test, TestingModule } from '@nestjs/testing';
|
||||
import { StateGateway } from './state.gateway';
|
||||
import { AuthenticatedSocket } from '../../types/socket';
|
||||
import { JwtVerificationService } from '../../auth/services/jwt-verification.service';
|
||||
import { UsersService } from '../../users/users.service';
|
||||
import { PrismaService } from '../../database/prisma.service';
|
||||
import { UserSocketService } from './user-socket.service';
|
||||
import { WsNotificationService } from './ws-notification.service';
|
||||
@@ -39,7 +38,6 @@ describe('StateGateway', () => {
|
||||
sockets: { sockets: { size: number; get: jest.Mock } };
|
||||
to: jest.Mock;
|
||||
};
|
||||
let mockUsersService: Partial<UsersService>;
|
||||
let mockJwtVerificationService: Partial<JwtVerificationService>;
|
||||
let mockPrismaService: Partial<PrismaService>;
|
||||
let mockUserSocketService: Partial<UserSocketService>;
|
||||
@@ -67,12 +65,6 @@ describe('StateGateway', () => {
|
||||
}),
|
||||
};
|
||||
|
||||
mockUsersService = {
|
||||
findOne: jest.fn().mockResolvedValue({
|
||||
id: 'user-id',
|
||||
}),
|
||||
};
|
||||
|
||||
mockJwtVerificationService = {
|
||||
extractToken: jest.fn((handshake) => handshake.auth?.token),
|
||||
verifyToken: jest.fn().mockReturnValue({
|
||||
@@ -83,7 +75,12 @@ describe('StateGateway', () => {
|
||||
|
||||
mockPrismaService = {
|
||||
user: {
|
||||
findUnique: jest.fn().mockResolvedValue({ activeDollId: 'doll-123' }),
|
||||
findUnique: jest.fn().mockResolvedValue({
|
||||
id: 'user-id',
|
||||
name: 'Test User',
|
||||
username: 'test-user',
|
||||
activeDollId: 'doll-123',
|
||||
}),
|
||||
} as any,
|
||||
friendship: {
|
||||
findMany: jest.fn().mockResolvedValue([]),
|
||||
@@ -119,7 +116,6 @@ describe('StateGateway', () => {
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
StateGateway,
|
||||
{ provide: UsersService, useValue: mockUsersService },
|
||||
{
|
||||
provide: JwtVerificationService,
|
||||
useValue: mockJwtVerificationService,
|
||||
@@ -190,7 +186,6 @@ describe('StateGateway', () => {
|
||||
);
|
||||
|
||||
// Should NOT call these anymore in handleConnection
|
||||
expect(mockUsersService.findOne).not.toHaveBeenCalled();
|
||||
expect(mockUserSocketService.setSocket).not.toHaveBeenCalled();
|
||||
|
||||
// Should set data on client
|
||||
@@ -244,6 +239,9 @@ describe('StateGateway', () => {
|
||||
|
||||
// Mock Prisma responses
|
||||
(mockPrismaService.user!.findUnique as jest.Mock).mockResolvedValue({
|
||||
id: 'user-id',
|
||||
name: 'Test User',
|
||||
username: 'test-user',
|
||||
activeDollId: 'doll-123',
|
||||
});
|
||||
(mockPrismaService.friendship!.findMany as jest.Mock).mockResolvedValue([
|
||||
@@ -255,32 +253,29 @@ describe('StateGateway', () => {
|
||||
mockClient as unknown as AuthenticatedSocket,
|
||||
);
|
||||
|
||||
// 1. Load User
|
||||
expect(mockUsersService.findOne).toHaveBeenCalledWith('test-sub');
|
||||
|
||||
// 2. Set Socket
|
||||
// 1. Set Socket
|
||||
expect(mockUserSocketService.setSocket).toHaveBeenCalledWith(
|
||||
'user-id',
|
||||
'client1',
|
||||
);
|
||||
|
||||
// 3. Fetch State (DB)
|
||||
// 2. Fetch State (DB)
|
||||
expect(mockPrismaService.user!.findUnique).toHaveBeenCalledWith({
|
||||
where: { id: 'user-id' },
|
||||
select: { activeDollId: true },
|
||||
where: { id: 'test-sub' },
|
||||
select: { id: true, name: true, username: true, activeDollId: true },
|
||||
});
|
||||
expect(mockPrismaService.friendship!.findMany).toHaveBeenCalledWith({
|
||||
where: { userId: 'user-id' },
|
||||
where: { userId: 'test-sub' },
|
||||
select: { friendId: true },
|
||||
});
|
||||
|
||||
// 4. Update Client Data
|
||||
// 3. Update Client Data
|
||||
expect(mockClient.data.userId).toBe('user-id');
|
||||
expect(mockClient.data.activeDollId).toBe('doll-123');
|
||||
expect(mockClient.data.friends).toContain('friend-1');
|
||||
expect(mockClient.data.friends).toContain('friend-2');
|
||||
|
||||
// 5. Emit Initialized
|
||||
// 4. Emit Initialized
|
||||
expect(mockClient.emit).toHaveBeenCalledWith('initialized', {
|
||||
userId: 'user-id',
|
||||
activeDollId: 'doll-123',
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Logger, Inject } from '@nestjs/common';
|
||||
import { Logger, Inject, OnModuleDestroy } from '@nestjs/common';
|
||||
import {
|
||||
OnGatewayConnection,
|
||||
OnGatewayDisconnect,
|
||||
@@ -22,7 +22,6 @@ import { PrismaService } from '../../database/prisma.service';
|
||||
import { UserSocketService } from './user-socket.service';
|
||||
import { WsNotificationService } from './ws-notification.service';
|
||||
import { WS_EVENT, REDIS_CHANNEL } from './ws-events';
|
||||
import { UsersService } from '../../users/users.service';
|
||||
import { ConnectionHandler } from './connection/handler';
|
||||
import { CursorHandler } from './cursor/handler';
|
||||
import { StatusHandler } from './status/handler';
|
||||
@@ -31,14 +30,13 @@ import { RedisHandler } from './utils/redis-handler';
|
||||
import { Broadcaster } from './utils/broadcasting';
|
||||
import { Throttler } from './utils/throttling';
|
||||
|
||||
@WebSocketGateway({
|
||||
cors: {
|
||||
origin: true,
|
||||
credentials: true,
|
||||
},
|
||||
})
|
||||
@WebSocketGateway()
|
||||
export class StateGateway
|
||||
implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect
|
||||
implements
|
||||
OnGatewayInit,
|
||||
OnGatewayConnection,
|
||||
OnGatewayDisconnect,
|
||||
OnModuleDestroy
|
||||
{
|
||||
private readonly logger = new Logger(StateGateway.name);
|
||||
|
||||
@@ -55,7 +53,6 @@ export class StateGateway
|
||||
constructor(
|
||||
private readonly jwtVerificationService: JwtVerificationService,
|
||||
private readonly prisma: PrismaService,
|
||||
private readonly usersService: UsersService,
|
||||
private readonly userSocketService: UserSocketService,
|
||||
private readonly wsNotificationService: WsNotificationService,
|
||||
@Inject(REDIS_CLIENT) private readonly redisClient: Redis | null,
|
||||
@@ -70,7 +67,6 @@ export class StateGateway
|
||||
this.connectionHandler = new ConnectionHandler(
|
||||
this.jwtVerificationService,
|
||||
this.prisma,
|
||||
this.usersService,
|
||||
this.userSocketService,
|
||||
this.wsNotificationService,
|
||||
this.logger,
|
||||
@@ -136,6 +132,12 @@ export class StateGateway
|
||||
}
|
||||
}
|
||||
|
||||
onModuleDestroy() {
|
||||
if (this.redisSubscriber) {
|
||||
this.redisSubscriber.removeAllListeners('message');
|
||||
}
|
||||
}
|
||||
|
||||
async isUserOnline(userId: string): Promise<boolean> {
|
||||
return this.userSocketService.isUserOnline(userId);
|
||||
}
|
||||
|
||||
@@ -48,13 +48,27 @@ export class WsNotificationService {
|
||||
action: 'add' | 'delete',
|
||||
) {
|
||||
if (this.redisClient) {
|
||||
await this.redisClient.publish(
|
||||
REDIS_CHANNEL.FRIEND_CACHE_UPDATE,
|
||||
JSON.stringify({ userId, friendId, action }),
|
||||
);
|
||||
} else {
|
||||
// Fallback: update locally
|
||||
try {
|
||||
await this.redisClient.publish(
|
||||
REDIS_CHANNEL.FRIEND_CACHE_UPDATE,
|
||||
JSON.stringify({ userId, friendId, action }),
|
||||
);
|
||||
return;
|
||||
} catch (error) {
|
||||
this.logger.warn(
|
||||
'Redis publish failed for friend cache update; applying local cache update only',
|
||||
error as Error,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
await this.updateFriendsCacheLocal(userId, friendId, action);
|
||||
} catch (error) {
|
||||
this.logger.error(
|
||||
'Failed to apply local friend cache update',
|
||||
error as Error,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,13 +103,27 @@ export class WsNotificationService {
|
||||
|
||||
async publishActiveDollUpdate(userId: string, dollId: string | null) {
|
||||
if (this.redisClient) {
|
||||
await this.redisClient.publish(
|
||||
REDIS_CHANNEL.ACTIVE_DOLL_UPDATE,
|
||||
JSON.stringify({ userId, dollId }),
|
||||
);
|
||||
} else {
|
||||
// Fallback: update locally
|
||||
try {
|
||||
await this.redisClient.publish(
|
||||
REDIS_CHANNEL.ACTIVE_DOLL_UPDATE,
|
||||
JSON.stringify({ userId, dollId }),
|
||||
);
|
||||
return;
|
||||
} catch (error) {
|
||||
this.logger.warn(
|
||||
'Redis publish failed for active doll update; applying local cache update only',
|
||||
error as Error,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
await this.updateActiveDollCache(userId, dollId);
|
||||
} catch (error) {
|
||||
this.logger.error(
|
||||
'Failed to apply local active doll cache update',
|
||||
error as Error,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user