Minor refinements & tuning to system structure

This commit is contained in:
2025-12-23 15:57:05 +08:00
parent 6c63f2d803
commit d06a58cf93
8 changed files with 394 additions and 87 deletions

View File

@@ -18,10 +18,12 @@ interface MockSocket extends Partial<AuthenticatedSocket> {
picture?: string;
};
userId?: string;
activeDollId?: string | null;
friends?: Set<string>;
};
handshake?: any;
disconnect?: jest.Mock;
emit?: jest.Mock;
}
describe('StateGateway', () => {
@@ -29,6 +31,7 @@ describe('StateGateway', () => {
let mockLoggerLog: jest.SpyInstance;
let mockLoggerDebug: jest.SpyInstance;
let mockLoggerWarn: jest.SpyInstance;
let mockLoggerError: jest.SpyInstance;
let mockServer: {
sockets: { sockets: { size: number; get: jest.Mock } };
to: jest.Mock;
@@ -37,6 +40,8 @@ describe('StateGateway', () => {
let mockJwtVerificationService: Partial<JwtVerificationService>;
let mockPrismaService: Partial<PrismaService>;
let mockUserSocketService: Partial<UserSocketService>;
let mockRedisClient: { publish: jest.Mock };
let mockRedisSubscriber: { subscribe: jest.Mock; on: jest.Mock };
beforeEach(async () => {
mockServer = {
@@ -67,9 +72,12 @@ describe('StateGateway', () => {
};
mockPrismaService = {
user: {
findUnique: jest.fn().mockResolvedValue({ activeDollId: 'doll-123' }),
} as any,
friendship: {
findMany: jest.fn().mockResolvedValue([]),
},
} as any,
};
mockUserSocketService = {
@@ -80,6 +88,15 @@ describe('StateGateway', () => {
getFriendsSockets: jest.fn().mockResolvedValue([]),
};
mockRedisClient = {
publish: jest.fn().mockResolvedValue(1),
};
mockRedisSubscriber = {
subscribe: jest.fn().mockResolvedValue(undefined),
on: jest.fn(),
};
const module: TestingModule = await Test.createTestingModule({
providers: [
StateGateway,
@@ -90,6 +107,8 @@ describe('StateGateway', () => {
},
{ provide: PrismaService, useValue: mockPrismaService },
{ provide: UserSocketService, useValue: mockUserSocketService },
{ provide: 'REDIS_CLIENT', useValue: mockRedisClient },
{ provide: 'REDIS_SUBSCRIBER_CLIENT', useValue: mockRedisSubscriber },
],
}).compile();
@@ -101,6 +120,9 @@ describe('StateGateway', () => {
.spyOn(gateway['logger'], 'debug')
.mockImplementation();
mockLoggerWarn = jest.spyOn(gateway['logger'], 'warn').mockImplementation();
mockLoggerError = jest
.spyOn(gateway['logger'], 'error')
.mockImplementation();
});
afterEach(() => {
@@ -114,20 +136,27 @@ describe('StateGateway', () => {
describe('afterInit', () => {
it('should log initialization message', () => {
gateway.afterInit();
expect(mockLoggerLog).toHaveBeenCalledWith('Initialized');
});
it('should subscribe to redis channel', () => {
expect(mockRedisSubscriber.subscribe).toHaveBeenCalledWith(
'active-doll-update',
expect.any(Function),
);
});
});
describe('handleConnection', () => {
it('should log client connection and sync user when authenticated', async () => {
it('should verify token and set basic user data (but NOT sync DB)', async () => {
const mockClient: MockSocket = {
id: 'client1',
data: { user: { keycloakSub: 'test-sub' } },
data: {},
handshake: {
auth: { token: 'mock-token' },
headers: {},
},
disconnect: jest.fn(),
};
await gateway.handleConnection(
@@ -140,20 +169,21 @@ describe('StateGateway', () => {
expect(mockJwtVerificationService.verifyToken).toHaveBeenCalledWith(
'mock-token',
);
expect(mockAuthService.syncUserFromToken).toHaveBeenCalledWith(
// Should NOT call these anymore in handleConnection
expect(mockAuthService.syncUserFromToken).not.toHaveBeenCalled();
expect(mockUserSocketService.setSocket).not.toHaveBeenCalled();
// Should set data on client
expect(mockClient.data.user).toEqual(
expect.objectContaining({
keycloakSub: 'test-sub',
}),
);
expect(mockUserSocketService.setSocket).toHaveBeenCalledWith(
'user-id',
'client1',
);
expect(mockClient.data.activeDollId).toBeNull();
expect(mockLoggerLog).toHaveBeenCalledWith(
`Client id: ${mockClient.id} connected (user: test-sub)`,
);
expect(mockLoggerDebug).toHaveBeenCalledWith(
'Number of connected clients: 5',
expect.stringContaining('WebSocket authenticated (Pending Init)'),
);
});
@@ -183,6 +213,83 @@ describe('StateGateway', () => {
});
});
describe('handleClientInitialize', () => {
it('should sync user, fetch state, and emit initialized event', async () => {
const mockClient: MockSocket = {
id: 'client1',
data: {
user: { keycloakSub: 'test-sub' },
friends: new Set(),
},
emit: jest.fn(),
disconnect: jest.fn(),
};
// Mock Prisma responses
(mockPrismaService.user!.findUnique as jest.Mock).mockResolvedValue({
activeDollId: 'doll-123',
});
(mockPrismaService.friendship!.findMany as jest.Mock).mockResolvedValue([
{ friendId: 'friend-1' },
{ friendId: 'friend-2' },
]);
await gateway.handleClientInitialize(
mockClient as unknown as AuthenticatedSocket,
);
// 1. Sync User
expect(mockAuthService.syncUserFromToken).toHaveBeenCalledWith(
mockClient.data.user,
);
// 2. Set Socket
expect(mockUserSocketService.setSocket).toHaveBeenCalledWith(
'user-id',
'client1',
);
// 3. Fetch State (DB)
expect(mockPrismaService.user!.findUnique).toHaveBeenCalledWith({
where: { id: 'user-id' },
select: { activeDollId: true },
});
expect(mockPrismaService.friendship!.findMany).toHaveBeenCalledWith({
where: { userId: 'user-id' },
select: { friendId: true },
});
// 4. Update Client Data
expect(mockClient.data.userId).toBe('user-id');
expect(mockClient.data.activeDollId).toBe('doll-123');
expect(mockClient.data.friends).toContain('friend-1');
expect(mockClient.data.friends).toContain('friend-2');
// 5. Emit Initialized
expect(mockClient.emit).toHaveBeenCalledWith('initialized', {
userId: 'user-id',
activeDollId: 'doll-123',
});
});
it('should disconnect if no user data present on socket', async () => {
const mockClient: MockSocket = {
id: 'client1',
data: {}, // Missing user data
disconnect: jest.fn(),
};
await gateway.handleClientInitialize(
mockClient as unknown as AuthenticatedSocket,
);
expect(mockLoggerError).toHaveBeenCalledWith(
expect.stringContaining('Unauthorized: No user data found'),
);
expect(mockClient.disconnect).toHaveBeenCalled();
});
});
describe('handleDisconnect', () => {
it('should log client disconnection', async () => {
const mockClient: MockSocket = {
@@ -250,6 +357,7 @@ describe('StateGateway', () => {
data: {
user: { keycloakSub: 'test-sub' },
userId: 'user-1',
activeDollId: 'doll-1', // User must have active doll
friends: new Set(['friend-1']),
},
};
@@ -261,6 +369,9 @@ describe('StateGateway', () => {
const data: CursorPositionDto = { x: 100, y: 200 };
// Force time to pass for throttle check if needed, or rely on first call passing
// The implementation uses lastBroadcastMap, initialized to empty, so first call should pass if now > 0
await gateway.handleCursorReportPosition(
mockClient as unknown as AuthenticatedSocket,
data,
@@ -276,21 +387,17 @@ describe('StateGateway', () => {
});
});
it('should not emit when no friends are online', async () => {
it('should NOT emit if user has no active doll', async () => {
const mockClient: MockSocket = {
id: 'client1',
data: {
user: { keycloakSub: 'test-sub' },
userId: 'user-1',
activeDollId: null, // No doll
friends: new Set(['friend-1']),
},
};
// Mock getFriendsSockets to return empty array
(mockUserSocketService.getFriendsSockets as jest.Mock).mockResolvedValue(
[],
);
const data: CursorPositionDto = { x: 100, y: 200 };
await gateway.handleCursorReportPosition(
@@ -298,11 +405,10 @@ describe('StateGateway', () => {
data,
);
// Verify that no message was emitted
expect(mockServer.to).not.toHaveBeenCalled();
});
it('should log warning when userId is missing', async () => {
it('should return early when userId is missing (not initialized)', async () => {
const mockClient: MockSocket = {
id: 'client1',
data: {
@@ -319,12 +425,9 @@ describe('StateGateway', () => {
data,
);
// Verify that a warning was logged
expect(mockLoggerWarn).toHaveBeenCalledWith(
`Could not find user ID for client ${mockClient.id}`,
);
// Verify that no message was emitted
expect(mockServer.to).not.toHaveBeenCalled();
// No explicit warning log expected in new implementation for just return
});
it('should throw exception when client is not authenticated', async () => {

View File

@@ -1,4 +1,4 @@
import { Logger } from '@nestjs/common';
import { Logger, Inject } from '@nestjs/common';
import {
OnGatewayConnection,
OnGatewayDisconnect,
@@ -9,8 +9,12 @@ import {
WsException,
} from '@nestjs/websockets';
import { OnEvent } from '@nestjs/event-emitter';
import Redis from 'ioredis';
import type { Server } from 'socket.io';
import {
REDIS_CLIENT,
REDIS_SUBSCRIBER_CLIENT,
} from '../../database/redis.module';
import type { AuthenticatedSocket } from '../../types/socket';
import { AuthService } from '../../auth/auth.service';
import { JwtVerificationService } from '../../auth/services/jwt-verification.service';
@@ -19,6 +23,7 @@ import { PrismaService } from '../../database/prisma.service';
import { UserSocketService } from './user-socket.service';
import { FriendEvents } from '../../friends/events/friend.events';
import type {
FriendRequestReceivedEvent,
FriendRequestAcceptedEvent,
@@ -37,6 +42,8 @@ import { UserEvents } from '../../users/events/user.events';
import type { UserActiveDollChangedEvent } from '../../users/events/user.events';
const WS_EVENT = {
CLIENT_INITIALIZE: 'client-initialize',
INITIALIZED: 'initialized',
CURSOR_REPORT_POSITION: 'cursor-report-position',
FRIEND_REQUEST_RECEIVED: 'friend-request-received',
FRIEND_REQUEST_ACCEPTED: 'friend-request-accepted',
@@ -50,6 +57,10 @@ const WS_EVENT = {
FRIEND_ACTIVE_DOLL_CHANGED: 'friend-active-doll-changed',
} as const;
const REDIS_CHANNEL = {
ACTIVE_DOLL_UPDATE: 'active-doll-update',
} as const;
@WebSocketGateway({
cors: {
origin: true,
@@ -69,12 +80,78 @@ export class StateGateway
private readonly jwtVerificationService: JwtVerificationService,
private readonly prisma: PrismaService,
private readonly userSocketService: UserSocketService,
) {}
@Inject(REDIS_CLIENT) private readonly redisClient: Redis | null,
@Inject(REDIS_SUBSCRIBER_CLIENT)
private readonly redisSubscriber: Redis | null,
) {
// Setup Redis subscription for cross-instance communication
if (this.redisSubscriber) {
this.redisSubscriber
.subscribe(REDIS_CHANNEL.ACTIVE_DOLL_UPDATE, (err) => {
if (err) {
this.logger.error(
`Failed to subscribe to ${REDIS_CHANNEL.ACTIVE_DOLL_UPDATE}`,
err,
);
} else {
this.logger.log(
`Subscribed to ${REDIS_CHANNEL.ACTIVE_DOLL_UPDATE} channel`,
);
}
})
.catch((err) => {
this.logger.error(
`Error subscribing to ${REDIS_CHANNEL.ACTIVE_DOLL_UPDATE}`,
err,
);
});
this.redisSubscriber.on('message', (channel, message) => {
if (channel === REDIS_CHANNEL.ACTIVE_DOLL_UPDATE) {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
this.handleActiveDollUpdateMessage(message);
}
});
}
}
afterInit() {
this.logger.log('Initialized');
}
private async handleActiveDollUpdateMessage(message: string) {
try {
const data = JSON.parse(message) as {
userId: string;
dollId: string | null;
};
const { userId, dollId } = data;
// Check if the user is connected to THIS instance
// Note: We need a local way to check if we hold the socket connection.
// io.sockets.sockets is a Map of all connected sockets on this server instance.
// We first get the socket ID from the shared store (UserSocketService)
// to see which socket ID belongs to the user.
const socketId = await this.userSocketService.getSocket(userId);
if (socketId) {
// Now check if we actually have this socket locally
const localSocket = this.io.sockets.sockets.get(socketId);
if (localSocket) {
// We own this connection! Update the local state.
const authSocket = localSocket as AuthenticatedSocket;
authSocket.data.activeDollId = dollId;
this.logger.debug(
`Updated activeDollId locally for user ${userId} to ${dollId}`,
);
}
}
} catch (error) {
this.logger.error('Error handling redis message', error);
}
}
async handleConnection(client: AuthenticatedSocket) {
try {
this.logger.debug(
@@ -106,25 +183,12 @@ export class StateGateway
picture: payload.picture,
};
this.logger.log(`WebSocket authenticated: ${payload.sub}`);
// Initialize defaults
client.data.activeDollId = null;
client.data.friends = new Set();
// userId is not set yet, it will be set in handleClientInitialize
const user = await this.authService.syncUserFromToken(client.data.user);
await this.userSocketService.setSocket(user.id, client.id);
client.data.userId = user.id;
// Sync active doll state to socket
const userWithDoll = await this.prisma.user.findUnique({
where: { id: user.id },
select: { activeDollId: true },
});
client.data.activeDollId = userWithDoll?.activeDollId || null;
// Initialize friends cache using Prisma directly
const friends = await this.prisma.friendship.findMany({
where: { userId: user.id },
select: { friendId: true },
});
client.data.friends = new Set(friends.map((f) => f.friendId));
this.logger.log(`WebSocket authenticated (Pending Init): ${payload.sub}`);
const { sockets } = this.io.sockets;
this.logger.log(
@@ -139,6 +203,49 @@ export class StateGateway
}
}
@SubscribeMessage(WS_EVENT.CLIENT_INITIALIZE)
async handleClientInitialize(client: AuthenticatedSocket) {
try {
const userTokenData = client.data.user;
if (!userTokenData) {
throw new WsException('Unauthorized: No user data found');
}
// 1. Sync user from token (DB Write/Read)
const user = await this.authService.syncUserFromToken(userTokenData);
// 2. Register socket mapping (Redis Write)
await this.userSocketService.setSocket(user.id, client.id);
client.data.userId = user.id;
// 3. Fetch initial state (DB Read)
const [userWithDoll, friends] = await Promise.all([
this.prisma.user.findUnique({
where: { id: user.id },
select: { activeDollId: true },
}),
this.prisma.friendship.findMany({
where: { userId: user.id },
select: { friendId: true },
}),
]);
client.data.activeDollId = userWithDoll?.activeDollId || null;
client.data.friends = new Set(friends.map((f) => f.friendId));
this.logger.log(`Client initialized: ${user.id} (${client.id})`);
// 4. Notify client
client.emit(WS_EVENT.INITIALIZED, {
userId: user.id,
activeDollId: client.data.activeDollId,
});
} catch (error) {
this.logger.error(`Initialization error: ${error}`);
client.disconnect();
}
}
async handleDisconnect(client: AuthenticatedSocket) {
const user = client.data.user;
@@ -167,9 +274,7 @@ export class StateGateway
}
}
}
// Note: We can't iterate over Redis keys easily to find socketId match without userId
// The previous fallback loop over map entries is not efficient with Redis.
// We rely on client.data.userId being set correctly during connection.
// If userId is undefined, client never initialized, so no cleanup needed
}
this.logger.log(
@@ -194,13 +299,13 @@ export class StateGateway
const currentUserId = client.data.userId;
// Do not broadcast cursor position if user has no active doll
if (!client.data.activeDollId) {
if (!currentUserId) {
// User has not initialized yet
return;
}
if (!currentUserId) {
this.logger.warn(`Could not find user ID for client ${client.id}`);
// Do not broadcast cursor position if user has no active doll
if (!client.data.activeDollId) {
return;
}
@@ -408,14 +513,23 @@ export class StateGateway
async handleActiveDollChanged(payload: UserActiveDollChangedEvent) {
const { userId, dollId, doll } = payload;
// 1. Update the user's socket data to reflect the change
const socketId = await this.userSocketService.getSocket(userId);
if (socketId) {
const userSocket = this.io.sockets.sockets.get(
socketId,
) as AuthenticatedSocket;
if (userSocket) {
userSocket.data.activeDollId = dollId;
// 1. Publish update to all instances via Redis so they can update local socket state
if (this.redisClient) {
await this.redisClient.publish(
REDIS_CHANNEL.ACTIVE_DOLL_UPDATE,
JSON.stringify({ userId, dollId }),
);
} else {
// Fallback for single instance (no redis) - update locally directly
// This mimics what handleActiveDollUpdateMessage does
const socketId = await this.userSocketService.getSocket(userId);
if (socketId) {
const userSocket = this.io.sockets.sockets.get(
socketId,
) as AuthenticatedSocket;
if (userSocket) {
userSocket.data.activeDollId = dollId;
}
}
}