Compare commits

...

9 Commits

12 changed files with 353 additions and 95 deletions

View File

@@ -19,6 +19,11 @@ JWT_ISSUER=friendolls
JWT_AUDIENCE=friendolls-api JWT_AUDIENCE=friendolls-api
JWT_EXPIRES_IN_SECONDS=3600 JWT_EXPIRES_IN_SECONDS=3600
# Auth cleanup
AUTH_CLEANUP_ENABLED=true
AUTH_CLEANUP_INTERVAL_MS=900000
AUTH_SESSION_REVOKED_RETENTION_DAYS=7
# Google OAuth # Google OAuth
GOOGLE_CLIENT_ID="replace-with-google-client-id" GOOGLE_CLIENT_ID="replace-with-google-client-id"
GOOGLE_CLIENT_SECRET="replace-with-google-client-secret" GOOGLE_CLIENT_SECRET="replace-with-google-client-secret"

View File

@@ -1,6 +1,6 @@
{ {
"name": "friendolls-server", "name": "friendolls-server",
"version": "0.0.1", "version": "0.1.0",
"description": "", "description": "",
"author": "", "author": "",
"private": true, "private": true,
@@ -49,6 +49,7 @@
"jsonwebtoken": "^9.0.2", "jsonwebtoken": "^9.0.2",
"passport": "^0.7.0", "passport": "^0.7.0",
"passport-discord": "^0.1.4", "passport-discord": "^0.1.4",
"helmet": "^8.1.0",
"passport-google-oauth20": "^2.0.0", "passport-google-oauth20": "^2.0.0",
"passport-jwt": "^4.0.1", "passport-jwt": "^4.0.1",
"pg": "^8.16.3", "pg": "^8.16.3",

9
pnpm-lock.yaml generated
View File

@@ -62,6 +62,9 @@ importers:
dotenv: dotenv:
specifier: ^17.2.3 specifier: ^17.2.3
version: 17.2.3 version: 17.2.3
helmet:
specifier: ^8.1.0
version: 8.1.0
ioredis: ioredis:
specifier: ^5.8.2 specifier: ^5.8.2
version: 5.8.2 version: 5.8.2
@@ -2298,6 +2301,10 @@ packages:
resolution: {integrity: sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==} resolution: {integrity: sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==}
engines: {node: '>= 0.4'} engines: {node: '>= 0.4'}
helmet@8.1.0:
resolution: {integrity: sha512-jOiHyAZsmnr8LqoPGmCjYAaiuWwjAPLgY8ZX2XrmHawt99/u1y6RgrZMTeoPfpUbV96HOalYgz1qzkRbw54Pmg==}
engines: {node: '>=18.0.0'}
hono@4.7.10: hono@4.7.10:
resolution: {integrity: sha512-QkACju9MiN59CKSY5JsGZCYmPZkA6sIW6OFCUp7qDjZu6S6KHtJHhAc9Uy9mV9F8PJ1/HQ3ybZF2yjCa/73fvQ==} resolution: {integrity: sha512-QkACju9MiN59CKSY5JsGZCYmPZkA6sIW6OFCUp7qDjZu6S6KHtJHhAc9Uy9mV9F8PJ1/HQ3ybZF2yjCa/73fvQ==}
engines: {node: '>=16.9.0'} engines: {node: '>=16.9.0'}
@@ -6227,6 +6234,8 @@ snapshots:
dependencies: dependencies:
function-bind: 1.1.2 function-bind: 1.1.2
helmet@8.1.0: {}
hono@4.7.10: {} hono@4.7.10: {}
html-escaper@2.0.2: {} html-escaper@2.0.2: {}

View File

@@ -49,10 +49,10 @@ model User {
activeDollId String? @map("active_doll_id") activeDollId String? @map("active_doll_id")
activeDoll Doll? @relation("ActiveDoll", fields: [activeDollId], references: [id]) activeDoll Doll? @relation("ActiveDoll", fields: [activeDollId], references: [id])
sentFriendRequests FriendRequest[] @relation("SentFriendRequests") sentFriendRequests FriendRequest[] @relation("SentFriendRequests")
receivedFriendRequests FriendRequest[] @relation("ReceivedFriendRequests") receivedFriendRequests FriendRequest[] @relation("ReceivedFriendRequests")
userFriendships Friendship[] @relation("UserFriendships") userFriendships Friendship[] @relation("UserFriendships")
friendFriendships Friendship[] @relation("FriendFriendships") friendFriendships Friendship[] @relation("FriendFriendships")
dolls Doll[] dolls Doll[]
authIdentities AuthIdentity[] authIdentities AuthIdentity[]
authSessions AuthSession[] authSessions AuthSession[]
@@ -62,17 +62,17 @@ model User {
} }
model AuthIdentity { model AuthIdentity {
id String @id @default(uuid()) id String @id @default(uuid())
provider AuthProvider provider AuthProvider
providerSubject String @map("provider_subject") providerSubject String @map("provider_subject")
providerEmail String? @map("provider_email") providerEmail String? @map("provider_email")
providerName String? @map("provider_name") providerName String? @map("provider_name")
providerUsername String? @map("provider_username") providerUsername String? @map("provider_username")
providerPicture String? @map("provider_picture") providerPicture String? @map("provider_picture")
emailVerified Boolean @default(false) @map("email_verified") emailVerified Boolean @default(false) @map("email_verified")
createdAt DateTime @default(now()) @map("created_at") createdAt DateTime @default(now()) @map("created_at")
updatedAt DateTime @updatedAt @map("updated_at") updatedAt DateTime @updatedAt @map("updated_at")
userId String @map("user_id") userId String @map("user_id")
user User @relation(fields: [userId], references: [id], onDelete: Cascade) user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@ -82,14 +82,14 @@ model AuthIdentity {
} }
model AuthSession { model AuthSession {
id String @id @default(uuid()) id String @id @default(uuid())
provider AuthProvider? provider AuthProvider?
refreshTokenHash String @unique @map("refresh_token_hash") refreshTokenHash String @unique @map("refresh_token_hash")
expiresAt DateTime @map("expires_at") expiresAt DateTime @map("expires_at")
revokedAt DateTime? @map("revoked_at") revokedAt DateTime? @map("revoked_at")
createdAt DateTime @default(now()) @map("created_at") createdAt DateTime @default(now()) @map("created_at")
updatedAt DateTime @updatedAt @map("updated_at") updatedAt DateTime @updatedAt @map("updated_at")
userId String @map("user_id") userId String @map("user_id")
user User @relation(fields: [userId], references: [id], onDelete: Cascade) user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@ -98,13 +98,13 @@ model AuthSession {
} }
model AuthExchangeCode { model AuthExchangeCode {
id String @id @default(uuid()) id String @id @default(uuid())
provider AuthProvider provider AuthProvider
codeHash String @unique @map("code_hash") codeHash String @unique @map("code_hash")
expiresAt DateTime @map("expires_at") expiresAt DateTime @map("expires_at")
consumedAt DateTime? @map("consumed_at") consumedAt DateTime? @map("consumed_at")
createdAt DateTime @default(now()) @map("created_at") createdAt DateTime @default(now()) @map("created_at")
userId String @map("user_id") userId String @map("user_id")
user User @relation(fields: [userId], references: [id], onDelete: Cascade) user User @relation(fields: [userId], references: [id], onDelete: Cascade)

View File

@@ -10,6 +10,7 @@ import { UsersModule } from '../users/users.module';
import { AuthController } from './auth.controller'; import { AuthController } from './auth.controller';
import { GoogleAuthGuard } from './guards/google-auth.guard'; import { GoogleAuthGuard } from './guards/google-auth.guard';
import { DiscordAuthGuard } from './guards/discord-auth.guard'; import { DiscordAuthGuard } from './guards/discord-auth.guard';
import { AuthCleanupService } from './services/auth-cleanup.service';
@Module({ @Module({
imports: [ imports: [
@@ -26,6 +27,7 @@ import { DiscordAuthGuard } from './guards/discord-auth.guard';
DiscordAuthGuard, DiscordAuthGuard,
AuthService, AuthService,
JwtVerificationService, JwtVerificationService,
AuthCleanupService,
], ],
exports: [AuthService, PassportModule, JwtVerificationService], exports: [AuthService, PassportModule, JwtVerificationService],
}) })

View File

@@ -0,0 +1,159 @@
import {
Injectable,
Inject,
Logger,
OnModuleDestroy,
OnModuleInit,
} from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { PrismaService } from '../../database/prisma.service';
import Redis from 'ioredis';
import {
parseBoolean,
parsePositiveInteger,
} from '../../common/config/env.utils';
import { REDIS_CLIENT } from '../../database/redis.module';
const MIN_CLEANUP_INTERVAL_MS = 60_000;
const DEFAULT_CLEANUP_INTERVAL_MS = 15 * 60_000;
const DEFAULT_REVOKED_RETENTION_DAYS = 7;
const CLEANUP_LOCK_KEY = 'lock:auth:cleanup';
const CLEANUP_LOCK_TTL_MS = 55_000;
@Injectable()
export class AuthCleanupService implements OnModuleInit, OnModuleDestroy {
private readonly logger = new Logger(AuthCleanupService.name);
private cleanupTimer: NodeJS.Timeout | null = null;
private isCleanupRunning = false;
constructor(
private readonly prisma: PrismaService,
private readonly configService: ConfigService,
@Inject(REDIS_CLIENT) private readonly redisClient: Redis | null,
) {}
onModuleInit(): void {
const enabled = parseBoolean(
this.configService.get<string>('AUTH_CLEANUP_ENABLED'),
true,
);
if (!enabled) {
this.logger.log('Auth cleanup task disabled');
return;
}
const configuredInterval = parsePositiveInteger(
this.configService.get<string>('AUTH_CLEANUP_INTERVAL_MS'),
DEFAULT_CLEANUP_INTERVAL_MS,
);
const cleanupIntervalMs = Math.max(
configuredInterval,
MIN_CLEANUP_INTERVAL_MS,
);
this.cleanupTimer = setInterval(() => {
void this.cleanupExpiredAuthData();
}, cleanupIntervalMs);
this.cleanupTimer.unref();
void this.cleanupExpiredAuthData();
this.logger.log(`Auth cleanup task scheduled every ${cleanupIntervalMs}ms`);
}
onModuleDestroy(): void {
if (!this.cleanupTimer) {
return;
}
clearInterval(this.cleanupTimer);
this.cleanupTimer = null;
}
private async cleanupExpiredAuthData(): Promise<void> {
if (this.isCleanupRunning) {
this.logger.warn(
'Skipping auth cleanup run because previous run is still in progress',
);
return;
}
this.isCleanupRunning = true;
const lockToken = `${process.pid}-${Date.now()}-${Math.random().toString(36).slice(2)}`;
let lockAcquired = false;
try {
if (this.redisClient) {
try {
const lockResult = await this.redisClient.set(
CLEANUP_LOCK_KEY,
lockToken,
'PX',
CLEANUP_LOCK_TTL_MS,
'NX',
);
if (lockResult !== 'OK') {
return;
}
lockAcquired = true;
} catch (error) {
this.logger.warn(
'Failed to acquire auth cleanup lock; running cleanup without distributed lock',
error as Error,
);
}
}
const now = new Date();
const revokedRetentionDays = parsePositiveInteger(
this.configService.get<string>('AUTH_SESSION_REVOKED_RETENTION_DAYS'),
DEFAULT_REVOKED_RETENTION_DAYS,
);
const revokedCutoff = new Date(
now.getTime() - revokedRetentionDays * 24 * 60 * 60 * 1000,
);
const [codes, sessions] = await Promise.all([
this.prisma.authExchangeCode.deleteMany({
where: {
OR: [{ expiresAt: { lt: now } }, { consumedAt: { not: null } }],
},
}),
this.prisma.authSession.deleteMany({
where: {
OR: [
{ expiresAt: { lt: now } },
{ revokedAt: { lt: revokedCutoff } },
],
},
}),
]);
const totalDeleted = codes.count + sessions.count;
if (totalDeleted > 0) {
this.logger.log(
`Auth cleanup removed ${totalDeleted} records (${codes.count} exchange codes, ${sessions.count} sessions)`,
);
}
} catch (error) {
this.logger.error('Auth cleanup task failed', error as Error);
} finally {
if (lockAcquired && this.redisClient) {
try {
const currentLockValue = await this.redisClient.get(CLEANUP_LOCK_KEY);
if (currentLockValue === lockToken) {
await this.redisClient.del(CLEANUP_LOCK_KEY);
}
} catch (error) {
this.logger.warn(
'Failed to release auth cleanup lock',
error as Error,
);
}
}
this.isCleanupRunning = false;
}
}
}

View File

@@ -40,6 +40,7 @@ export class PrismaService
implements OnModuleInit, OnModuleDestroy implements OnModuleInit, OnModuleDestroy
{ {
private readonly logger = new Logger(PrismaService.name); private readonly logger = new Logger(PrismaService.name);
private readonly pool: Pool;
constructor(private configService: ConfigService) { constructor(private configService: ConfigService) {
const databaseUrl = configService.get<string>('DATABASE_URL'); const databaseUrl = configService.get<string>('DATABASE_URL');
@@ -62,6 +63,8 @@ export class PrismaService
], ],
}); });
this.pool = pool;
// Log database queries in development mode // Log database queries in development mode
if (process.env.NODE_ENV === 'development') { if (process.env.NODE_ENV === 'development') {
this.$on('query' as never, (e: QueryEvent) => { this.$on('query' as never, (e: QueryEvent) => {
@@ -101,6 +104,7 @@ export class PrismaService
async onModuleDestroy() { async onModuleDestroy() {
try { try {
await this.$disconnect(); await this.$disconnect();
await this.pool.end();
this.logger.log('Successfully disconnected from database'); this.logger.log('Successfully disconnected from database');
} catch (error) { } catch (error) {
this.logger.error('Error disconnecting from database', error); this.logger.error('Error disconnecting from database', error);

View File

@@ -2,6 +2,7 @@ import { NestFactory } from '@nestjs/core';
import { ValidationPipe, Logger } from '@nestjs/common'; import { ValidationPipe, Logger } from '@nestjs/common';
import { ConfigService } from '@nestjs/config'; import { ConfigService } from '@nestjs/config';
import { DocumentBuilder, SwaggerModule } from '@nestjs/swagger'; import { DocumentBuilder, SwaggerModule } from '@nestjs/swagger';
import helmet from 'helmet';
import { AppModule } from './app.module'; import { AppModule } from './app.module';
import { AllExceptionsFilter } from './common/filters/all-exceptions.filter'; import { AllExceptionsFilter } from './common/filters/all-exceptions.filter';
import { RedisIoAdapter } from './ws/redis-io.adapter'; import { RedisIoAdapter } from './ws/redis-io.adapter';
@@ -10,12 +11,28 @@ async function bootstrap() {
const logger = new Logger('Bootstrap'); const logger = new Logger('Bootstrap');
const app = await NestFactory.create(AppModule); const app = await NestFactory.create(AppModule);
const configService = app.get(ConfigService); const configService = app.get(ConfigService);
const nodeEnv = configService.get<string>('NODE_ENV') || 'development';
const isProduction = nodeEnv === 'production';
app.enableShutdownHooks();
app.use(
helmet({
contentSecurityPolicy: false,
crossOriginEmbedderPolicy: false,
}),
);
// Configure Redis Adapter for horizontal scaling (if enabled) // Configure Redis Adapter for horizontal scaling (if enabled)
const redisIoAdapter = new RedisIoAdapter(app, configService); const redisIoAdapter = new RedisIoAdapter(app, configService);
await redisIoAdapter.connectToRedis(); await redisIoAdapter.connectToRedis();
app.useWebSocketAdapter(redisIoAdapter); app.useWebSocketAdapter(redisIoAdapter);
app.enableCors({
origin: true,
credentials: true,
});
// Enable global exception filter for consistent error responses // Enable global exception filter for consistent error responses
app.useGlobalFilters(new AllExceptionsFilter()); app.useGlobalFilters(new AllExceptionsFilter());
@@ -29,43 +46,54 @@ async function bootstrap() {
// Automatically transform payloads to DTO instances // Automatically transform payloads to DTO instances
transform: true, transform: true,
// Provide detailed error messages // Provide detailed error messages
disableErrorMessages: false, disableErrorMessages: isProduction,
}), }),
); );
// Configure Swagger documentation if (!isProduction) {
const config = new DocumentBuilder() const config = new DocumentBuilder()
.setTitle('Friendolls API') .setTitle('Friendolls API')
.setDescription( .setDescription(
'API for managing users in Friendolls application.\n\n' + 'API for managing users in Friendolls application.\n\n' +
'Authentication is handled via Passport.js social sign-in for desktop clients.\n' + 'Authentication is handled via Passport.js social sign-in for desktop clients.\n' +
'Desktop clients exchange one-time SSO codes for Friendolls JWT tokens.\n\n' + 'Desktop clients exchange one-time SSO codes for Friendolls JWT tokens.\n\n' +
'Include the JWT token in the Authorization header as: `Bearer <token>`', 'Include the JWT token in the Authorization header as: `Bearer <token>`',
) )
.setVersion('1.0') .setVersion('1.0')
.addBearerAuth( .addBearerAuth(
{ {
type: 'http', type: 'http',
scheme: 'bearer', scheme: 'bearer',
bearerFormat: 'JWT', bearerFormat: 'JWT',
name: 'Authorization', name: 'Authorization',
description: 'Enter Friendolls JWT access token', description: 'Enter Friendolls JWT access token',
in: 'header', in: 'header',
}, },
'bearer', 'bearer',
) )
.addTag('users', 'User profile management endpoints') .addTag('users', 'User profile management endpoints')
.build(); .build();
const document = SwaggerModule.createDocument(app, config); const document = SwaggerModule.createDocument(app, config);
SwaggerModule.setup('api', app, document); SwaggerModule.setup('api', app, document);
}
const host = process.env.HOST ?? 'localhost'; const host = process.env.HOST ?? 'localhost';
const port = process.env.PORT ?? 3000; const port = process.env.PORT ?? 3000;
await app.listen(port); await app.listen(port);
const httpServer = app.getHttpServer() as {
once?: (event: 'close', listener: () => void) => void;
} | null;
httpServer?.once?.('close', () => {
void redisIoAdapter.close();
});
logger.log(`Application is running on: http://${host}:${port}`); logger.log(`Application is running on: http://${host}:${port}`);
logger.log(`Swagger documentation available at: http://${host}:${port}/api`); if (!isProduction) {
logger.log(
`Swagger documentation available at: http://${host}:${port}/api`,
);
}
} }
void bootstrap(); void bootstrap();

View File

@@ -22,6 +22,7 @@ describe('UsersController', () => {
const mockAuthUser: AuthenticatedUser = { const mockAuthUser: AuthenticatedUser = {
userId: 'uuid-123', userId: 'uuid-123',
email: 'test@example.com', email: 'test@example.com',
tokenType: 'access',
roles: ['user'], roles: ['user'],
}; };

View File

@@ -9,6 +9,8 @@ import { WsNotificationService } from '../ws-notification.service';
import { WS_EVENT } from '../ws-events'; import { WS_EVENT } from '../ws-events';
import { Validator } from '../utils/validation'; import { Validator } from '../utils/validation';
const SENDER_NAME_CACHE_TTL_MS = 10 * 60 * 1000;
export class InteractionHandler { export class InteractionHandler {
private readonly logger = new Logger(InteractionHandler.name); private readonly logger = new Logger(InteractionHandler.name);
@@ -18,6 +20,32 @@ export class InteractionHandler {
private readonly wsNotificationService: WsNotificationService, private readonly wsNotificationService: WsNotificationService,
) {} ) {}
private async resolveSenderName(
client: AuthenticatedSocket,
userId: string,
): Promise<string> {
const cachedName = client.data.senderName;
const cachedAt = client.data.senderNameCachedAt;
const cacheIsFresh =
cachedName &&
typeof cachedAt === 'number' &&
Date.now() - cachedAt < SENDER_NAME_CACHE_TTL_MS;
if (cacheIsFresh) {
return cachedName;
}
const sender = await this.prisma.user.findUnique({
where: { id: userId },
select: { name: true, username: true },
});
const senderName = sender?.name || sender?.username || 'Unknown';
client.data.senderName = senderName;
client.data.senderNameCachedAt = Date.now();
return senderName;
}
async handleSendInteraction( async handleSendInteraction(
client: AuthenticatedSocket, client: AuthenticatedSocket,
data: SendInteractionDto, data: SendInteractionDto,
@@ -61,11 +89,7 @@ export class InteractionHandler {
} }
// 3. Construct payload // 3. Construct payload
const sender = await this.prisma.user.findUnique({ const senderName = await this.resolveSenderName(client, currentUserId);
where: { id: currentUserId },
select: { name: true, username: true },
});
const senderName = sender?.name || sender?.username || 'Unknown';
const payload: InteractionPayloadDto = { const payload: InteractionPayloadDto = {
senderUserId: currentUserId, senderUserId: currentUserId,

View File

@@ -3,7 +3,6 @@ import { Test, TestingModule } from '@nestjs/testing';
import { StateGateway } from './state.gateway'; import { StateGateway } from './state.gateway';
import { AuthenticatedSocket } from '../../types/socket'; import { AuthenticatedSocket } from '../../types/socket';
import { JwtVerificationService } from '../../auth/services/jwt-verification.service'; import { JwtVerificationService } from '../../auth/services/jwt-verification.service';
import { UsersService } from '../../users/users.service';
import { PrismaService } from '../../database/prisma.service'; import { PrismaService } from '../../database/prisma.service';
import { UserSocketService } from './user-socket.service'; import { UserSocketService } from './user-socket.service';
import { WsNotificationService } from './ws-notification.service'; import { WsNotificationService } from './ws-notification.service';
@@ -23,6 +22,8 @@ type MockSocket = {
userId?: string; userId?: string;
activeDollId?: string | null; activeDollId?: string | null;
friends?: Set<string>; friends?: Set<string>;
senderName?: string;
senderNameCachedAt?: number;
}; };
handshake?: any; handshake?: any;
disconnect?: jest.Mock; disconnect?: jest.Mock;
@@ -39,7 +40,6 @@ describe('StateGateway', () => {
sockets: { sockets: { size: number; get: jest.Mock } }; sockets: { sockets: { size: number; get: jest.Mock } };
to: jest.Mock; to: jest.Mock;
}; };
let mockUsersService: Partial<UsersService>;
let mockJwtVerificationService: Partial<JwtVerificationService>; let mockJwtVerificationService: Partial<JwtVerificationService>;
let mockPrismaService: Partial<PrismaService>; let mockPrismaService: Partial<PrismaService>;
let mockUserSocketService: Partial<UserSocketService>; let mockUserSocketService: Partial<UserSocketService>;
@@ -67,12 +67,6 @@ describe('StateGateway', () => {
}), }),
}; };
mockUsersService = {
findOne: jest.fn().mockResolvedValue({
id: 'user-id',
}),
};
mockJwtVerificationService = { mockJwtVerificationService = {
extractToken: jest.fn((handshake) => handshake.auth?.token), extractToken: jest.fn((handshake) => handshake.auth?.token),
verifyToken: jest.fn().mockReturnValue({ verifyToken: jest.fn().mockReturnValue({
@@ -83,7 +77,12 @@ describe('StateGateway', () => {
mockPrismaService = { mockPrismaService = {
user: { user: {
findUnique: jest.fn().mockResolvedValue({ activeDollId: 'doll-123' }), findUnique: jest.fn().mockResolvedValue({
id: 'user-id',
name: 'Test User',
username: 'test-user',
activeDollId: 'doll-123',
}),
} as any, } as any,
friendship: { friendship: {
findMany: jest.fn().mockResolvedValue([]), findMany: jest.fn().mockResolvedValue([]),
@@ -119,7 +118,6 @@ describe('StateGateway', () => {
const module: TestingModule = await Test.createTestingModule({ const module: TestingModule = await Test.createTestingModule({
providers: [ providers: [
StateGateway, StateGateway,
{ provide: UsersService, useValue: mockUsersService },
{ {
provide: JwtVerificationService, provide: JwtVerificationService,
useValue: mockJwtVerificationService, useValue: mockJwtVerificationService,
@@ -190,7 +188,6 @@ describe('StateGateway', () => {
); );
// Should NOT call these anymore in handleConnection // Should NOT call these anymore in handleConnection
expect(mockUsersService.findOne).not.toHaveBeenCalled();
expect(mockUserSocketService.setSocket).not.toHaveBeenCalled(); expect(mockUserSocketService.setSocket).not.toHaveBeenCalled();
// Should set data on client // Should set data on client
@@ -244,6 +241,9 @@ describe('StateGateway', () => {
// Mock Prisma responses // Mock Prisma responses
(mockPrismaService.user!.findUnique as jest.Mock).mockResolvedValue({ (mockPrismaService.user!.findUnique as jest.Mock).mockResolvedValue({
id: 'user-id',
name: 'Test User',
username: 'test-user',
activeDollId: 'doll-123', activeDollId: 'doll-123',
}); });
(mockPrismaService.friendship!.findMany as jest.Mock).mockResolvedValue([ (mockPrismaService.friendship!.findMany as jest.Mock).mockResolvedValue([
@@ -255,32 +255,29 @@ describe('StateGateway', () => {
mockClient as unknown as AuthenticatedSocket, mockClient as unknown as AuthenticatedSocket,
); );
// 1. Load User // 1. Set Socket
expect(mockUsersService.findOne).toHaveBeenCalledWith('test-sub');
// 2. Set Socket
expect(mockUserSocketService.setSocket).toHaveBeenCalledWith( expect(mockUserSocketService.setSocket).toHaveBeenCalledWith(
'user-id', 'user-id',
'client1', 'client1',
); );
// 3. Fetch State (DB) // 2. Fetch State (DB)
expect(mockPrismaService.user!.findUnique).toHaveBeenCalledWith({ expect(mockPrismaService.user!.findUnique).toHaveBeenCalledWith({
where: { id: 'user-id' }, where: { id: 'test-sub' },
select: { activeDollId: true }, select: { id: true, name: true, username: true, activeDollId: true },
}); });
expect(mockPrismaService.friendship!.findMany).toHaveBeenCalledWith({ expect(mockPrismaService.friendship!.findMany).toHaveBeenCalledWith({
where: { userId: 'user-id' }, where: { userId: 'test-sub' },
select: { friendId: true }, select: { friendId: true },
}); });
// 4. Update Client Data // 3. Update Client Data
expect(mockClient.data.userId).toBe('user-id'); expect(mockClient.data.userId).toBe('user-id');
expect(mockClient.data.activeDollId).toBe('doll-123'); expect(mockClient.data.activeDollId).toBe('doll-123');
expect(mockClient.data.friends).toContain('friend-1'); expect(mockClient.data.friends).toContain('friend-1');
expect(mockClient.data.friends).toContain('friend-2'); expect(mockClient.data.friends).toContain('friend-2');
// 5. Emit Initialized // 4. Emit Initialized
expect(mockClient.emit).toHaveBeenCalledWith('initialized', { expect(mockClient.emit).toHaveBeenCalledWith('initialized', {
userId: 'user-id', userId: 'user-id',
activeDollId: 'doll-123', activeDollId: 'doll-123',

View File

@@ -48,13 +48,27 @@ export class WsNotificationService {
action: 'add' | 'delete', action: 'add' | 'delete',
) { ) {
if (this.redisClient) { if (this.redisClient) {
await this.redisClient.publish( try {
REDIS_CHANNEL.FRIEND_CACHE_UPDATE, await this.redisClient.publish(
JSON.stringify({ userId, friendId, action }), REDIS_CHANNEL.FRIEND_CACHE_UPDATE,
); JSON.stringify({ userId, friendId, action }),
} else { );
// Fallback: update locally return;
} catch (error) {
this.logger.warn(
'Redis publish failed for friend cache update; applying local cache update only',
error as Error,
);
}
}
try {
await this.updateFriendsCacheLocal(userId, friendId, action); await this.updateFriendsCacheLocal(userId, friendId, action);
} catch (error) {
this.logger.error(
'Failed to apply local friend cache update',
error as Error,
);
} }
} }
@@ -89,13 +103,27 @@ export class WsNotificationService {
async publishActiveDollUpdate(userId: string, dollId: string | null) { async publishActiveDollUpdate(userId: string, dollId: string | null) {
if (this.redisClient) { if (this.redisClient) {
await this.redisClient.publish( try {
REDIS_CHANNEL.ACTIVE_DOLL_UPDATE, await this.redisClient.publish(
JSON.stringify({ userId, dollId }), REDIS_CHANNEL.ACTIVE_DOLL_UPDATE,
); JSON.stringify({ userId, dollId }),
} else { );
// Fallback: update locally return;
} catch (error) {
this.logger.warn(
'Redis publish failed for active doll update; applying local cache update only',
error as Error,
);
}
}
try {
await this.updateActiveDollCache(userId, dollId); await this.updateActiveDollCache(userId, dollId);
} catch (error) {
this.logger.error(
'Failed to apply local active doll cache update',
error as Error,
);
} }
} }
} }