Redis
This commit is contained in:
@@ -5,12 +5,17 @@ import { AuthenticatedSocket } from '../../types/socket';
|
||||
import { AuthService } from '../../auth/auth.service';
|
||||
import { JwtVerificationService } from '../../auth/services/jwt-verification.service';
|
||||
import { PrismaService } from '../../database/prisma.service';
|
||||
import { UserSocketService } from './user-socket.service';
|
||||
|
||||
interface MockSocket extends Partial<AuthenticatedSocket> {
|
||||
id: string;
|
||||
data: {
|
||||
user?: {
|
||||
keycloakSub: string;
|
||||
email?: string;
|
||||
name?: string;
|
||||
preferred_username?: string;
|
||||
picture?: string;
|
||||
};
|
||||
userId?: string;
|
||||
friends?: Set<string>;
|
||||
@@ -31,6 +36,7 @@ describe('StateGateway', () => {
|
||||
let mockAuthService: Partial<AuthService>;
|
||||
let mockJwtVerificationService: Partial<JwtVerificationService>;
|
||||
let mockPrismaService: Partial<PrismaService>;
|
||||
let mockUserSocketService: Partial<UserSocketService>;
|
||||
|
||||
beforeEach(async () => {
|
||||
mockServer = {
|
||||
@@ -66,6 +72,14 @@ describe('StateGateway', () => {
|
||||
},
|
||||
};
|
||||
|
||||
mockUserSocketService = {
|
||||
setSocket: jest.fn().mockResolvedValue(undefined),
|
||||
removeSocket: jest.fn().mockResolvedValue(undefined),
|
||||
getSocket: jest.fn().mockResolvedValue(null),
|
||||
isUserOnline: jest.fn().mockResolvedValue(false),
|
||||
getFriendsSockets: jest.fn().mockResolvedValue([]),
|
||||
};
|
||||
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
StateGateway,
|
||||
@@ -75,6 +89,7 @@ describe('StateGateway', () => {
|
||||
useValue: mockJwtVerificationService,
|
||||
},
|
||||
{ provide: PrismaService, useValue: mockPrismaService },
|
||||
{ provide: UserSocketService, useValue: mockUserSocketService },
|
||||
],
|
||||
}).compile();
|
||||
|
||||
@@ -130,6 +145,10 @@ describe('StateGateway', () => {
|
||||
keycloakSub: 'test-sub',
|
||||
}),
|
||||
);
|
||||
expect(mockUserSocketService.setSocket).toHaveBeenCalledWith(
|
||||
'user-id',
|
||||
'client1',
|
||||
);
|
||||
expect(mockLoggerLog).toHaveBeenCalledWith(
|
||||
`Client id: ${mockClient.id} connected (user: test-sub)`,
|
||||
);
|
||||
@@ -165,35 +184,57 @@ describe('StateGateway', () => {
|
||||
});
|
||||
|
||||
describe('handleDisconnect', () => {
|
||||
it('should log client disconnection', () => {
|
||||
it('should log client disconnection', async () => {
|
||||
const mockClient: MockSocket = {
|
||||
id: 'client1',
|
||||
data: { user: { keycloakSub: 'test-sub' } },
|
||||
};
|
||||
|
||||
gateway.handleDisconnect(mockClient as unknown as AuthenticatedSocket);
|
||||
await gateway.handleDisconnect(mockClient as unknown as AuthenticatedSocket);
|
||||
|
||||
expect(mockLoggerLog).toHaveBeenCalledWith(
|
||||
`Client id: ${mockClient.id} disconnected (user: test-sub)`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle disconnection when no user data', () => {
|
||||
it('should handle disconnection when no user data', async () => {
|
||||
const mockClient: MockSocket = {
|
||||
id: 'client1',
|
||||
data: {},
|
||||
};
|
||||
|
||||
gateway.handleDisconnect(mockClient as unknown as AuthenticatedSocket);
|
||||
await gateway.handleDisconnect(mockClient as unknown as AuthenticatedSocket);
|
||||
|
||||
expect(mockLoggerLog).toHaveBeenCalledWith(
|
||||
`Client id: ${mockClient.id} disconnected (user: unknown)`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should remove socket if it matches', async () => {
|
||||
const mockClient: MockSocket = {
|
||||
id: 'client1',
|
||||
data: {
|
||||
user: { keycloakSub: 'test-sub' },
|
||||
userId: 'user-id',
|
||||
friends: new Set(['friend-1']),
|
||||
},
|
||||
};
|
||||
|
||||
(mockUserSocketService.getSocket as jest.Mock).mockResolvedValue('client1');
|
||||
(mockUserSocketService.getFriendsSockets as jest.Mock).mockResolvedValue([
|
||||
{ userId: 'friend-1', socketId: 'friend-socket-id' }
|
||||
]);
|
||||
|
||||
await gateway.handleDisconnect(mockClient as unknown as AuthenticatedSocket);
|
||||
|
||||
expect(mockUserSocketService.getSocket).toHaveBeenCalledWith('user-id');
|
||||
expect(mockUserSocketService.removeSocket).toHaveBeenCalledWith('user-id');
|
||||
expect(mockServer.to).toHaveBeenCalledWith('friend-socket-id');
|
||||
});
|
||||
});
|
||||
|
||||
describe('handleCursorReportPosition', () => {
|
||||
it('should emit cursor position to connected friends', () => {
|
||||
it('should emit cursor position to connected friends', async () => {
|
||||
const mockClient: MockSocket = {
|
||||
id: 'client1',
|
||||
data: {
|
||||
@@ -203,13 +244,14 @@ describe('StateGateway', () => {
|
||||
},
|
||||
};
|
||||
|
||||
// Setup the userSocketMap to simulate a connected friend
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access
|
||||
(gateway as any).userSocketMap.set('friend-1', 'friend-socket-id');
|
||||
// Mock getFriendsSockets to return the friend's socket
|
||||
(mockUserSocketService.getFriendsSockets as jest.Mock).mockResolvedValue([
|
||||
{ userId: 'friend-1', socketId: 'friend-socket-id' },
|
||||
]);
|
||||
|
||||
const data: CursorPositionDto = { x: 100, y: 200 };
|
||||
|
||||
gateway.handleCursorReportPosition(
|
||||
await gateway.handleCursorReportPosition(
|
||||
mockClient as unknown as AuthenticatedSocket,
|
||||
data,
|
||||
);
|
||||
@@ -224,7 +266,7 @@ describe('StateGateway', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should not emit when no friends are online', () => {
|
||||
it('should not emit when no friends are online', async () => {
|
||||
const mockClient: MockSocket = {
|
||||
id: 'client1',
|
||||
data: {
|
||||
@@ -234,10 +276,12 @@ describe('StateGateway', () => {
|
||||
},
|
||||
};
|
||||
|
||||
// Don't set up userSocketMap - friend is not online
|
||||
// Mock getFriendsSockets to return empty array
|
||||
(mockUserSocketService.getFriendsSockets as jest.Mock).mockResolvedValue([]);
|
||||
|
||||
const data: CursorPositionDto = { x: 100, y: 200 };
|
||||
|
||||
gateway.handleCursorReportPosition(
|
||||
await gateway.handleCursorReportPosition(
|
||||
mockClient as unknown as AuthenticatedSocket,
|
||||
data,
|
||||
);
|
||||
@@ -246,7 +290,7 @@ describe('StateGateway', () => {
|
||||
expect(mockServer.to).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should log warning when userId is missing', () => {
|
||||
it('should log warning when userId is missing', async () => {
|
||||
const mockClient: MockSocket = {
|
||||
id: 'client1',
|
||||
data: {
|
||||
@@ -258,7 +302,7 @@ describe('StateGateway', () => {
|
||||
|
||||
const data: CursorPositionDto = { x: 100, y: 200 };
|
||||
|
||||
gateway.handleCursorReportPosition(
|
||||
await gateway.handleCursorReportPosition(
|
||||
mockClient as unknown as AuthenticatedSocket,
|
||||
data,
|
||||
);
|
||||
@@ -271,19 +315,19 @@ describe('StateGateway', () => {
|
||||
expect(mockServer.to).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should throw exception when client is not authenticated', () => {
|
||||
it('should throw exception when client is not authenticated', async () => {
|
||||
const mockClient: MockSocket = {
|
||||
id: 'client1',
|
||||
data: {},
|
||||
};
|
||||
const data: CursorPositionDto = { x: 100, y: 200 };
|
||||
|
||||
expect(() => {
|
||||
await expect(
|
||||
gateway.handleCursorReportPosition(
|
||||
mockClient as unknown as AuthenticatedSocket,
|
||||
data,
|
||||
);
|
||||
}).toThrow('Unauthorized');
|
||||
),
|
||||
).rejects.toThrow('Unauthorized');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -16,6 +16,7 @@ import { AuthService } from '../../auth/auth.service';
|
||||
import { JwtVerificationService } from '../../auth/services/jwt-verification.service';
|
||||
import { CursorPositionDto } from '../dto/cursor-position.dto';
|
||||
import { PrismaService } from '../../database/prisma.service';
|
||||
import { UserSocketService } from './user-socket.service';
|
||||
|
||||
import { FriendEvents } from '../../friends/events/friend.events';
|
||||
import type {
|
||||
@@ -45,7 +46,6 @@ export class StateGateway
|
||||
implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect
|
||||
{
|
||||
private readonly logger = new Logger(StateGateway.name);
|
||||
private userSocketMap: Map<string, string> = new Map();
|
||||
private lastBroadcastMap: Map<string, number> = new Map();
|
||||
|
||||
@WebSocketServer() io: Server;
|
||||
@@ -54,6 +54,7 @@ export class StateGateway
|
||||
private readonly authService: AuthService,
|
||||
private readonly jwtVerificationService: JwtVerificationService,
|
||||
private readonly prisma: PrismaService,
|
||||
private readonly userSocketService: UserSocketService,
|
||||
) {}
|
||||
|
||||
afterInit() {
|
||||
@@ -94,7 +95,7 @@ export class StateGateway
|
||||
this.logger.log(`WebSocket authenticated: ${payload.sub}`);
|
||||
|
||||
const user = await this.authService.syncUserFromToken(client.data.user);
|
||||
this.userSocketMap.set(user.id, client.id);
|
||||
await this.userSocketService.setSocket(user.id, client.id);
|
||||
client.data.userId = user.id;
|
||||
|
||||
// Initialize friends cache using Prisma directly
|
||||
@@ -117,7 +118,7 @@ export class StateGateway
|
||||
}
|
||||
}
|
||||
|
||||
handleDisconnect(client: AuthenticatedSocket) {
|
||||
async handleDisconnect(client: AuthenticatedSocket) {
|
||||
const user = client.data.user;
|
||||
|
||||
if (user) {
|
||||
@@ -125,34 +126,28 @@ export class StateGateway
|
||||
|
||||
if (userId) {
|
||||
// Check if this socket is still the active one for the user
|
||||
const currentSocketId = this.userSocketMap.get(userId);
|
||||
const currentSocketId = await this.userSocketService.getSocket(userId);
|
||||
if (currentSocketId === client.id) {
|
||||
this.userSocketMap.delete(userId);
|
||||
await this.userSocketService.removeSocket(userId);
|
||||
this.lastBroadcastMap.delete(userId);
|
||||
|
||||
// Notify friends that this user has disconnected
|
||||
const friends = client.data.friends;
|
||||
if (friends) {
|
||||
for (const friendId of friends) {
|
||||
const friendSocketId = this.userSocketMap.get(friendId);
|
||||
if (friendSocketId) {
|
||||
this.io.to(friendSocketId).emit(WS_EVENT.FRIEND_DISCONNECTED, {
|
||||
userId: userId,
|
||||
});
|
||||
}
|
||||
const friendIds = Array.from(friends);
|
||||
const friendSockets = await this.userSocketService.getFriendsSockets(friendIds);
|
||||
|
||||
for (const { socketId } of friendSockets) {
|
||||
this.io.to(socketId).emit(WS_EVENT.FRIEND_DISCONNECTED, {
|
||||
userId: userId,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback for cases where client.data.userId might not be set
|
||||
for (const [uid, socketId] of this.userSocketMap.entries()) {
|
||||
if (socketId === client.id) {
|
||||
this.userSocketMap.delete(uid);
|
||||
this.lastBroadcastMap.delete(uid);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Note: We can't iterate over Redis keys easily to find socketId match without userId
|
||||
// The previous fallback loop over map entries is not efficient with Redis.
|
||||
// We rely on client.data.userId being set correctly during connection.
|
||||
}
|
||||
|
||||
this.logger.log(
|
||||
@@ -160,12 +155,12 @@ export class StateGateway
|
||||
);
|
||||
}
|
||||
|
||||
isUserOnline(userId: string): boolean {
|
||||
return this.userSocketMap.has(userId);
|
||||
async isUserOnline(userId: string): Promise<boolean> {
|
||||
return this.userSocketService.isUserOnline(userId);
|
||||
}
|
||||
|
||||
@SubscribeMessage(WS_EVENT.CURSOR_REPORT_POSITION)
|
||||
handleCursorReportPosition(
|
||||
async handleCursorReportPosition(
|
||||
client: AuthenticatedSocket,
|
||||
data: CursorPositionDto,
|
||||
) {
|
||||
@@ -192,25 +187,25 @@ export class StateGateway
|
||||
// Broadcast to online friends
|
||||
const friends = client.data.friends;
|
||||
if (friends) {
|
||||
for (const friendId of friends) {
|
||||
const friendSocketId = this.userSocketMap.get(friendId);
|
||||
if (friendSocketId) {
|
||||
const payload = {
|
||||
userId: currentUserId,
|
||||
position: data,
|
||||
};
|
||||
this.io
|
||||
.to(friendSocketId)
|
||||
.emit(WS_EVENT.FRIEND_CURSOR_POSITION, payload);
|
||||
}
|
||||
const friendIds = Array.from(friends);
|
||||
const friendSockets = await this.userSocketService.getFriendsSockets(friendIds);
|
||||
|
||||
for (const { socketId } of friendSockets) {
|
||||
const payload = {
|
||||
userId: currentUserId,
|
||||
position: data,
|
||||
};
|
||||
this.io
|
||||
.to(socketId)
|
||||
.emit(WS_EVENT.FRIEND_CURSOR_POSITION, payload);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@OnEvent(FriendEvents.REQUEST_RECEIVED)
|
||||
handleFriendRequestReceived(payload: FriendRequestReceivedEvent) {
|
||||
async handleFriendRequestReceived(payload: FriendRequestReceivedEvent) {
|
||||
const { userId, friendRequest } = payload;
|
||||
const socketId = this.userSocketMap.get(userId);
|
||||
const socketId = await this.userSocketService.getSocket(userId);
|
||||
if (socketId) {
|
||||
this.io.to(socketId).emit(WS_EVENT.FRIEND_REQUEST_RECEIVED, {
|
||||
id: friendRequest.id,
|
||||
@@ -229,10 +224,10 @@ export class StateGateway
|
||||
}
|
||||
|
||||
@OnEvent(FriendEvents.REQUEST_ACCEPTED)
|
||||
handleFriendRequestAccepted(payload: FriendRequestAcceptedEvent) {
|
||||
async handleFriendRequestAccepted(payload: FriendRequestAcceptedEvent) {
|
||||
const { userId, friendRequest } = payload;
|
||||
|
||||
const socketId = this.userSocketMap.get(userId);
|
||||
const socketId = await this.userSocketService.getSocket(userId);
|
||||
|
||||
// 1. Update cache for the user who sent the request (userId / friendRequest.senderId)
|
||||
if (socketId) {
|
||||
@@ -259,7 +254,7 @@ export class StateGateway
|
||||
}
|
||||
|
||||
// 2. Update cache for the user who accepted the request (friendRequest.receiverId)
|
||||
const receiverSocketId = this.userSocketMap.get(friendRequest.receiverId);
|
||||
const receiverSocketId = await this.userSocketService.getSocket(friendRequest.receiverId);
|
||||
if (receiverSocketId) {
|
||||
const receiverSocket = this.io.sockets.sockets.get(
|
||||
receiverSocketId,
|
||||
@@ -271,9 +266,9 @@ export class StateGateway
|
||||
}
|
||||
|
||||
@OnEvent(FriendEvents.REQUEST_DENIED)
|
||||
handleFriendRequestDenied(payload: FriendRequestDeniedEvent) {
|
||||
async handleFriendRequestDenied(payload: FriendRequestDeniedEvent) {
|
||||
const { userId, friendRequest } = payload;
|
||||
const socketId = this.userSocketMap.get(userId);
|
||||
const socketId = await this.userSocketService.getSocket(userId);
|
||||
if (socketId) {
|
||||
this.io.to(socketId).emit(WS_EVENT.FRIEND_REQUEST_DENIED, {
|
||||
id: friendRequest.id,
|
||||
@@ -292,10 +287,10 @@ export class StateGateway
|
||||
}
|
||||
|
||||
@OnEvent(FriendEvents.UNFRIENDED)
|
||||
handleUnfriended(payload: UnfriendedEvent) {
|
||||
async handleUnfriended(payload: UnfriendedEvent) {
|
||||
const { userId, friendId } = payload;
|
||||
|
||||
const socketId = this.userSocketMap.get(userId);
|
||||
const socketId = await this.userSocketService.getSocket(userId);
|
||||
|
||||
// 1. Update cache for the user receiving the notification (userId)
|
||||
if (socketId) {
|
||||
@@ -313,7 +308,7 @@ export class StateGateway
|
||||
}
|
||||
|
||||
// 2. Update cache for the user initiating the unfriend (friendId)
|
||||
const initiatorSocketId = this.userSocketMap.get(friendId);
|
||||
const initiatorSocketId = await this.userSocketService.getSocket(friendId);
|
||||
if (initiatorSocketId) {
|
||||
const initiatorSocket = this.io.sockets.sockets.get(
|
||||
initiatorSocketId,
|
||||
|
||||
113
src/ws/state/user-socket.service.ts
Normal file
113
src/ws/state/user-socket.service.ts
Normal file
@@ -0,0 +1,113 @@
|
||||
import { Injectable, Inject, Logger } from '@nestjs/common';
|
||||
import { REDIS_CLIENT } from '../../database/redis.module';
|
||||
import Redis from 'ioredis';
|
||||
|
||||
@Injectable()
|
||||
export class UserSocketService {
|
||||
private readonly logger = new Logger(UserSocketService.name);
|
||||
private localUserSocketMap: Map<string, string> = new Map();
|
||||
private readonly PREFIX = 'socket:user:';
|
||||
private readonly TTL = 86400; // 24 hours
|
||||
|
||||
constructor(
|
||||
@Inject(REDIS_CLIENT) private readonly redisClient: Redis | null,
|
||||
) {}
|
||||
|
||||
async setSocket(userId: string, socketId: string): Promise<void> {
|
||||
if (this.redisClient) {
|
||||
try {
|
||||
await this.redisClient.set(
|
||||
`${this.PREFIX}${userId}`,
|
||||
socketId,
|
||||
'EX',
|
||||
this.TTL,
|
||||
);
|
||||
} catch (error) {
|
||||
this.logger.error(
|
||||
`Failed to set socket for user ${userId} in Redis`,
|
||||
error,
|
||||
);
|
||||
// Fallback to local map on error? Or just log?
|
||||
// Let's use local map as backup if redis is down/null
|
||||
this.localUserSocketMap.set(userId, socketId);
|
||||
}
|
||||
} else {
|
||||
this.localUserSocketMap.set(userId, socketId);
|
||||
}
|
||||
}
|
||||
|
||||
async removeSocket(userId: string): Promise<void> {
|
||||
if (this.redisClient) {
|
||||
try {
|
||||
await this.redisClient.del(`${this.PREFIX}${userId}`);
|
||||
} catch (error) {
|
||||
this.logger.error(
|
||||
`Failed to remove socket for user ${userId} from Redis`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
}
|
||||
this.localUserSocketMap.delete(userId);
|
||||
}
|
||||
|
||||
async getSocket(userId: string): Promise<string | null> {
|
||||
if (this.redisClient) {
|
||||
try {
|
||||
const socketId = await this.redisClient.get(`${this.PREFIX}${userId}`);
|
||||
return socketId;
|
||||
} catch (error) {
|
||||
this.logger.error(
|
||||
`Failed to get socket for user ${userId} from Redis`,
|
||||
error,
|
||||
);
|
||||
return this.localUserSocketMap.get(userId) || null;
|
||||
}
|
||||
}
|
||||
return this.localUserSocketMap.get(userId) || null;
|
||||
}
|
||||
|
||||
async isUserOnline(userId: string): Promise<boolean> {
|
||||
const socketId = await this.getSocket(userId);
|
||||
return !!socketId;
|
||||
}
|
||||
|
||||
async getFriendsSockets(friendIds: string[]): Promise<{ userId: string; socketId: string }[]> {
|
||||
if (friendIds.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
if (this.redisClient) {
|
||||
try {
|
||||
// Use pipeline for batch fetching
|
||||
const pipeline = this.redisClient.pipeline();
|
||||
friendIds.forEach((id) => pipeline.get(`${this.PREFIX}${id}`));
|
||||
const results = await pipeline.exec();
|
||||
|
||||
const sockets: { userId: string; socketId: string }[] = [];
|
||||
|
||||
if (results) {
|
||||
results.forEach((result, index) => {
|
||||
const [err, socketId] = result;
|
||||
if (!err && socketId && typeof socketId === 'string') {
|
||||
sockets.push({ userId: friendIds[index], socketId });
|
||||
}
|
||||
});
|
||||
}
|
||||
return sockets;
|
||||
} catch (error) {
|
||||
this.logger.error('Failed to batch get friend sockets from Redis', error);
|
||||
// Fallback to local implementation
|
||||
}
|
||||
}
|
||||
|
||||
// Local fallback
|
||||
const sockets: { userId: string; socketId: string }[] = [];
|
||||
for (const friendId of friendIds) {
|
||||
const socketId = this.localUserSocketMap.get(friendId);
|
||||
if (socketId) {
|
||||
sockets.push({ userId: friendId, socketId });
|
||||
}
|
||||
}
|
||||
return sockets;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user