refactored ws module

This commit is contained in:
2026-01-10 13:48:22 +08:00
parent d1f2f9089e
commit 7e7e21c0e6
11 changed files with 475 additions and 330 deletions

View File

@@ -6,6 +6,7 @@ import { AuthService } from '../../auth/auth.service';
import { JwtVerificationService } from '../../auth/services/jwt-verification.service';
import { PrismaService } from '../../database/prisma.service';
import { UserSocketService } from './user-socket.service';
import { WsNotificationService } from './ws-notification.service';
interface MockSocket extends Partial<AuthenticatedSocket> {
id: string;
@@ -42,6 +43,14 @@ describe('StateGateway', () => {
let mockUserSocketService: Partial<UserSocketService>;
let mockRedisClient: { publish: jest.Mock };
let mockRedisSubscriber: { subscribe: jest.Mock; on: jest.Mock };
let mockWsNotificationService: {
setIo: jest.Mock;
emitToUser: jest.Mock;
emitToFriends: jest.Mock;
emitToSocket: jest.Mock;
updateActiveDollCache: jest.Mock;
publishActiveDollUpdate: jest.Mock;
};
beforeEach(async () => {
mockServer = {
@@ -97,6 +106,15 @@ describe('StateGateway', () => {
on: jest.fn(),
};
mockWsNotificationService = {
setIo: jest.fn(),
emitToUser: jest.fn(),
emitToFriends: jest.fn(),
emitToSocket: jest.fn(),
updateActiveDollCache: jest.fn(),
publishActiveDollUpdate: jest.fn(),
};
const module: TestingModule = await Test.createTestingModule({
providers: [
StateGateway,
@@ -107,6 +125,7 @@ describe('StateGateway', () => {
},
{ provide: PrismaService, useValue: mockPrismaService },
{ provide: UserSocketService, useValue: mockUserSocketService },
{ provide: WsNotificationService, useValue: mockWsNotificationService },
{ provide: 'REDIS_CLIENT', useValue: mockRedisClient },
{ provide: 'REDIS_SUBSCRIBER_CLIENT', useValue: mockRedisSubscriber },
],
@@ -142,6 +161,7 @@ describe('StateGateway', () => {
it('should subscribe to redis channel', () => {
expect(mockRedisSubscriber.subscribe).toHaveBeenCalledWith(
'active-doll-update',
'friend-cache-update',
expect.any(Function),
);
});
@@ -348,7 +368,11 @@ describe('StateGateway', () => {
expect(mockUserSocketService.removeSocket).toHaveBeenCalledWith(
'user-id',
);
expect(mockServer.to).toHaveBeenCalledWith('friend-socket-id');
expect(mockWsNotificationService.emitToSocket).toHaveBeenCalledWith(
'friend-socket-id',
expect.any(String),
expect.any(Object),
);
});
});
@@ -379,14 +403,15 @@ describe('StateGateway', () => {
data,
);
// Verify that the message was emitted to the friend
expect(mockServer.to).toHaveBeenCalledWith('friend-socket-id');
// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access
const emitMock = mockServer.to().emit as jest.Mock;
expect(emitMock).toHaveBeenCalledWith('friend-cursor-position', {
userId: 'user-1',
position: data,
});
// Verify that message was emitted via WsNotificationService
expect(mockWsNotificationService.emitToSocket).toHaveBeenCalledWith(
'friend-socket-id',
'friend-cursor-position',
{
userId: 'user-1',
position: data,
},
);
});
it('should NOT emit if user has no active doll', async () => {

View File

@@ -8,7 +8,6 @@ import {
WebSocketServer,
WsException,
} from '@nestjs/websockets';
import { OnEvent } from '@nestjs/event-emitter';
import Redis from 'ioredis';
import type { Server } from 'socket.io';
import {
@@ -21,45 +20,8 @@ import { JwtVerificationService } from '../../auth/services/jwt-verification.ser
import { CursorPositionDto } from '../dto/cursor-position.dto';
import { PrismaService } from '../../database/prisma.service';
import { UserSocketService } from './user-socket.service';
import { FriendEvents } from '../../friends/events/friend.events';
import type {
FriendRequestReceivedEvent,
FriendRequestAcceptedEvent,
FriendRequestDeniedEvent,
UnfriendedEvent,
} from '../../friends/events/friend.events';
import { DollEvents } from '../../dolls/events/doll.events';
import type {
DollCreatedEvent,
DollUpdatedEvent,
DollDeletedEvent,
} from '../../dolls/events/doll.events';
import { UserEvents } from '../../users/events/user.events';
import type { UserActiveDollChangedEvent } from '../../users/events/user.events';
const WS_EVENT = {
CLIENT_INITIALIZE: 'client-initialize',
INITIALIZED: 'initialized',
CURSOR_REPORT_POSITION: 'cursor-report-position',
FRIEND_REQUEST_RECEIVED: 'friend-request-received',
FRIEND_REQUEST_ACCEPTED: 'friend-request-accepted',
FRIEND_REQUEST_DENIED: 'friend-request-denied',
UNFRIENDED: 'unfriended',
FRIEND_CURSOR_POSITION: 'friend-cursor-position',
FRIEND_DISCONNECTED: 'friend-disconnected',
FRIEND_DOLL_CREATED: 'friend-doll-created',
FRIEND_DOLL_UPDATED: 'friend-doll-updated',
FRIEND_DOLL_DELETED: 'friend-doll-deleted',
FRIEND_ACTIVE_DOLL_CHANGED: 'friend-active-doll-changed',
} as const;
const REDIS_CHANNEL = {
ACTIVE_DOLL_UPDATE: 'active-doll-update',
} as const;
import { WsNotificationService } from './ws-notification.service';
import { WS_EVENT, REDIS_CHANNEL } from './ws-events';
@WebSocketGateway({
cors: {
@@ -80,6 +42,7 @@ export class StateGateway
private readonly jwtVerificationService: JwtVerificationService,
private readonly prisma: PrismaService,
private readonly userSocketService: UserSocketService,
private readonly wsNotificationService: WsNotificationService,
@Inject(REDIS_CLIENT) private readonly redisClient: Redis | null,
@Inject(REDIS_SUBSCRIBER_CLIENT)
private readonly redisSubscriber: Redis | null,
@@ -87,29 +50,28 @@ export class StateGateway
// Setup Redis subscription for cross-instance communication
if (this.redisSubscriber) {
this.redisSubscriber
.subscribe(REDIS_CHANNEL.ACTIVE_DOLL_UPDATE, (err) => {
if (err) {
this.logger.error(
`Failed to subscribe to ${REDIS_CHANNEL.ACTIVE_DOLL_UPDATE}`,
err,
);
} else {
this.logger.log(
`Subscribed to ${REDIS_CHANNEL.ACTIVE_DOLL_UPDATE} channel`,
);
}
})
.subscribe(
REDIS_CHANNEL.ACTIVE_DOLL_UPDATE,
REDIS_CHANNEL.FRIEND_CACHE_UPDATE,
(err) => {
if (err) {
this.logger.error(`Failed to subscribe to Redis channels`, err);
} else {
this.logger.log(`Subscribed to Redis channels`);
}
},
)
.catch((err) => {
this.logger.error(
`Error subscribing to ${REDIS_CHANNEL.ACTIVE_DOLL_UPDATE}`,
err,
);
this.logger.error(`Error subscribing to Redis channels`, err);
});
this.redisSubscriber.on('message', (channel, message) => {
if (channel === REDIS_CHANNEL.ACTIVE_DOLL_UPDATE) {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
this.handleActiveDollUpdateMessage(message);
} else if (channel === REDIS_CHANNEL.FRIEND_CACHE_UPDATE) {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
this.handleFriendCacheUpdateMessage(message);
}
});
}
@@ -117,6 +79,7 @@ export class StateGateway
afterInit() {
this.logger.log('Initialized');
this.wsNotificationService.setIo(this.io);
}
private async handleActiveDollUpdateMessage(message: string) {
@@ -126,29 +89,27 @@ export class StateGateway
dollId: string | null;
};
const { userId, dollId } = data;
// Check if the user is connected to THIS instance
// Note: We need a local way to check if we hold the socket connection.
// io.sockets.sockets is a Map of all connected sockets on this server instance.
// We first get the socket ID from the shared store (UserSocketService)
// to see which socket ID belongs to the user.
const socketId = await this.userSocketService.getSocket(userId);
if (socketId) {
// Now check if we actually have this socket locally
const localSocket = this.io.sockets.sockets.get(socketId);
if (localSocket) {
// We own this connection! Update the local state.
const authSocket = localSocket as AuthenticatedSocket;
authSocket.data.activeDollId = dollId;
this.logger.debug(
`Updated activeDollId locally for user ${userId} to ${dollId}`,
);
}
}
await this.wsNotificationService.updateActiveDollCache(userId, dollId);
} catch (error) {
this.logger.error('Error handling redis message', error);
this.logger.error('Error handling active doll update message', error);
}
}
private async handleFriendCacheUpdateMessage(message: string) {
try {
const data = JSON.parse(message) as {
userId: string;
friendId: string;
action: 'add' | 'delete';
};
const { userId, friendId, action } = data;
await this.wsNotificationService.updateFriendsCacheLocal(
userId,
friendId,
action,
);
} catch (error) {
this.logger.error('Error handling friend cache update message', error);
}
}
@@ -302,9 +263,13 @@ export class StateGateway
await this.userSocketService.getFriendsSockets(friendIds);
for (const { socketId } of friendSockets) {
this.io.to(socketId).emit(WS_EVENT.FRIEND_DISCONNECTED, {
userId: userId,
});
this.wsNotificationService.emitToSocket(
socketId,
WS_EVENT.FRIEND_DISCONNECTED,
{
userId: userId,
},
);
}
}
}
@@ -363,238 +328,12 @@ export class StateGateway
userId: currentUserId,
position: data,
};
this.io.to(socketId).emit(WS_EVENT.FRIEND_CURSOR_POSITION, payload);
}
}
}
@OnEvent(FriendEvents.REQUEST_RECEIVED)
async handleFriendRequestReceived(payload: FriendRequestReceivedEvent) {
const { userId, friendRequest } = payload;
const socketId = await this.userSocketService.getSocket(userId);
if (socketId) {
this.io.to(socketId).emit(WS_EVENT.FRIEND_REQUEST_RECEIVED, {
id: friendRequest.id,
sender: {
id: friendRequest.sender.id,
name: friendRequest.sender.name,
username: friendRequest.sender.username,
picture: friendRequest.sender.picture,
},
createdAt: friendRequest.createdAt,
});
this.logger.debug(
`Emitted friend request notification to user ${userId}`,
);
}
}
@OnEvent(FriendEvents.REQUEST_ACCEPTED)
async handleFriendRequestAccepted(payload: FriendRequestAcceptedEvent) {
const { userId, friendRequest } = payload;
const socketId = await this.userSocketService.getSocket(userId);
// 1. Update cache for the user who sent the request (userId / friendRequest.senderId)
if (socketId) {
const senderSocket = this.io.sockets.sockets.get(
socketId,
) as AuthenticatedSocket;
if (senderSocket && senderSocket.data.friends) {
senderSocket.data.friends.add(friendRequest.receiverId);
}
this.io.to(socketId).emit(WS_EVENT.FRIEND_REQUEST_ACCEPTED, {
id: friendRequest.id,
friend: {
id: friendRequest.receiver.id,
name: friendRequest.receiver.name,
username: friendRequest.receiver.username,
picture: friendRequest.receiver.picture,
},
acceptedAt: friendRequest.updatedAt,
});
this.logger.debug(
`Emitted friend request accepted notification to user ${userId}`,
);
}
// 2. Update cache for the user who accepted the request (friendRequest.receiverId)
const receiverSocketId = await this.userSocketService.getSocket(
friendRequest.receiverId,
);
if (receiverSocketId) {
const receiverSocket = this.io.sockets.sockets.get(
receiverSocketId,
) as AuthenticatedSocket;
if (receiverSocket && receiverSocket.data.friends) {
receiverSocket.data.friends.add(friendRequest.senderId);
}
}
}
@OnEvent(FriendEvents.REQUEST_DENIED)
async handleFriendRequestDenied(payload: FriendRequestDeniedEvent) {
const { userId, friendRequest } = payload;
const socketId = await this.userSocketService.getSocket(userId);
if (socketId) {
this.io.to(socketId).emit(WS_EVENT.FRIEND_REQUEST_DENIED, {
id: friendRequest.id,
denier: {
id: friendRequest.receiver.id,
name: friendRequest.receiver.name,
username: friendRequest.receiver.username,
picture: friendRequest.receiver.picture,
},
deniedAt: friendRequest.updatedAt,
});
this.logger.debug(
`Emitted friend request denied notification to user ${userId}`,
);
}
}
@OnEvent(FriendEvents.UNFRIENDED)
async handleUnfriended(payload: UnfriendedEvent) {
const { userId, friendId } = payload;
const socketId = await this.userSocketService.getSocket(userId);
// 1. Update cache for the user receiving the notification (userId)
if (socketId) {
const socket = this.io.sockets.sockets.get(
socketId,
) as AuthenticatedSocket;
if (socket && socket.data.friends) {
socket.data.friends.delete(friendId);
}
this.io.to(socketId).emit(WS_EVENT.UNFRIENDED, {
friendId,
});
this.logger.debug(`Emitted unfriended notification to user ${userId}`);
}
// 2. Update cache for the user initiating the unfriend (friendId)
const initiatorSocketId = await this.userSocketService.getSocket(friendId);
if (initiatorSocketId) {
const initiatorSocket = this.io.sockets.sockets.get(
initiatorSocketId,
) as AuthenticatedSocket;
if (initiatorSocket && initiatorSocket.data.friends) {
initiatorSocket.data.friends.delete(userId);
}
}
}
@OnEvent(DollEvents.DOLL_CREATED)
async handleDollCreated(payload: DollCreatedEvent) {
const { userId, doll } = payload;
const friendSockets = await this.userSocketService.getFriendsSockets([
userId,
]);
for (const { socketId } of friendSockets) {
this.io.to(socketId).emit(WS_EVENT.FRIEND_DOLL_CREATED, {
friendId: userId,
doll: {
id: doll.id,
name: doll.name,
configuration: doll.configuration,
createdAt: doll.createdAt,
updatedAt: doll.updatedAt,
},
});
}
}
@OnEvent(DollEvents.DOLL_UPDATED)
async handleDollUpdated(payload: DollUpdatedEvent) {
const { userId, doll } = payload;
const friendSockets = await this.userSocketService.getFriendsSockets([
userId,
]);
for (const { socketId } of friendSockets) {
this.io.to(socketId).emit(WS_EVENT.FRIEND_DOLL_UPDATED, {
friendId: userId,
doll: {
id: doll.id,
name: doll.name,
configuration: doll.configuration,
createdAt: doll.createdAt,
updatedAt: doll.updatedAt,
},
});
}
}
@OnEvent(DollEvents.DOLL_DELETED)
async handleDollDeleted(payload: DollDeletedEvent) {
const { userId, dollId } = payload;
const friendSockets = await this.userSocketService.getFriendsSockets([
userId,
]);
for (const { socketId } of friendSockets) {
this.io.to(socketId).emit(WS_EVENT.FRIEND_DOLL_DELETED, {
friendId: userId,
dollId,
});
}
}
@OnEvent(UserEvents.ACTIVE_DOLL_CHANGED)
async handleActiveDollChanged(payload: UserActiveDollChangedEvent) {
const { userId, dollId, doll } = payload;
// 1. Publish update to all instances via Redis so they can update local socket state
if (this.redisClient) {
await this.redisClient.publish(
REDIS_CHANNEL.ACTIVE_DOLL_UPDATE,
JSON.stringify({ userId, dollId }),
);
} else {
// Fallback for single instance (no redis) - update locally directly
// This mimics what handleActiveDollUpdateMessage does
const socketId = await this.userSocketService.getSocket(userId);
if (socketId) {
const userSocket = this.io.sockets.sockets.get(
this.wsNotificationService.emitToSocket(
socketId,
) as AuthenticatedSocket;
if (userSocket) {
userSocket.data.activeDollId = dollId;
}
WS_EVENT.FRIEND_CURSOR_POSITION,
payload,
);
}
}
// 2. Broadcast to friends
const friends = await this.prisma.friendship.findMany({
where: { userId },
select: { friendId: true },
});
const friendIds = friends.map((f) => f.friendId);
const friendSockets =
await this.userSocketService.getFriendsSockets(friendIds);
this.logger.log(
`Broadcasting friend-active-doll-changed for user ${userId}, doll: ${doll ? doll.id : 'null'} to ${friendSockets.length} friends`,
);
for (const { socketId } of friendSockets) {
this.io.to(socketId).emit(WS_EVENT.FRIEND_ACTIVE_DOLL_CHANGED, {
friendId: userId,
doll: doll
? {
id: doll.id,
name: doll.name,
configuration: doll.configuration,
createdAt: doll.createdAt,
updatedAt: doll.updatedAt,
}
: null,
});
}
}
}

20
src/ws/state/ws-events.ts Normal file
View File

@@ -0,0 +1,20 @@
export const WS_EVENT = {
CLIENT_INITIALIZE: 'client-initialize',
INITIALIZED: 'initialized',
CURSOR_REPORT_POSITION: 'cursor-report-position',
FRIEND_REQUEST_RECEIVED: 'friend-request-received',
FRIEND_REQUEST_ACCEPTED: 'friend-request-accepted',
FRIEND_REQUEST_DENIED: 'friend-request-denied',
UNFRIENDED: 'unfriended',
FRIEND_CURSOR_POSITION: 'friend-cursor-position',
FRIEND_DISCONNECTED: 'friend-disconnected',
FRIEND_DOLL_CREATED: 'friend-doll-created',
FRIEND_DOLL_UPDATED: 'friend-doll-updated',
FRIEND_DOLL_DELETED: 'friend-doll-deleted',
FRIEND_ACTIVE_DOLL_CHANGED: 'friend-active-doll-changed',
} as const;
export const REDIS_CHANNEL = {
ACTIVE_DOLL_UPDATE: 'active-doll-update',
FRIEND_CACHE_UPDATE: 'friend-cache-update',
} as const;

View File

@@ -0,0 +1,101 @@
import { Injectable, Logger, Inject } from '@nestjs/common';
import Redis from 'ioredis';
import { Server } from 'socket.io';
import { UserSocketService } from './user-socket.service';
import type { AuthenticatedSocket } from '../../types/socket';
import { REDIS_CLIENT } from '../../database/redis.module';
import { REDIS_CHANNEL } from './ws-events';
@Injectable()
export class WsNotificationService {
private readonly logger = new Logger(WsNotificationService.name);
private io: Server | null = null;
constructor(
private readonly userSocketService: UserSocketService,
@Inject(REDIS_CLIENT) private readonly redisClient: Redis | null,
) {}
setIo(io: Server) {
this.io = io;
}
async emitToUser(userId: string, event: string, payload: any) {
if (!this.io) return;
const socketId = await this.userSocketService.getSocket(userId);
if (socketId) {
this.io.to(socketId).emit(event, payload);
}
}
async emitToFriends(userIds: string[], event: string, payload: any) {
if (!this.io) return;
const friendSockets =
await this.userSocketService.getFriendsSockets(userIds);
for (const { socketId } of friendSockets) {
this.io.to(socketId).emit(event, payload);
}
}
emitToSocket(socketId: string, event: string, payload: any) {
if (!this.io) return;
this.io.to(socketId).emit(event, payload);
}
async updateFriendsCache(
userId: string,
friendId: string,
action: 'add' | 'delete',
) {
if (this.redisClient) {
await this.redisClient.publish(
REDIS_CHANNEL.FRIEND_CACHE_UPDATE,
JSON.stringify({ userId, friendId, action }),
);
} else {
// Fallback: update locally
await this.updateFriendsCacheLocal(userId, friendId, action);
}
}
async updateFriendsCacheLocal(
userId: string,
friendId: string,
action: 'add' | 'delete',
) {
if (!this.io) return;
const socketId = await this.userSocketService.getSocket(userId);
if (socketId) {
const socket = this.io.sockets.sockets.get(
socketId,
) as AuthenticatedSocket;
if (socket?.data?.friends) {
if (action === 'add') socket.data.friends.add(friendId);
else socket.data.friends.delete(friendId);
}
}
}
async updateActiveDollCache(userId: string, dollId: string | null) {
if (!this.io) return;
const socketId = await this.userSocketService.getSocket(userId);
if (socketId) {
const socket = this.io.sockets.sockets.get(
socketId,
) as AuthenticatedSocket;
if (socket) socket.data.activeDollId = dollId;
}
}
async publishActiveDollUpdate(userId: string, dollId: string | null) {
if (this.redisClient) {
await this.redisClient.publish(
REDIS_CHANNEL.ACTIVE_DOLL_UPDATE,
JSON.stringify({ userId, dollId }),
);
} else {
// Fallback: update locally
await this.updateActiveDollCache(userId, dollId);
}
}
}