Minor refinements & tuning to system structure

This commit is contained in:
2025-12-23 15:57:05 +08:00
parent 6c63f2d803
commit d06a58cf93
8 changed files with 394 additions and 87 deletions

View File

@@ -3,6 +3,7 @@ import { ConfigService } from '@nestjs/config';
import Redis from 'ioredis'; import Redis from 'ioredis';
export const REDIS_CLIENT = 'REDIS_CLIENT'; export const REDIS_CLIENT = 'REDIS_CLIENT';
export const REDIS_SUBSCRIBER_CLIENT = 'REDIS_SUBSCRIBER_CLIENT';
@Global() @Global()
@Module({ @Module({
@@ -46,7 +47,46 @@ export const REDIS_CLIENT = 'REDIS_CLIENT';
}, },
inject: [ConfigService], inject: [ConfigService],
}, },
{
provide: REDIS_SUBSCRIBER_CLIENT,
useFactory: (configService: ConfigService) => {
const logger = new Logger('RedisSubscriberModule');
const host = configService.get<string>('REDIS_HOST');
const port = configService.get<number>('REDIS_PORT');
const password = configService.get<string>('REDIS_PASSWORD');
if (!host) {
return null;
}
const client = new Redis({
host,
port: port || 6379,
password: password,
retryStrategy(times) {
const delay = Math.min(times * 50, 2000);
return delay;
},
});
client.on('error', (err) => {
// Suppress the known error that happens when ioredis tries to perform checks on a subscriber connection
if (
err.message &&
err.message.includes(
'Connection in subscriber mode, only subscriber commands may be used',
)
) {
return;
}
logger.error('Redis subscriber connection error', err);
});
return client;
},
inject: [ConfigService],
},
], ],
exports: [REDIS_CLIENT], exports: [REDIS_CLIENT, REDIS_SUBSCRIBER_CLIENT],
}) })
export class RedisModule {} export class RedisModule {}

View File

@@ -75,13 +75,6 @@ describe('DollsService', () => {
const createDto = { name: 'New Doll' }; const createDto = { name: 'New Doll' };
const userId = 'user-1'; const userId = 'user-1';
// Mock the transaction callback to return the doll
jest
.spyOn(prismaService, '$transaction')
.mockImplementation(async (callback) => {
return callback(prismaService);
});
await service.create(userId, createDto); await service.create(userId, createDto);
expect(prismaService.doll.create).toHaveBeenCalledWith({ expect(prismaService.doll.create).toHaveBeenCalledWith({

View File

@@ -0,0 +1,48 @@
import {
IsString,
IsNotEmpty,
IsOptional,
ValidateNested,
} from 'class-validator';
import { Type } from 'class-transformer';
import { ApiProperty, ApiPropertyOptional } from '@nestjs/swagger';
import { DollConfigurationDto } from './create-doll.dto';
export class ActiveDollDto {
@ApiProperty({
description: 'Unique identifier of the doll',
example: '550e8400-e29b-41d4-a716-446655440000',
})
@IsString()
@IsNotEmpty()
id: string;
@ApiProperty({
description: 'Display name of the doll',
example: 'My First Doll',
})
@IsString()
@IsNotEmpty()
name: string;
@ApiPropertyOptional({
description: 'Configuration for the doll',
type: DollConfigurationDto,
})
@IsOptional()
@ValidateNested()
@Type(() => DollConfigurationDto)
configuration?: DollConfigurationDto;
@ApiProperty({
description: 'Creation date of the doll',
example: '2023-01-01T00:00:00.000Z',
})
createdAt: Date;
@ApiProperty({
description: 'Last update date of the doll',
example: '2023-01-01T00:00:00.000Z',
})
updatedAt: Date;
}

View File

@@ -1,4 +1,5 @@
import { ApiProperty } from '@nestjs/swagger'; import { ApiProperty } from '@nestjs/swagger';
import { ActiveDollDto } from '../../dolls/dto/active-doll.dto';
export class UserBasicDto { export class UserBasicDto {
@ApiProperty({ @ApiProperty({
@@ -30,8 +31,9 @@ export class UserBasicDto {
@ApiProperty({ @ApiProperty({
description: "User's active doll", description: "User's active doll",
required: false, required: false,
type: ActiveDollDto,
}) })
activeDoll?: any; activeDoll?: ActiveDollDto;
} }
export class FriendRequestResponseDto { export class FriendRequestResponseDto {

View File

@@ -20,7 +20,7 @@ import {
ApiQuery, ApiQuery,
} from '@nestjs/swagger'; } from '@nestjs/swagger';
import { ThrottlerGuard, Throttle } from '@nestjs/throttler'; import { ThrottlerGuard, Throttle } from '@nestjs/throttler';
import { User, FriendRequest } from '@prisma/client'; import { User, FriendRequest, Prisma } from '@prisma/client';
import { FriendsService } from './friends.service'; import { FriendsService } from './friends.service';
import { JwtAuthGuard } from '../auth/guards/jwt-auth.guard'; import { JwtAuthGuard } from '../auth/guards/jwt-auth.guard';
import { import {
@@ -42,19 +42,13 @@ type FriendRequestWithRelations = FriendRequest & {
}; };
import { UsersService } from '../users/users.service'; import { UsersService } from '../users/users.service';
type FriendWithDoll = { type FriendshipWithFriendAndDoll = Prisma.FriendshipGetPayload<{
id: string; include: {
name: string; friend: {
username: string | null; include: { activeDoll: true };
picture: string | null; };
activeDoll?: { };
id: string; }>;
name: string;
configuration: any;
createdAt: Date;
updatedAt: Date;
} | null;
};
@ApiTags('friends') @ApiTags('friends')
@Controller('friends') @Controller('friends')
@@ -314,8 +308,10 @@ export class FriendsController {
const friendships = await this.friendsService.getFriends(user.id); const friendships = await this.friendsService.getFriends(user.id);
return friendships.map((friendship) => { return friendships.map((friendship) => {
// Need to cast to any because TS doesn't know about the included relation in the service method // Use Prisma generated type for safe casting
const friend = friendship.friend as unknown as FriendWithDoll; const typedFriendship =
friendship as unknown as FriendshipWithFriendAndDoll;
const friend = typedFriendship.friend;
return { return {
id: friendship.id, id: friendship.id,
@@ -328,7 +324,8 @@ export class FriendsController {
? { ? {
id: friend.activeDoll.id, id: friend.activeDoll.id,
name: friend.activeDoll.name, name: friend.activeDoll.name,
configuration: friend.activeDoll.configuration as unknown, // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
configuration: friend.activeDoll.configuration as any,
createdAt: friend.activeDoll.createdAt, createdAt: friend.activeDoll.createdAt,
updatedAt: friend.activeDoll.updatedAt, updatedAt: friend.activeDoll.updatedAt,
} }

View File

@@ -59,6 +59,16 @@ export class RedisIoAdapter extends IoAdapter {
this.logger.error('Redis Pub client error', err); this.logger.error('Redis Pub client error', err);
}); });
subClient.on('error', (err) => { subClient.on('error', (err) => {
// Suppress specific error about subscriber mode restrictions
// This is a known issue/behavior when ioredis performs internal checks (like info) on a subscriber connection
if (
err.message &&
err.message.includes(
'Connection in subscriber mode, only subscriber commands may be used',
)
) {
return;
}
this.logger.error('Redis Sub client error', err); this.logger.error('Redis Sub client error', err);
}); });

View File

@@ -18,10 +18,12 @@ interface MockSocket extends Partial<AuthenticatedSocket> {
picture?: string; picture?: string;
}; };
userId?: string; userId?: string;
activeDollId?: string | null;
friends?: Set<string>; friends?: Set<string>;
}; };
handshake?: any; handshake?: any;
disconnect?: jest.Mock; disconnect?: jest.Mock;
emit?: jest.Mock;
} }
describe('StateGateway', () => { describe('StateGateway', () => {
@@ -29,6 +31,7 @@ describe('StateGateway', () => {
let mockLoggerLog: jest.SpyInstance; let mockLoggerLog: jest.SpyInstance;
let mockLoggerDebug: jest.SpyInstance; let mockLoggerDebug: jest.SpyInstance;
let mockLoggerWarn: jest.SpyInstance; let mockLoggerWarn: jest.SpyInstance;
let mockLoggerError: jest.SpyInstance;
let mockServer: { let mockServer: {
sockets: { sockets: { size: number; get: jest.Mock } }; sockets: { sockets: { size: number; get: jest.Mock } };
to: jest.Mock; to: jest.Mock;
@@ -37,6 +40,8 @@ describe('StateGateway', () => {
let mockJwtVerificationService: Partial<JwtVerificationService>; let mockJwtVerificationService: Partial<JwtVerificationService>;
let mockPrismaService: Partial<PrismaService>; let mockPrismaService: Partial<PrismaService>;
let mockUserSocketService: Partial<UserSocketService>; let mockUserSocketService: Partial<UserSocketService>;
let mockRedisClient: { publish: jest.Mock };
let mockRedisSubscriber: { subscribe: jest.Mock; on: jest.Mock };
beforeEach(async () => { beforeEach(async () => {
mockServer = { mockServer = {
@@ -67,9 +72,12 @@ describe('StateGateway', () => {
}; };
mockPrismaService = { mockPrismaService = {
user: {
findUnique: jest.fn().mockResolvedValue({ activeDollId: 'doll-123' }),
} as any,
friendship: { friendship: {
findMany: jest.fn().mockResolvedValue([]), findMany: jest.fn().mockResolvedValue([]),
}, } as any,
}; };
mockUserSocketService = { mockUserSocketService = {
@@ -80,6 +88,15 @@ describe('StateGateway', () => {
getFriendsSockets: jest.fn().mockResolvedValue([]), getFriendsSockets: jest.fn().mockResolvedValue([]),
}; };
mockRedisClient = {
publish: jest.fn().mockResolvedValue(1),
};
mockRedisSubscriber = {
subscribe: jest.fn().mockResolvedValue(undefined),
on: jest.fn(),
};
const module: TestingModule = await Test.createTestingModule({ const module: TestingModule = await Test.createTestingModule({
providers: [ providers: [
StateGateway, StateGateway,
@@ -90,6 +107,8 @@ describe('StateGateway', () => {
}, },
{ provide: PrismaService, useValue: mockPrismaService }, { provide: PrismaService, useValue: mockPrismaService },
{ provide: UserSocketService, useValue: mockUserSocketService }, { provide: UserSocketService, useValue: mockUserSocketService },
{ provide: 'REDIS_CLIENT', useValue: mockRedisClient },
{ provide: 'REDIS_SUBSCRIBER_CLIENT', useValue: mockRedisSubscriber },
], ],
}).compile(); }).compile();
@@ -101,6 +120,9 @@ describe('StateGateway', () => {
.spyOn(gateway['logger'], 'debug') .spyOn(gateway['logger'], 'debug')
.mockImplementation(); .mockImplementation();
mockLoggerWarn = jest.spyOn(gateway['logger'], 'warn').mockImplementation(); mockLoggerWarn = jest.spyOn(gateway['logger'], 'warn').mockImplementation();
mockLoggerError = jest
.spyOn(gateway['logger'], 'error')
.mockImplementation();
}); });
afterEach(() => { afterEach(() => {
@@ -114,20 +136,27 @@ describe('StateGateway', () => {
describe('afterInit', () => { describe('afterInit', () => {
it('should log initialization message', () => { it('should log initialization message', () => {
gateway.afterInit(); gateway.afterInit();
expect(mockLoggerLog).toHaveBeenCalledWith('Initialized'); expect(mockLoggerLog).toHaveBeenCalledWith('Initialized');
}); });
it('should subscribe to redis channel', () => {
expect(mockRedisSubscriber.subscribe).toHaveBeenCalledWith(
'active-doll-update',
expect.any(Function),
);
});
}); });
describe('handleConnection', () => { describe('handleConnection', () => {
it('should log client connection and sync user when authenticated', async () => { it('should verify token and set basic user data (but NOT sync DB)', async () => {
const mockClient: MockSocket = { const mockClient: MockSocket = {
id: 'client1', id: 'client1',
data: { user: { keycloakSub: 'test-sub' } }, data: {},
handshake: { handshake: {
auth: { token: 'mock-token' }, auth: { token: 'mock-token' },
headers: {}, headers: {},
}, },
disconnect: jest.fn(),
}; };
await gateway.handleConnection( await gateway.handleConnection(
@@ -140,20 +169,21 @@ describe('StateGateway', () => {
expect(mockJwtVerificationService.verifyToken).toHaveBeenCalledWith( expect(mockJwtVerificationService.verifyToken).toHaveBeenCalledWith(
'mock-token', 'mock-token',
); );
expect(mockAuthService.syncUserFromToken).toHaveBeenCalledWith(
// Should NOT call these anymore in handleConnection
expect(mockAuthService.syncUserFromToken).not.toHaveBeenCalled();
expect(mockUserSocketService.setSocket).not.toHaveBeenCalled();
// Should set data on client
expect(mockClient.data.user).toEqual(
expect.objectContaining({ expect.objectContaining({
keycloakSub: 'test-sub', keycloakSub: 'test-sub',
}), }),
); );
expect(mockUserSocketService.setSocket).toHaveBeenCalledWith( expect(mockClient.data.activeDollId).toBeNull();
'user-id',
'client1',
);
expect(mockLoggerLog).toHaveBeenCalledWith( expect(mockLoggerLog).toHaveBeenCalledWith(
`Client id: ${mockClient.id} connected (user: test-sub)`, expect.stringContaining('WebSocket authenticated (Pending Init)'),
);
expect(mockLoggerDebug).toHaveBeenCalledWith(
'Number of connected clients: 5',
); );
}); });
@@ -183,6 +213,83 @@ describe('StateGateway', () => {
}); });
}); });
describe('handleClientInitialize', () => {
it('should sync user, fetch state, and emit initialized event', async () => {
const mockClient: MockSocket = {
id: 'client1',
data: {
user: { keycloakSub: 'test-sub' },
friends: new Set(),
},
emit: jest.fn(),
disconnect: jest.fn(),
};
// Mock Prisma responses
(mockPrismaService.user!.findUnique as jest.Mock).mockResolvedValue({
activeDollId: 'doll-123',
});
(mockPrismaService.friendship!.findMany as jest.Mock).mockResolvedValue([
{ friendId: 'friend-1' },
{ friendId: 'friend-2' },
]);
await gateway.handleClientInitialize(
mockClient as unknown as AuthenticatedSocket,
);
// 1. Sync User
expect(mockAuthService.syncUserFromToken).toHaveBeenCalledWith(
mockClient.data.user,
);
// 2. Set Socket
expect(mockUserSocketService.setSocket).toHaveBeenCalledWith(
'user-id',
'client1',
);
// 3. Fetch State (DB)
expect(mockPrismaService.user!.findUnique).toHaveBeenCalledWith({
where: { id: 'user-id' },
select: { activeDollId: true },
});
expect(mockPrismaService.friendship!.findMany).toHaveBeenCalledWith({
where: { userId: 'user-id' },
select: { friendId: true },
});
// 4. Update Client Data
expect(mockClient.data.userId).toBe('user-id');
expect(mockClient.data.activeDollId).toBe('doll-123');
expect(mockClient.data.friends).toContain('friend-1');
expect(mockClient.data.friends).toContain('friend-2');
// 5. Emit Initialized
expect(mockClient.emit).toHaveBeenCalledWith('initialized', {
userId: 'user-id',
activeDollId: 'doll-123',
});
});
it('should disconnect if no user data present on socket', async () => {
const mockClient: MockSocket = {
id: 'client1',
data: {}, // Missing user data
disconnect: jest.fn(),
};
await gateway.handleClientInitialize(
mockClient as unknown as AuthenticatedSocket,
);
expect(mockLoggerError).toHaveBeenCalledWith(
expect.stringContaining('Unauthorized: No user data found'),
);
expect(mockClient.disconnect).toHaveBeenCalled();
});
});
describe('handleDisconnect', () => { describe('handleDisconnect', () => {
it('should log client disconnection', async () => { it('should log client disconnection', async () => {
const mockClient: MockSocket = { const mockClient: MockSocket = {
@@ -250,6 +357,7 @@ describe('StateGateway', () => {
data: { data: {
user: { keycloakSub: 'test-sub' }, user: { keycloakSub: 'test-sub' },
userId: 'user-1', userId: 'user-1',
activeDollId: 'doll-1', // User must have active doll
friends: new Set(['friend-1']), friends: new Set(['friend-1']),
}, },
}; };
@@ -261,6 +369,9 @@ describe('StateGateway', () => {
const data: CursorPositionDto = { x: 100, y: 200 }; const data: CursorPositionDto = { x: 100, y: 200 };
// Force time to pass for throttle check if needed, or rely on first call passing
// The implementation uses lastBroadcastMap, initialized to empty, so first call should pass if now > 0
await gateway.handleCursorReportPosition( await gateway.handleCursorReportPosition(
mockClient as unknown as AuthenticatedSocket, mockClient as unknown as AuthenticatedSocket,
data, data,
@@ -276,21 +387,17 @@ describe('StateGateway', () => {
}); });
}); });
it('should not emit when no friends are online', async () => { it('should NOT emit if user has no active doll', async () => {
const mockClient: MockSocket = { const mockClient: MockSocket = {
id: 'client1', id: 'client1',
data: { data: {
user: { keycloakSub: 'test-sub' }, user: { keycloakSub: 'test-sub' },
userId: 'user-1', userId: 'user-1',
activeDollId: null, // No doll
friends: new Set(['friend-1']), friends: new Set(['friend-1']),
}, },
}; };
// Mock getFriendsSockets to return empty array
(mockUserSocketService.getFriendsSockets as jest.Mock).mockResolvedValue(
[],
);
const data: CursorPositionDto = { x: 100, y: 200 }; const data: CursorPositionDto = { x: 100, y: 200 };
await gateway.handleCursorReportPosition( await gateway.handleCursorReportPosition(
@@ -298,11 +405,10 @@ describe('StateGateway', () => {
data, data,
); );
// Verify that no message was emitted
expect(mockServer.to).not.toHaveBeenCalled(); expect(mockServer.to).not.toHaveBeenCalled();
}); });
it('should log warning when userId is missing', async () => { it('should return early when userId is missing (not initialized)', async () => {
const mockClient: MockSocket = { const mockClient: MockSocket = {
id: 'client1', id: 'client1',
data: { data: {
@@ -319,12 +425,9 @@ describe('StateGateway', () => {
data, data,
); );
// Verify that a warning was logged
expect(mockLoggerWarn).toHaveBeenCalledWith(
`Could not find user ID for client ${mockClient.id}`,
);
// Verify that no message was emitted // Verify that no message was emitted
expect(mockServer.to).not.toHaveBeenCalled(); expect(mockServer.to).not.toHaveBeenCalled();
// No explicit warning log expected in new implementation for just return
}); });
it('should throw exception when client is not authenticated', async () => { it('should throw exception when client is not authenticated', async () => {

View File

@@ -1,4 +1,4 @@
import { Logger } from '@nestjs/common'; import { Logger, Inject } from '@nestjs/common';
import { import {
OnGatewayConnection, OnGatewayConnection,
OnGatewayDisconnect, OnGatewayDisconnect,
@@ -9,8 +9,12 @@ import {
WsException, WsException,
} from '@nestjs/websockets'; } from '@nestjs/websockets';
import { OnEvent } from '@nestjs/event-emitter'; import { OnEvent } from '@nestjs/event-emitter';
import Redis from 'ioredis';
import type { Server } from 'socket.io'; import type { Server } from 'socket.io';
import {
REDIS_CLIENT,
REDIS_SUBSCRIBER_CLIENT,
} from '../../database/redis.module';
import type { AuthenticatedSocket } from '../../types/socket'; import type { AuthenticatedSocket } from '../../types/socket';
import { AuthService } from '../../auth/auth.service'; import { AuthService } from '../../auth/auth.service';
import { JwtVerificationService } from '../../auth/services/jwt-verification.service'; import { JwtVerificationService } from '../../auth/services/jwt-verification.service';
@@ -19,6 +23,7 @@ import { PrismaService } from '../../database/prisma.service';
import { UserSocketService } from './user-socket.service'; import { UserSocketService } from './user-socket.service';
import { FriendEvents } from '../../friends/events/friend.events'; import { FriendEvents } from '../../friends/events/friend.events';
import type { import type {
FriendRequestReceivedEvent, FriendRequestReceivedEvent,
FriendRequestAcceptedEvent, FriendRequestAcceptedEvent,
@@ -37,6 +42,8 @@ import { UserEvents } from '../../users/events/user.events';
import type { UserActiveDollChangedEvent } from '../../users/events/user.events'; import type { UserActiveDollChangedEvent } from '../../users/events/user.events';
const WS_EVENT = { const WS_EVENT = {
CLIENT_INITIALIZE: 'client-initialize',
INITIALIZED: 'initialized',
CURSOR_REPORT_POSITION: 'cursor-report-position', CURSOR_REPORT_POSITION: 'cursor-report-position',
FRIEND_REQUEST_RECEIVED: 'friend-request-received', FRIEND_REQUEST_RECEIVED: 'friend-request-received',
FRIEND_REQUEST_ACCEPTED: 'friend-request-accepted', FRIEND_REQUEST_ACCEPTED: 'friend-request-accepted',
@@ -50,6 +57,10 @@ const WS_EVENT = {
FRIEND_ACTIVE_DOLL_CHANGED: 'friend-active-doll-changed', FRIEND_ACTIVE_DOLL_CHANGED: 'friend-active-doll-changed',
} as const; } as const;
const REDIS_CHANNEL = {
ACTIVE_DOLL_UPDATE: 'active-doll-update',
} as const;
@WebSocketGateway({ @WebSocketGateway({
cors: { cors: {
origin: true, origin: true,
@@ -69,12 +80,78 @@ export class StateGateway
private readonly jwtVerificationService: JwtVerificationService, private readonly jwtVerificationService: JwtVerificationService,
private readonly prisma: PrismaService, private readonly prisma: PrismaService,
private readonly userSocketService: UserSocketService, private readonly userSocketService: UserSocketService,
) {} @Inject(REDIS_CLIENT) private readonly redisClient: Redis | null,
@Inject(REDIS_SUBSCRIBER_CLIENT)
private readonly redisSubscriber: Redis | null,
) {
// Setup Redis subscription for cross-instance communication
if (this.redisSubscriber) {
this.redisSubscriber
.subscribe(REDIS_CHANNEL.ACTIVE_DOLL_UPDATE, (err) => {
if (err) {
this.logger.error(
`Failed to subscribe to ${REDIS_CHANNEL.ACTIVE_DOLL_UPDATE}`,
err,
);
} else {
this.logger.log(
`Subscribed to ${REDIS_CHANNEL.ACTIVE_DOLL_UPDATE} channel`,
);
}
})
.catch((err) => {
this.logger.error(
`Error subscribing to ${REDIS_CHANNEL.ACTIVE_DOLL_UPDATE}`,
err,
);
});
this.redisSubscriber.on('message', (channel, message) => {
if (channel === REDIS_CHANNEL.ACTIVE_DOLL_UPDATE) {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
this.handleActiveDollUpdateMessage(message);
}
});
}
}
afterInit() { afterInit() {
this.logger.log('Initialized'); this.logger.log('Initialized');
} }
private async handleActiveDollUpdateMessage(message: string) {
try {
const data = JSON.parse(message) as {
userId: string;
dollId: string | null;
};
const { userId, dollId } = data;
// Check if the user is connected to THIS instance
// Note: We need a local way to check if we hold the socket connection.
// io.sockets.sockets is a Map of all connected sockets on this server instance.
// We first get the socket ID from the shared store (UserSocketService)
// to see which socket ID belongs to the user.
const socketId = await this.userSocketService.getSocket(userId);
if (socketId) {
// Now check if we actually have this socket locally
const localSocket = this.io.sockets.sockets.get(socketId);
if (localSocket) {
// We own this connection! Update the local state.
const authSocket = localSocket as AuthenticatedSocket;
authSocket.data.activeDollId = dollId;
this.logger.debug(
`Updated activeDollId locally for user ${userId} to ${dollId}`,
);
}
}
} catch (error) {
this.logger.error('Error handling redis message', error);
}
}
async handleConnection(client: AuthenticatedSocket) { async handleConnection(client: AuthenticatedSocket) {
try { try {
this.logger.debug( this.logger.debug(
@@ -106,25 +183,12 @@ export class StateGateway
picture: payload.picture, picture: payload.picture,
}; };
this.logger.log(`WebSocket authenticated: ${payload.sub}`); // Initialize defaults
client.data.activeDollId = null;
client.data.friends = new Set();
// userId is not set yet, it will be set in handleClientInitialize
const user = await this.authService.syncUserFromToken(client.data.user); this.logger.log(`WebSocket authenticated (Pending Init): ${payload.sub}`);
await this.userSocketService.setSocket(user.id, client.id);
client.data.userId = user.id;
// Sync active doll state to socket
const userWithDoll = await this.prisma.user.findUnique({
where: { id: user.id },
select: { activeDollId: true },
});
client.data.activeDollId = userWithDoll?.activeDollId || null;
// Initialize friends cache using Prisma directly
const friends = await this.prisma.friendship.findMany({
where: { userId: user.id },
select: { friendId: true },
});
client.data.friends = new Set(friends.map((f) => f.friendId));
const { sockets } = this.io.sockets; const { sockets } = this.io.sockets;
this.logger.log( this.logger.log(
@@ -139,6 +203,49 @@ export class StateGateway
} }
} }
@SubscribeMessage(WS_EVENT.CLIENT_INITIALIZE)
async handleClientInitialize(client: AuthenticatedSocket) {
try {
const userTokenData = client.data.user;
if (!userTokenData) {
throw new WsException('Unauthorized: No user data found');
}
// 1. Sync user from token (DB Write/Read)
const user = await this.authService.syncUserFromToken(userTokenData);
// 2. Register socket mapping (Redis Write)
await this.userSocketService.setSocket(user.id, client.id);
client.data.userId = user.id;
// 3. Fetch initial state (DB Read)
const [userWithDoll, friends] = await Promise.all([
this.prisma.user.findUnique({
where: { id: user.id },
select: { activeDollId: true },
}),
this.prisma.friendship.findMany({
where: { userId: user.id },
select: { friendId: true },
}),
]);
client.data.activeDollId = userWithDoll?.activeDollId || null;
client.data.friends = new Set(friends.map((f) => f.friendId));
this.logger.log(`Client initialized: ${user.id} (${client.id})`);
// 4. Notify client
client.emit(WS_EVENT.INITIALIZED, {
userId: user.id,
activeDollId: client.data.activeDollId,
});
} catch (error) {
this.logger.error(`Initialization error: ${error}`);
client.disconnect();
}
}
async handleDisconnect(client: AuthenticatedSocket) { async handleDisconnect(client: AuthenticatedSocket) {
const user = client.data.user; const user = client.data.user;
@@ -167,9 +274,7 @@ export class StateGateway
} }
} }
} }
// Note: We can't iterate over Redis keys easily to find socketId match without userId // If userId is undefined, client never initialized, so no cleanup needed
// The previous fallback loop over map entries is not efficient with Redis.
// We rely on client.data.userId being set correctly during connection.
} }
this.logger.log( this.logger.log(
@@ -194,13 +299,13 @@ export class StateGateway
const currentUserId = client.data.userId; const currentUserId = client.data.userId;
// Do not broadcast cursor position if user has no active doll if (!currentUserId) {
if (!client.data.activeDollId) { // User has not initialized yet
return; return;
} }
if (!currentUserId) { // Do not broadcast cursor position if user has no active doll
this.logger.warn(`Could not find user ID for client ${client.id}`); if (!client.data.activeDollId) {
return; return;
} }
@@ -408,14 +513,23 @@ export class StateGateway
async handleActiveDollChanged(payload: UserActiveDollChangedEvent) { async handleActiveDollChanged(payload: UserActiveDollChangedEvent) {
const { userId, dollId, doll } = payload; const { userId, dollId, doll } = payload;
// 1. Update the user's socket data to reflect the change // 1. Publish update to all instances via Redis so they can update local socket state
const socketId = await this.userSocketService.getSocket(userId); if (this.redisClient) {
if (socketId) { await this.redisClient.publish(
const userSocket = this.io.sockets.sockets.get( REDIS_CHANNEL.ACTIVE_DOLL_UPDATE,
socketId, JSON.stringify({ userId, dollId }),
) as AuthenticatedSocket; );
if (userSocket) { } else {
userSocket.data.activeDollId = dollId; // Fallback for single instance (no redis) - update locally directly
// This mimics what handleActiveDollUpdateMessage does
const socketId = await this.userSocketService.getSocket(userId);
if (socketId) {
const userSocket = this.io.sockets.sockets.get(
socketId,
) as AuthenticatedSocket;
if (userSocket) {
userSocket.data.activeDollId = dollId;
}
} }
} }