diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..9724dc2 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,13 @@ +.git +.github +.vscode +.idea +.DS_Store +node_modules +dist +coverage +*.log +.env +.env.* +test +README.md diff --git a/.env.example b/.env.example index 4e2e5bf..c0129c6 100644 --- a/.env.example +++ b/.env.example @@ -9,6 +9,9 @@ DATABASE_URL="postgresql://postgres:postgres@localhost:5432/friendolls_dev?schem # Redis REDIS_HOST=localhost REDIS_PORT=6379 +REDIS_REQUIRED=false +REDIS_CONNECT_TIMEOUT_MS=5000 +REDIS_STARTUP_RETRIES=10 # JWT Configuration JWT_SECRET=replace-with-strong-random-secret @@ -16,6 +19,11 @@ JWT_ISSUER=friendolls JWT_AUDIENCE=friendolls-api JWT_EXPIRES_IN_SECONDS=3600 +# Auth cleanup +AUTH_CLEANUP_ENABLED=true +AUTH_CLEANUP_INTERVAL_MS=900000 +AUTH_SESSION_REVOKED_RETENTION_DAYS=7 + # Google OAuth GOOGLE_CLIENT_ID="replace-with-google-client-id" GOOGLE_CLIENT_SECRET="replace-with-google-client-secret" diff --git a/Dockerfile b/Dockerfile index d0868a5..b2f7c96 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,16 +1,28 @@ -FROM node:20-alpine AS builder +FROM node:24-alpine AS base +ENV PNPM_HOME="/pnpm" +ENV PATH="$PNPM_HOME:$PATH" +RUN corepack enable + +FROM base AS deps WORKDIR /app COPY package.json pnpm-lock.yaml ./ -RUN npm i -g pnpm && pnpm install --frozen-lockfile +RUN pnpm install --frozen-lockfile + +FROM deps AS builder +WORKDIR /app COPY . . RUN DATABASE_URL="postgresql://dummy:dummy@localhost:5432/dummy" pnpm prisma:generate RUN pnpm build -FROM node:20-alpine +FROM base AS runner WORKDIR /app -COPY --from=builder /app/dist ./dist -COPY --from=builder /app/node_modules ./node_modules +ENV NODE_ENV=production +RUN addgroup -S nodejs && adduser -S nestjs -G nodejs +COPY package.json pnpm-lock.yaml ./ COPY --from=builder /app/prisma ./prisma COPY --from=builder /app/prisma.config.ts ./prisma.config.ts -COPY --from=builder /app/package.json ./package.json +RUN pnpm install --frozen-lockfile --prod +COPY --from=builder /app/dist ./dist +COPY --from=builder /app/node_modules/.prisma ./node_modules/.prisma +USER nestjs CMD ["node", "dist/src/main.js"] diff --git a/package.json b/package.json index 85066a0..4300d18 100644 --- a/package.json +++ b/package.json @@ -47,6 +47,7 @@ "dotenv": "^17.2.3", "ioredis": "^5.8.2", "jsonwebtoken": "^9.0.2", + "helmet": "^8.1.0", "passport": "^0.7.0", "passport-discord": "^0.1.4", "passport-google-oauth20": "^2.0.0", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 7dafdcd..e9731e4 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -62,6 +62,9 @@ importers: dotenv: specifier: ^17.2.3 version: 17.2.3 + helmet: + specifier: ^8.1.0 + version: 8.1.0 ioredis: specifier: ^5.8.2 version: 5.8.2 @@ -2298,6 +2301,10 @@ packages: resolution: {integrity: sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==} engines: {node: '>= 0.4'} + helmet@8.1.0: + resolution: {integrity: sha512-jOiHyAZsmnr8LqoPGmCjYAaiuWwjAPLgY8ZX2XrmHawt99/u1y6RgrZMTeoPfpUbV96HOalYgz1qzkRbw54Pmg==} + engines: {node: '>=18.0.0'} + hono@4.7.10: resolution: {integrity: sha512-QkACju9MiN59CKSY5JsGZCYmPZkA6sIW6OFCUp7qDjZu6S6KHtJHhAc9Uy9mV9F8PJ1/HQ3ybZF2yjCa/73fvQ==} engines: {node: '>=16.9.0'} @@ -6227,6 +6234,8 @@ snapshots: dependencies: function-bind: 1.1.2 + helmet@8.1.0: {} + hono@4.7.10: {} html-escaper@2.0.2: {} diff --git a/src/app.module.ts b/src/app.module.ts index 4ddacf7..0ac0a54 100644 --- a/src/app.module.ts +++ b/src/app.module.ts @@ -12,13 +12,24 @@ import { RedisModule } from './database/redis.module'; import { WsModule } from './ws/ws.module'; import { FriendsModule } from './friends/friends.module'; import { DollsModule } from './dolls/dolls.module'; +import { parseRedisRequired } from './common/config/env.utils'; /** * Validates required environment variables. * Throws an error if any required variables are missing or invalid. * Returns the validated config. */ -function validateEnvironment(config: Record): Record { +function getOptionalEnvString( + config: Record, + key: string, +): string | undefined { + const value = config[key]; + return typeof value === 'string' ? value : undefined; +} + +function validateEnvironment( + config: Record, +): Record { const requiredVars = ['JWT_SECRET', 'DATABASE_URL']; const missingVars = requiredVars.filter((varName) => !config[varName]); @@ -34,6 +45,40 @@ function validateEnvironment(config: Record): Record { throw new Error('PORT must be a valid number'); } + if (config.NODE_ENV === 'production') { + if ( + typeof config.JWT_SECRET !== 'string' || + config.JWT_SECRET.length < 32 + ) { + throw new Error( + 'JWT_SECRET must be at least 32 characters in production', + ); + } + } + + const redisRequired = parseRedisRequired({ + nodeEnv: getOptionalEnvString(config, 'NODE_ENV'), + redisRequired: getOptionalEnvString(config, 'REDIS_REQUIRED'), + }); + + if (redisRequired && !config.REDIS_HOST) { + throw new Error( + 'REDIS_REQUIRED is enabled but REDIS_HOST is not configured', + ); + } + + const redisConnectTimeout = getOptionalEnvString( + config, + 'REDIS_CONNECT_TIMEOUT_MS', + ); + if ( + redisConnectTimeout !== undefined && + (!Number.isFinite(Number(redisConnectTimeout)) || + Number(redisConnectTimeout) <= 0) + ) { + throw new Error('REDIS_CONNECT_TIMEOUT_MS must be a positive number'); + } + validateOptionalProvider(config, 'GOOGLE'); validateOptionalProvider(config, 'DISCORD'); @@ -41,7 +86,7 @@ function validateEnvironment(config: Record): Record { } function validateOptionalProvider( - config: Record, + config: Record, provider: 'GOOGLE' | 'DISCORD', ): void { const vars = [ diff --git a/src/auth/auth.module.ts b/src/auth/auth.module.ts index 4bf688a..3723de0 100644 --- a/src/auth/auth.module.ts +++ b/src/auth/auth.module.ts @@ -10,6 +10,7 @@ import { UsersModule } from '../users/users.module'; import { AuthController } from './auth.controller'; import { GoogleAuthGuard } from './guards/google-auth.guard'; import { DiscordAuthGuard } from './guards/discord-auth.guard'; +import { AuthCleanupService } from './services/auth-cleanup.service'; @Module({ imports: [ @@ -26,6 +27,7 @@ import { DiscordAuthGuard } from './guards/discord-auth.guard'; DiscordAuthGuard, AuthService, JwtVerificationService, + AuthCleanupService, ], exports: [AuthService, PassportModule, JwtVerificationService], }) diff --git a/src/auth/services/auth-cleanup.service.ts b/src/auth/services/auth-cleanup.service.ts new file mode 100644 index 0000000..303fe2c --- /dev/null +++ b/src/auth/services/auth-cleanup.service.ts @@ -0,0 +1,159 @@ +import { + Injectable, + Inject, + Logger, + OnModuleDestroy, + OnModuleInit, +} from '@nestjs/common'; +import { ConfigService } from '@nestjs/config'; +import { PrismaService } from '../../database/prisma.service'; +import Redis from 'ioredis'; +import { + parseBoolean, + parsePositiveInteger, +} from '../../common/config/env.utils'; +import { REDIS_CLIENT } from '../../database/redis.module'; + +const MIN_CLEANUP_INTERVAL_MS = 60_000; +const DEFAULT_CLEANUP_INTERVAL_MS = 15 * 60_000; +const DEFAULT_REVOKED_RETENTION_DAYS = 7; +const CLEANUP_LOCK_KEY = 'lock:auth:cleanup'; +const CLEANUP_LOCK_TTL_MS = 55_000; + +@Injectable() +export class AuthCleanupService implements OnModuleInit, OnModuleDestroy { + private readonly logger = new Logger(AuthCleanupService.name); + private cleanupTimer: NodeJS.Timeout | null = null; + private isCleanupRunning = false; + + constructor( + private readonly prisma: PrismaService, + private readonly configService: ConfigService, + @Inject(REDIS_CLIENT) private readonly redisClient: Redis | null, + ) {} + + onModuleInit(): void { + const enabled = parseBoolean( + this.configService.get('AUTH_CLEANUP_ENABLED'), + true, + ); + + if (!enabled) { + this.logger.log('Auth cleanup task disabled'); + return; + } + + const configuredInterval = parsePositiveInteger( + this.configService.get('AUTH_CLEANUP_INTERVAL_MS'), + DEFAULT_CLEANUP_INTERVAL_MS, + ); + const cleanupIntervalMs = Math.max( + configuredInterval, + MIN_CLEANUP_INTERVAL_MS, + ); + + this.cleanupTimer = setInterval(() => { + void this.cleanupExpiredAuthData(); + }, cleanupIntervalMs); + this.cleanupTimer.unref(); + + void this.cleanupExpiredAuthData(); + this.logger.log(`Auth cleanup task scheduled every ${cleanupIntervalMs}ms`); + } + + onModuleDestroy(): void { + if (!this.cleanupTimer) { + return; + } + + clearInterval(this.cleanupTimer); + this.cleanupTimer = null; + } + + private async cleanupExpiredAuthData(): Promise { + if (this.isCleanupRunning) { + this.logger.warn( + 'Skipping auth cleanup run because previous run is still in progress', + ); + return; + } + + this.isCleanupRunning = true; + const lockToken = `${process.pid}-${Date.now()}-${Math.random().toString(36).slice(2)}`; + let lockAcquired = false; + + try { + if (this.redisClient) { + try { + const lockResult = await this.redisClient.set( + CLEANUP_LOCK_KEY, + lockToken, + 'PX', + CLEANUP_LOCK_TTL_MS, + 'NX', + ); + if (lockResult !== 'OK') { + return; + } + lockAcquired = true; + } catch (error) { + this.logger.warn( + 'Failed to acquire auth cleanup lock; running cleanup without distributed lock', + error as Error, + ); + } + } + + const now = new Date(); + const revokedRetentionDays = parsePositiveInteger( + this.configService.get('AUTH_SESSION_REVOKED_RETENTION_DAYS'), + DEFAULT_REVOKED_RETENTION_DAYS, + ); + const revokedCutoff = new Date( + now.getTime() - revokedRetentionDays * 24 * 60 * 60 * 1000, + ); + + const [codes, sessions] = await Promise.all([ + this.prisma.authExchangeCode.deleteMany({ + where: { + OR: [{ expiresAt: { lt: now } }, { consumedAt: { not: null } }], + }, + }), + this.prisma.authSession.deleteMany({ + where: { + OR: [ + { expiresAt: { lt: now } }, + { revokedAt: { lt: revokedCutoff } }, + ], + }, + }), + ]); + + const totalDeleted = codes.count + sessions.count; + + if (totalDeleted > 0) { + this.logger.log( + `Auth cleanup removed ${totalDeleted} records (${codes.count} exchange codes, ${sessions.count} sessions)`, + ); + } + } catch (error) { + this.logger.error('Auth cleanup task failed', error as Error); + } finally { + if (lockAcquired && this.redisClient) { + try { + const currentLockValue = await this.redisClient.get(CLEANUP_LOCK_KEY); + if (currentLockValue === lockToken) { + await this.redisClient.del(CLEANUP_LOCK_KEY); + } + } catch (error) { + this.logger.warn( + 'Failed to release auth cleanup lock', + error as Error, + ); + } + } + + this.isCleanupRunning = false; + } + } +} diff --git a/src/common/config/env.utils.ts b/src/common/config/env.utils.ts new file mode 100644 index 0000000..ac8eba2 --- /dev/null +++ b/src/common/config/env.utils.ts @@ -0,0 +1,66 @@ +export function parseBoolean( + value: string | undefined, + fallback: boolean, +): boolean { + if (value === undefined) { + return fallback; + } + + const normalized = value.trim().toLowerCase(); + if (['true', '1', 'yes', 'y', 'on'].includes(normalized)) { + return true; + } + + if (['false', '0', 'no', 'n', 'off'].includes(normalized)) { + return false; + } + + return fallback; +} + +export function parsePositiveInteger( + value: string | undefined, + fallback: number, +): number { + if (!value) { + return fallback; + } + + const parsed = Number(value); + if (!Number.isFinite(parsed) || parsed <= 0) { + return fallback; + } + + return Math.floor(parsed); +} + +export function parseCsvList(value: string | undefined): string[] { + if (!value) { + return []; + } + + return value + .split(',') + .map((item) => item.trim()) + .filter((item) => item.length > 0); +} + +export function isLikelyHttpOrigin(origin: string): boolean { + try { + const parsed = new URL(origin); + return parsed.protocol === 'http:' || parsed.protocol === 'https:'; + } catch { + return false; + } +} + +export function parseRedisRequired(config: { + nodeEnv?: string; + redisRequired?: string; +}): boolean { + if (config.redisRequired === undefined) { + return config.nodeEnv === 'production'; + } + + return parseBoolean(config.redisRequired, false); +} diff --git a/src/database/prisma.service.ts b/src/database/prisma.service.ts index 4aee782..051603b 100644 --- a/src/database/prisma.service.ts +++ b/src/database/prisma.service.ts @@ -40,6 +40,7 @@ export class PrismaService implements OnModuleInit, OnModuleDestroy { private readonly logger = new Logger(PrismaService.name); + private readonly pool: Pool; constructor(private configService: ConfigService) { const databaseUrl = configService.get('DATABASE_URL'); @@ -62,6 +63,8 @@ export class PrismaService ], }); + this.pool = pool; + // Log database queries in development mode if (process.env.NODE_ENV === 'development') { this.$on('query' as never, (e: QueryEvent) => { @@ -101,6 +104,7 @@ export class PrismaService async onModuleDestroy() { try { await this.$disconnect(); + await this.pool.end(); this.logger.log('Successfully disconnected from database'); } catch (error) { this.logger.error('Error disconnecting from database', error); diff --git a/src/database/redis.module.ts b/src/database/redis.module.ts index 6a4d488..091032b 100644 --- a/src/database/redis.module.ts +++ b/src/database/redis.module.ts @@ -1,35 +1,108 @@ -import { Module, Global, Logger } from '@nestjs/common'; +import { + Inject, + Injectable, + Logger, + Module, + Global, + OnModuleDestroy, +} from '@nestjs/common'; import { ConfigService } from '@nestjs/config'; import Redis from 'ioredis'; +import { + parsePositiveInteger, + parseRedisRequired, +} from '../common/config/env.utils'; export const REDIS_CLIENT = 'REDIS_CLIENT'; export const REDIS_SUBSCRIBER_CLIENT = 'REDIS_SUBSCRIBER_CLIENT'; +const DEFAULT_REDIS_STARTUP_RETRIES = 10; + +@Injectable() +class RedisLifecycleService implements OnModuleDestroy { + private readonly logger = new Logger(RedisLifecycleService.name); + + constructor( + @Inject(REDIS_CLIENT) private readonly redisClient: Redis | null, + @Inject(REDIS_SUBSCRIBER_CLIENT) + private readonly redisSubscriber: Redis | null, + ) {} + + async onModuleDestroy(): Promise { + const clients = [this.redisClient, this.redisSubscriber].filter( + (client): client is Redis => client !== null, + ); + + if (clients.length === 0) { + return; + } + + await Promise.all( + clients.map(async (client) => { + try { + await client.quit(); + } catch (error) { + this.logger.warn( + 'Redis quit failed, forcing disconnect', + error as Error, + ); + client.disconnect(); + } + }), + ); + } +} + @Global() @Module({ providers: [ { provide: REDIS_CLIENT, - useFactory: (configService: ConfigService) => { + useFactory: async (configService: ConfigService) => { const logger = new Logger('RedisModule'); const host = configService.get('REDIS_HOST'); - const port = configService.get('REDIS_PORT'); + const port = parsePositiveInteger( + configService.get('REDIS_PORT'), + 6379, + ); const password = configService.get('REDIS_PASSWORD'); + const connectTimeout = parsePositiveInteger( + configService.get('REDIS_CONNECT_TIMEOUT_MS'), + 5000, + ); + const redisRequired = parseRedisRequired({ + nodeEnv: configService.get('NODE_ENV'), + redisRequired: configService.get('REDIS_REQUIRED'), + }); + const startupRetries = parsePositiveInteger( + configService.get('REDIS_STARTUP_RETRIES'), + DEFAULT_REDIS_STARTUP_RETRIES, + ); - // Fallback or "disabled" mode if no host is provided if (!host) { - logger.warn( - 'REDIS_HOST not defined. Redis features will be disabled or fall back to local memory.', - ); + if (redisRequired) { + throw new Error( + 'REDIS_REQUIRED is enabled but REDIS_HOST is not configured', + ); + } + + logger.warn('REDIS_HOST not defined. Redis features are disabled.'); return null; } const client = new Redis({ host, - port: port || 6379, - password: password, - // Retry strategy: keep trying to reconnect + port, + password, + lazyConnect: true, + connectTimeout, + maxRetriesPerRequest: 1, + enableOfflineQueue: false, retryStrategy(times) { + if (times > startupRetries) { + return null; + } + const delay = Math.min(times * 50, 2000); return delay; }, @@ -40,20 +113,51 @@ export const REDIS_SUBSCRIBER_CLIENT = 'REDIS_SUBSCRIBER_CLIENT'; }); client.on('connect', () => { - logger.log(`Connected to Redis at ${host}:${port || 6379}`); + logger.log(`Connected to Redis at ${host}:${port}`); }); + try { + await client.connect(); + await client.ping(); + } catch { + client.disconnect(); + + if (redisRequired) { + throw new Error( + `Failed to connect to required Redis at ${host}:${port}`, + ); + } + + logger.warn('Redis connection failed; Redis features are disabled.'); + return null; + } + return client; }, inject: [ConfigService], }, { provide: REDIS_SUBSCRIBER_CLIENT, - useFactory: (configService: ConfigService) => { + useFactory: async (configService: ConfigService) => { const logger = new Logger('RedisSubscriberModule'); const host = configService.get('REDIS_HOST'); - const port = configService.get('REDIS_PORT'); + const port = parsePositiveInteger( + configService.get('REDIS_PORT'), + 6379, + ); const password = configService.get('REDIS_PASSWORD'); + const connectTimeout = parsePositiveInteger( + configService.get('REDIS_CONNECT_TIMEOUT_MS'), + 5000, + ); + const redisRequired = parseRedisRequired({ + nodeEnv: configService.get('NODE_ENV'), + redisRequired: configService.get('REDIS_REQUIRED'), + }); + const startupRetries = parsePositiveInteger( + configService.get('REDIS_STARTUP_RETRIES'), + DEFAULT_REDIS_STARTUP_RETRIES, + ); if (!host) { return null; @@ -61,9 +165,17 @@ export const REDIS_SUBSCRIBER_CLIENT = 'REDIS_SUBSCRIBER_CLIENT'; const client = new Redis({ host, - port: port || 6379, - password: password, + port, + password, + lazyConnect: true, + connectTimeout, + maxRetriesPerRequest: 1, + enableOfflineQueue: false, retryStrategy(times) { + if (times > startupRetries) { + return null; + } + const delay = Math.min(times * 50, 2000); return delay; }, @@ -82,10 +194,29 @@ export const REDIS_SUBSCRIBER_CLIENT = 'REDIS_SUBSCRIBER_CLIENT'; logger.error('Redis subscriber connection error', err); }); + try { + await client.connect(); + await client.ping(); + } catch { + client.disconnect(); + + if (redisRequired) { + throw new Error( + `Failed to connect to required Redis subscriber at ${host}:${port}`, + ); + } + + logger.warn( + 'Redis subscriber connection failed; cross-instance subscriptions are disabled.', + ); + return null; + } + return client; }, inject: [ConfigService], }, + RedisLifecycleService, ], exports: [REDIS_CLIENT, REDIS_SUBSCRIBER_CLIENT], }) diff --git a/src/main.ts b/src/main.ts index 3494653..e4d48ec 100644 --- a/src/main.ts +++ b/src/main.ts @@ -2,6 +2,7 @@ import { NestFactory } from '@nestjs/core'; import { ValidationPipe, Logger } from '@nestjs/common'; import { ConfigService } from '@nestjs/config'; import { DocumentBuilder, SwaggerModule } from '@nestjs/swagger'; +import helmet from 'helmet'; import { AppModule } from './app.module'; import { AllExceptionsFilter } from './common/filters/all-exceptions.filter'; import { RedisIoAdapter } from './ws/redis-io.adapter'; @@ -10,12 +11,28 @@ async function bootstrap() { const logger = new Logger('Bootstrap'); const app = await NestFactory.create(AppModule); const configService = app.get(ConfigService); + const nodeEnv = configService.get('NODE_ENV') || 'development'; + const isProduction = nodeEnv === 'production'; + + app.enableShutdownHooks(); + + app.use( + helmet({ + contentSecurityPolicy: false, + crossOriginEmbedderPolicy: false, + }), + ); // Configure Redis Adapter for horizontal scaling (if enabled) const redisIoAdapter = new RedisIoAdapter(app, configService); await redisIoAdapter.connectToRedis(); app.useWebSocketAdapter(redisIoAdapter); + app.enableCors({ + origin: true, + credentials: true, + }); + // Enable global exception filter for consistent error responses app.useGlobalFilters(new AllExceptionsFilter()); @@ -29,43 +46,54 @@ async function bootstrap() { // Automatically transform payloads to DTO instances transform: true, // Provide detailed error messages - disableErrorMessages: false, + disableErrorMessages: isProduction, }), ); - // Configure Swagger documentation - const config = new DocumentBuilder() - .setTitle('Friendolls API') - .setDescription( - 'API for managing users in Friendolls application.\n\n' + - 'Authentication is handled via Passport.js social sign-in for desktop clients.\n' + - 'Desktop clients exchange one-time SSO codes for Friendolls JWT tokens.\n\n' + - 'Include the JWT token in the Authorization header as: `Bearer `', - ) - .setVersion('1.0') - .addBearerAuth( - { - type: 'http', - scheme: 'bearer', - bearerFormat: 'JWT', - name: 'Authorization', - description: 'Enter Friendolls JWT access token', - in: 'header', - }, - 'bearer', - ) - .addTag('users', 'User profile management endpoints') - .build(); + if (!isProduction) { + const config = new DocumentBuilder() + .setTitle('Friendolls API') + .setDescription( + 'API for managing users in Friendolls application.\n\n' + + 'Authentication is handled via Passport.js social sign-in for desktop clients.\n' + + 'Desktop clients exchange one-time SSO codes for Friendolls JWT tokens.\n\n' + + 'Include the JWT token in the Authorization header as: `Bearer `', + ) + .setVersion('1.0') + .addBearerAuth( + { + type: 'http', + scheme: 'bearer', + bearerFormat: 'JWT', + name: 'Authorization', + description: 'Enter Friendolls JWT access token', + in: 'header', + }, + 'bearer', + ) + .addTag('users', 'User profile management endpoints') + .build(); - const document = SwaggerModule.createDocument(app, config); - SwaggerModule.setup('api', app, document); + const document = SwaggerModule.createDocument(app, config); + SwaggerModule.setup('api', app, document); + } const host = process.env.HOST ?? 'localhost'; const port = process.env.PORT ?? 3000; await app.listen(port); + const httpServer = app.getHttpServer() as { + once?: (event: 'close', listener: () => void) => void; + } | null; + httpServer?.once?.('close', () => { + void redisIoAdapter.close(); + }); logger.log(`Application is running on: http://${host}:${port}`); - logger.log(`Swagger documentation available at: http://${host}:${port}/api`); + if (!isProduction) { + logger.log( + `Swagger documentation available at: http://${host}:${port}/api`, + ); + } } void bootstrap(); diff --git a/src/types/socket.d.ts b/src/types/socket.d.ts index aa96222..7d1c7c1 100644 --- a/src/types/socket.d.ts +++ b/src/types/socket.d.ts @@ -9,6 +9,8 @@ export type AuthenticatedSocket = BaseSocket< { user?: AuthenticatedUser; userId?: string; + senderName?: string; + senderNameCachedAt?: number; activeDollId?: string | null; friends?: Set; // Set of friend user IDs } diff --git a/src/users/users.controller.spec.ts b/src/users/users.controller.spec.ts index 53280d6..35fbe62 100644 --- a/src/users/users.controller.spec.ts +++ b/src/users/users.controller.spec.ts @@ -22,6 +22,7 @@ describe('UsersController', () => { const mockAuthUser: AuthenticatedUser = { userId: 'uuid-123', email: 'test@example.com', + tokenType: 'access', roles: ['user'], }; diff --git a/src/ws/redis-io.adapter.ts b/src/ws/redis-io.adapter.ts index 206cfc3..f35b924 100644 --- a/src/ws/redis-io.adapter.ts +++ b/src/ws/redis-io.adapter.ts @@ -4,10 +4,18 @@ import { createAdapter } from '@socket.io/redis-adapter'; import Redis from 'ioredis'; import { ConfigService } from '@nestjs/config'; import { INestApplicationContext, Logger } from '@nestjs/common'; +import { + parsePositiveInteger, + parseRedisRequired, +} from '../common/config/env.utils'; + +const DEFAULT_REDIS_STARTUP_RETRIES = 10; export class RedisIoAdapter extends IoAdapter { private adapterConstructor: ReturnType; private readonly logger = new Logger(RedisIoAdapter.name); + private pubClient: Redis | null = null; + private subClient: Redis | null = null; constructor( private app: INestApplicationContext, @@ -18,41 +26,63 @@ export class RedisIoAdapter extends IoAdapter { async connectToRedis(): Promise { const host = this.configService.get('REDIS_HOST'); - const port = this.configService.get('REDIS_PORT'); + const port = parsePositiveInteger( + this.configService.get('REDIS_PORT'), + 6379, + ); const password = this.configService.get('REDIS_PASSWORD'); + const startupRetries = parsePositiveInteger( + this.configService.get('REDIS_STARTUP_RETRIES'), + DEFAULT_REDIS_STARTUP_RETRIES, + ); + const redisRequired = parseRedisRequired({ + nodeEnv: this.configService.get('NODE_ENV'), + redisRequired: this.configService.get('REDIS_REQUIRED'), + }); - // Only set up Redis adapter if host is configured if (!host) { + if (redisRequired) { + throw new Error( + 'REDIS_REQUIRED is enabled but REDIS_HOST is not configured', + ); + } + this.logger.log('Redis adapter disabled (REDIS_HOST not set)'); return; } - this.logger.log(`Connecting Redis adapter to ${host}:${port || 6379}`); + this.logger.log(`Connecting Redis adapter to ${host}:${port}`); try { + const connectTimeout = parsePositiveInteger( + this.configService.get('REDIS_CONNECT_TIMEOUT_MS'), + 5000, + ); const pubClient = new Redis({ host, - port: port || 6379, - password: password, + port, + password, + lazyConnect: true, + connectTimeout, + maxRetriesPerRequest: 1, + enableOfflineQueue: false, retryStrategy(times) { - // Retry connecting but don't crash if Redis is temporarily down during startup + if (times > startupRetries) { + return null; + } + return Math.min(times * 50, 2000); }, }); const subClient = pubClient.duplicate(); - // Wait for connection to ensure it's valid - await new Promise((resolve, reject) => { - pubClient.once('connect', () => { - this.logger.log('Redis Pub client connected'); - resolve(); - }); - pubClient.once('error', (err) => { - this.logger.error('Redis Pub client error', err); - reject(err); - }); - }); + await pubClient.connect(); + await subClient.connect(); + await pubClient.ping(); + await subClient.ping(); + + this.logger.log('Redis Pub/Sub clients connected'); // Handle subsequent errors gracefully pubClient.on('error', (err) => { @@ -73,21 +103,53 @@ export class RedisIoAdapter extends IoAdapter { }); this.adapterConstructor = createAdapter(pubClient, subClient); + this.pubClient = pubClient; + this.subClient = subClient; this.logger.log('Redis adapter initialized successfully'); } catch (error) { + await this.close(); this.logger.error('Failed to initialize Redis adapter', error); - // We don't throw here to allow the app to start without Redis if connection fails, - // though functionality will be degraded if multiple instances are running. + + if (redisRequired) { + throw error; + } } } createIOServer(port: number, options?: ServerOptions): any { + const cors = { + origin: true, + credentials: true, + }; + // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment - const server = super.createIOServer(port, options); + const server = super.createIOServer(port, { + ...(options ?? {}), + cors, + }); if (this.adapterConstructor) { // eslint-disable-next-line @typescript-eslint/no-unsafe-call, @typescript-eslint/no-unsafe-member-access server.adapter(this.adapterConstructor); } return server; } + + async close(): Promise { + const clients = [this.pubClient, this.subClient].filter( + (client): client is Redis => client !== null, + ); + + await Promise.all( + clients.map(async (client) => { + try { + await client.quit(); + } catch { + client.disconnect(); + } + }), + ); + + this.pubClient = null; + this.subClient = null; + } } diff --git a/src/ws/state/connection/handler.ts b/src/ws/state/connection/handler.ts index 450eabf..6dce449 100644 --- a/src/ws/state/connection/handler.ts +++ b/src/ws/state/connection/handler.ts @@ -6,13 +6,11 @@ import { PrismaService } from '../../../database/prisma.service'; import { UserSocketService } from '../user-socket.service'; import { WsNotificationService } from '../ws-notification.service'; import { WS_EVENT } from '../ws-events'; -import { UsersService } from '../../../users/users.service'; export class ConnectionHandler { constructor( private readonly jwtVerificationService: JwtVerificationService, private readonly prisma: PrismaService, - private readonly usersService: UsersService, private readonly userSocketService: UserSocketService, private readonly wsNotificationService: WsNotificationService, private readonly logger: Logger, @@ -43,6 +41,7 @@ export class ConnectionHandler { // Initialize defaults client.data.activeDollId = null; client.data.friends = new Set(); + client.data.senderName = undefined; // userId is not set yet, it will be set in handleClientInitialize this.logger.log(`WebSocket authenticated (Pending Init): ${payload.sub}`); @@ -94,42 +93,42 @@ export class ConnectionHandler { this.logger.log( `WebSocket authenticated via initialize fallback (Pending Init): ${payload.sub}`, ); - - this.logger.log( - `WebSocket authenticated via initialize fallback (Pending Init): ${payload.sub}`, - ); } if (!userTokenData) { throw new WsException('Unauthorized: No user data found'); } - const user = await this.usersService.findOne(userTokenData.userId); - - // 2. Register socket mapping (Redis Write) - await this.userSocketService.setSocket(user.id, client.id); - client.data.userId = user.id; - - // 3. Fetch initial state (DB Read) - const [userWithDoll, friends] = await Promise.all([ + // 2. Fetch initial state (DB Read) + const [userState, friends] = await Promise.all([ this.prisma.user.findUnique({ - where: { id: user.id }, - select: { activeDollId: true }, + where: { id: userTokenData.userId }, + select: { id: true, name: true, username: true, activeDollId: true }, }), this.prisma.friendship.findMany({ - where: { userId: user.id }, + where: { userId: userTokenData.userId }, select: { friendId: true }, }), ]); - client.data.activeDollId = userWithDoll?.activeDollId || null; - client.data.friends = new Set(friends.map((f) => f.friendId)); + if (!userState) { + throw new WsException('Unauthorized: No user data found'); + } - this.logger.log(`Client initialized: ${user.id} (${client.id})`); + // 3. Register socket mapping (Redis Write) + await this.userSocketService.setSocket(userState.id, client.id); + client.data.userId = userState.id; + + client.data.activeDollId = userState.activeDollId || null; + client.data.friends = new Set(friends.map((f) => f.friendId)); + client.data.senderName = userState.name || userState.username; + client.data.senderNameCachedAt = Date.now(); + + this.logger.log(`Client initialized: ${userState.id} (${client.id})`); // 4. Notify client client.emit(WS_EVENT.INITIALIZED, { - userId: user.id, + userId: userState.id, activeDollId: client.data.activeDollId, }); } catch (error) { @@ -157,7 +156,9 @@ export class ConnectionHandler { // Notify friends that this user has disconnected const friends = client.data.friends; if (friends) { - const friendIds = Array.from(friends); + const friendIds = Array.from(friends).filter( + (friendId): friendId is string => typeof friendId === 'string', + ); const friendSockets = await this.userSocketService.getFriendsSockets(friendIds); @@ -179,9 +180,5 @@ export class ConnectionHandler { this.logger.log( `Client id: ${client.id} disconnected (user: ${user?.userId || 'unknown'})`, ); - - this.logger.log( - `Client id: ${client.id} disconnected (user: ${user?.userId || 'unknown'})`, - ); } } diff --git a/src/ws/state/interaction/handler.ts b/src/ws/state/interaction/handler.ts index ea5619e..2ea837c 100644 --- a/src/ws/state/interaction/handler.ts +++ b/src/ws/state/interaction/handler.ts @@ -1,14 +1,16 @@ import { Logger } from '@nestjs/common'; import { WsException } from '@nestjs/websockets'; import type { AuthenticatedSocket } from '../../../types/socket'; +import { PrismaService } from '../../../database/prisma.service'; import { SendInteractionDto } from '../../dto/send-interaction.dto'; import { InteractionPayloadDto } from '../../dto/interaction-payload.dto'; -import { PrismaService } from '../../../database/prisma.service'; import { UserSocketService } from '../user-socket.service'; import { WsNotificationService } from '../ws-notification.service'; import { WS_EVENT } from '../ws-events'; import { Validator } from '../utils/validation'; +const SENDER_NAME_CACHE_TTL_MS = 10 * 60 * 1000; + export class InteractionHandler { private readonly logger = new Logger(InteractionHandler.name); @@ -18,6 +20,32 @@ export class InteractionHandler { private readonly wsNotificationService: WsNotificationService, ) {} + private async resolveSenderName( + client: AuthenticatedSocket, + userId: string, + ): Promise { + const cachedName = client.data.senderName; + const cachedAt = client.data.senderNameCachedAt; + const cacheIsFresh = + cachedName && + typeof cachedAt === 'number' && + Date.now() - cachedAt < SENDER_NAME_CACHE_TTL_MS; + + if (cacheIsFresh) { + return cachedName; + } + + const sender = await this.prisma.user.findUnique({ + where: { id: userId }, + select: { name: true, username: true }, + }); + + const senderName = sender?.name || sender?.username || 'Unknown'; + client.data.senderName = senderName; + client.data.senderNameCachedAt = Date.now(); + return senderName; + } + async handleSendInteraction( client: AuthenticatedSocket, data: SendInteractionDto, @@ -61,11 +89,7 @@ export class InteractionHandler { } // 3. Construct payload - const sender = await this.prisma.user.findUnique({ - where: { id: currentUserId }, - select: { name: true, username: true }, - }); - const senderName = sender?.name || sender?.username || 'Unknown'; + const senderName = await this.resolveSenderName(client, currentUserId); const payload: InteractionPayloadDto = { senderUserId: currentUserId, diff --git a/src/ws/state/state.gateway.spec.ts b/src/ws/state/state.gateway.spec.ts index daf0735..559a2f3 100644 --- a/src/ws/state/state.gateway.spec.ts +++ b/src/ws/state/state.gateway.spec.ts @@ -3,7 +3,6 @@ import { Test, TestingModule } from '@nestjs/testing'; import { StateGateway } from './state.gateway'; import { AuthenticatedSocket } from '../../types/socket'; import { JwtVerificationService } from '../../auth/services/jwt-verification.service'; -import { UsersService } from '../../users/users.service'; import { PrismaService } from '../../database/prisma.service'; import { UserSocketService } from './user-socket.service'; import { WsNotificationService } from './ws-notification.service'; @@ -39,7 +38,6 @@ describe('StateGateway', () => { sockets: { sockets: { size: number; get: jest.Mock } }; to: jest.Mock; }; - let mockUsersService: Partial; let mockJwtVerificationService: Partial; let mockPrismaService: Partial; let mockUserSocketService: Partial; @@ -67,12 +65,6 @@ describe('StateGateway', () => { }), }; - mockUsersService = { - findOne: jest.fn().mockResolvedValue({ - id: 'user-id', - }), - }; - mockJwtVerificationService = { extractToken: jest.fn((handshake) => handshake.auth?.token), verifyToken: jest.fn().mockReturnValue({ @@ -83,7 +75,12 @@ describe('StateGateway', () => { mockPrismaService = { user: { - findUnique: jest.fn().mockResolvedValue({ activeDollId: 'doll-123' }), + findUnique: jest.fn().mockResolvedValue({ + id: 'user-id', + name: 'Test User', + username: 'test-user', + activeDollId: 'doll-123', + }), } as any, friendship: { findMany: jest.fn().mockResolvedValue([]), @@ -119,7 +116,6 @@ describe('StateGateway', () => { const module: TestingModule = await Test.createTestingModule({ providers: [ StateGateway, - { provide: UsersService, useValue: mockUsersService }, { provide: JwtVerificationService, useValue: mockJwtVerificationService, @@ -190,7 +186,6 @@ describe('StateGateway', () => { ); // Should NOT call these anymore in handleConnection - expect(mockUsersService.findOne).not.toHaveBeenCalled(); expect(mockUserSocketService.setSocket).not.toHaveBeenCalled(); // Should set data on client @@ -244,6 +239,9 @@ describe('StateGateway', () => { // Mock Prisma responses (mockPrismaService.user!.findUnique as jest.Mock).mockResolvedValue({ + id: 'user-id', + name: 'Test User', + username: 'test-user', activeDollId: 'doll-123', }); (mockPrismaService.friendship!.findMany as jest.Mock).mockResolvedValue([ @@ -255,32 +253,29 @@ describe('StateGateway', () => { mockClient as unknown as AuthenticatedSocket, ); - // 1. Load User - expect(mockUsersService.findOne).toHaveBeenCalledWith('test-sub'); - - // 2. Set Socket + // 1. Set Socket expect(mockUserSocketService.setSocket).toHaveBeenCalledWith( 'user-id', 'client1', ); - // 3. Fetch State (DB) + // 2. Fetch State (DB) expect(mockPrismaService.user!.findUnique).toHaveBeenCalledWith({ - where: { id: 'user-id' }, - select: { activeDollId: true }, + where: { id: 'test-sub' }, + select: { id: true, name: true, username: true, activeDollId: true }, }); expect(mockPrismaService.friendship!.findMany).toHaveBeenCalledWith({ - where: { userId: 'user-id' }, + where: { userId: 'test-sub' }, select: { friendId: true }, }); - // 4. Update Client Data + // 3. Update Client Data expect(mockClient.data.userId).toBe('user-id'); expect(mockClient.data.activeDollId).toBe('doll-123'); expect(mockClient.data.friends).toContain('friend-1'); expect(mockClient.data.friends).toContain('friend-2'); - // 5. Emit Initialized + // 4. Emit Initialized expect(mockClient.emit).toHaveBeenCalledWith('initialized', { userId: 'user-id', activeDollId: 'doll-123', diff --git a/src/ws/state/state.gateway.ts b/src/ws/state/state.gateway.ts index 102a890..29719f7 100644 --- a/src/ws/state/state.gateway.ts +++ b/src/ws/state/state.gateway.ts @@ -1,4 +1,4 @@ -import { Logger, Inject } from '@nestjs/common'; +import { Logger, Inject, OnModuleDestroy } from '@nestjs/common'; import { OnGatewayConnection, OnGatewayDisconnect, @@ -22,7 +22,6 @@ import { PrismaService } from '../../database/prisma.service'; import { UserSocketService } from './user-socket.service'; import { WsNotificationService } from './ws-notification.service'; import { WS_EVENT, REDIS_CHANNEL } from './ws-events'; -import { UsersService } from '../../users/users.service'; import { ConnectionHandler } from './connection/handler'; import { CursorHandler } from './cursor/handler'; import { StatusHandler } from './status/handler'; @@ -31,14 +30,13 @@ import { RedisHandler } from './utils/redis-handler'; import { Broadcaster } from './utils/broadcasting'; import { Throttler } from './utils/throttling'; -@WebSocketGateway({ - cors: { - origin: true, - credentials: true, - }, -}) +@WebSocketGateway() export class StateGateway - implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect + implements + OnGatewayInit, + OnGatewayConnection, + OnGatewayDisconnect, + OnModuleDestroy { private readonly logger = new Logger(StateGateway.name); @@ -55,7 +53,6 @@ export class StateGateway constructor( private readonly jwtVerificationService: JwtVerificationService, private readonly prisma: PrismaService, - private readonly usersService: UsersService, private readonly userSocketService: UserSocketService, private readonly wsNotificationService: WsNotificationService, @Inject(REDIS_CLIENT) private readonly redisClient: Redis | null, @@ -70,7 +67,6 @@ export class StateGateway this.connectionHandler = new ConnectionHandler( this.jwtVerificationService, this.prisma, - this.usersService, this.userSocketService, this.wsNotificationService, this.logger, @@ -136,6 +132,12 @@ export class StateGateway } } + onModuleDestroy() { + if (this.redisSubscriber) { + this.redisSubscriber.removeAllListeners('message'); + } + } + async isUserOnline(userId: string): Promise { return this.userSocketService.isUserOnline(userId); } diff --git a/src/ws/state/ws-notification.service.ts b/src/ws/state/ws-notification.service.ts index 3533de8..34a3ad4 100644 --- a/src/ws/state/ws-notification.service.ts +++ b/src/ws/state/ws-notification.service.ts @@ -48,13 +48,27 @@ export class WsNotificationService { action: 'add' | 'delete', ) { if (this.redisClient) { - await this.redisClient.publish( - REDIS_CHANNEL.FRIEND_CACHE_UPDATE, - JSON.stringify({ userId, friendId, action }), - ); - } else { - // Fallback: update locally + try { + await this.redisClient.publish( + REDIS_CHANNEL.FRIEND_CACHE_UPDATE, + JSON.stringify({ userId, friendId, action }), + ); + return; + } catch (error) { + this.logger.warn( + 'Redis publish failed for friend cache update; applying local cache update only', + error as Error, + ); + } + } + + try { await this.updateFriendsCacheLocal(userId, friendId, action); + } catch (error) { + this.logger.error( + 'Failed to apply local friend cache update', + error as Error, + ); } } @@ -89,13 +103,27 @@ export class WsNotificationService { async publishActiveDollUpdate(userId: string, dollId: string | null) { if (this.redisClient) { - await this.redisClient.publish( - REDIS_CHANNEL.ACTIVE_DOLL_UPDATE, - JSON.stringify({ userId, dollId }), - ); - } else { - // Fallback: update locally + try { + await this.redisClient.publish( + REDIS_CHANNEL.ACTIVE_DOLL_UPDATE, + JSON.stringify({ userId, dollId }), + ); + return; + } catch (error) { + this.logger.warn( + 'Redis publish failed for active doll update; applying local cache update only', + error as Error, + ); + } + } + + try { await this.updateActiveDollCache(userId, dollId); + } catch (error) { + this.logger.error( + 'Failed to apply local active doll cache update', + error as Error, + ); } } }