diff --git a/.env.example b/.env.example index c0129c6..0948865 100644 --- a/.env.example +++ b/.env.example @@ -13,6 +13,12 @@ REDIS_REQUIRED=false REDIS_CONNECT_TIMEOUT_MS=5000 REDIS_STARTUP_RETRIES=10 +# Cache +CACHE_KEY_PREFIX=friendolls +CACHE_DEFAULT_TTL_SECONDS=60 +CACHE_MAX_TTL_SECONDS=86400 +CACHE_METRICS_LOG_INTERVAL_MS=60000 + # JWT Configuration JWT_SECRET=replace-with-strong-random-secret JWT_ISSUER=friendolls @@ -24,6 +30,10 @@ AUTH_CLEANUP_ENABLED=true AUTH_CLEANUP_INTERVAL_MS=900000 AUTH_SESSION_REVOKED_RETENTION_DAYS=7 +# Rate limiting +THROTTLE_TTL=1000 +THROTTLE_LIMIT=5 + # Google OAuth GOOGLE_CLIENT_ID="replace-with-google-client-id" GOOGLE_CLIENT_SECRET="replace-with-google-client-secret" diff --git a/src/app.module.ts b/src/app.module.ts index 2f66ac4..4626f2c 100644 --- a/src/app.module.ts +++ b/src/app.module.ts @@ -5,6 +5,7 @@ import { EventEmitterModule } from '@nestjs/event-emitter'; import { ThrottlerGuard, ThrottlerModule } from '@nestjs/throttler'; import { AppController } from './app.controller'; import { AppService } from './app.service'; +import { CacheModule, RedisThrottlerStorage } from './common/cache'; import { UsersModule } from './users/users.module'; import { AuthModule } from './auth/auth.module'; import { DatabaseModule } from './database/database.module'; @@ -12,7 +13,10 @@ import { RedisModule } from './database/redis.module'; import { WsModule } from './ws/ws.module'; import { FriendsModule } from './friends/friends.module'; import { DollsModule } from './dolls/dolls.module'; -import { parseRedisRequired } from './common/config/env.utils'; +import { + parsePositiveInteger, + parseRedisRequired, +} from './common/config/env.utils'; /** * Validates required environment variables. @@ -117,15 +121,33 @@ function validateOptionalProvider( envFilePath: '.env', validate: validateEnvironment, }), + CacheModule, ThrottlerModule.forRootAsync({ - imports: [ConfigModule], - inject: [ConfigService], - useFactory: (config: ConfigService) => [ - { - ttl: config.get('THROTTLE_TTL', 1000), - limit: config.get('THROTTLE_LIMIT', 5), - }, - ], + imports: [ConfigModule, CacheModule], + inject: [ConfigService, RedisThrottlerStorage], + useFactory: ( + config: ConfigService, + redisThrottlerStorage: RedisThrottlerStorage, + ) => { + const ttl = parsePositiveInteger( + config.get('THROTTLE_TTL'), + 1000, + ); + const limit = parsePositiveInteger( + config.get('THROTTLE_LIMIT'), + 5, + ); + + return { + storage: redisThrottlerStorage, + throttlers: [ + { + ttl, + limit, + }, + ], + }; + }, }), EventEmitterModule.forRoot(), DatabaseModule, diff --git a/src/common/cache/cache.module.ts b/src/common/cache/cache.module.ts new file mode 100644 index 0000000..53ed337 --- /dev/null +++ b/src/common/cache/cache.module.ts @@ -0,0 +1,12 @@ +import { Global, Module } from '@nestjs/common'; +import { RedisModule } from '../../database/redis.module'; +import { CacheService } from './cache.service'; +import { RedisThrottlerStorage } from './redis-throttler.storage'; + +@Global() +@Module({ + imports: [RedisModule], + providers: [CacheService, RedisThrottlerStorage], + exports: [CacheService, RedisThrottlerStorage], +}) +export class CacheModule {} diff --git a/src/common/cache/cache.service.ts b/src/common/cache/cache.service.ts new file mode 100644 index 0000000..d4de209 --- /dev/null +++ b/src/common/cache/cache.service.ts @@ -0,0 +1,185 @@ +import { Inject, Injectable, Logger, OnModuleDestroy } from '@nestjs/common'; +import { ConfigService } from '@nestjs/config'; +import Redis from 'ioredis'; +import { REDIS_CLIENT } from '../../database/redis.module'; +import { parsePositiveInteger } from '../config/env.utils'; + +const DEFAULT_CACHE_KEY_PREFIX = 'friendolls'; +const DEFAULT_CACHE_TTL_SECONDS = 60; +const DEFAULT_CACHE_MAX_TTL_SECONDS = 86_400; +const DEFAULT_CACHE_METRICS_LOG_INTERVAL_MS = 60_000; + +@Injectable() +export class CacheService implements OnModuleDestroy { + private readonly logger = new Logger(CacheService.name); + private readonly keyPrefix: string; + private readonly defaultTtlSeconds: number; + private readonly maxTtlSeconds: number; + private readonly metricsLogIntervalMs: number; + + private readonly metrics = { + getHits: 0, + getMisses: 0, + sets: 0, + deletes: 0, + errors: 0, + unavailable: 0, + }; + + private metricsTimer: NodeJS.Timeout | null = null; + + constructor( + private readonly configService: ConfigService, + @Inject(REDIS_CLIENT) private readonly redisClient: Redis | null, + ) { + this.keyPrefix = + this.configService.get('CACHE_KEY_PREFIX') || + DEFAULT_CACHE_KEY_PREFIX; + this.defaultTtlSeconds = parsePositiveInteger( + this.configService.get('CACHE_DEFAULT_TTL_SECONDS'), + DEFAULT_CACHE_TTL_SECONDS, + ); + this.maxTtlSeconds = parsePositiveInteger( + this.configService.get('CACHE_MAX_TTL_SECONDS'), + DEFAULT_CACHE_MAX_TTL_SECONDS, + ); + this.metricsLogIntervalMs = parsePositiveInteger( + this.configService.get('CACHE_METRICS_LOG_INTERVAL_MS'), + DEFAULT_CACHE_METRICS_LOG_INTERVAL_MS, + ); + + if (this.metricsLogIntervalMs > 0) { + this.metricsTimer = setInterval(() => { + this.flushMetrics(); + }, this.metricsLogIntervalMs); + this.metricsTimer.unref(); + } + } + + onModuleDestroy(): void { + if (this.metricsTimer) { + clearInterval(this.metricsTimer); + this.metricsTimer = null; + } + + this.flushMetrics(); + } + + getNamespacedKey(namespace: string, key: string): string { + return `${this.keyPrefix}:${namespace}:${key}`; + } + + resolveTtlSeconds(ttlSeconds?: number): number { + if (ttlSeconds === undefined) { + return this.defaultTtlSeconds; + } + + if (!Number.isFinite(ttlSeconds) || ttlSeconds <= 0) { + return this.defaultTtlSeconds; + } + + return Math.min(Math.floor(ttlSeconds), this.maxTtlSeconds); + } + + async get(key: string): Promise { + if (!this.redisClient) { + this.metrics.unavailable += 1; + return null; + } + + try { + const value = await this.redisClient.get(key); + if (value === null) { + this.metrics.getMisses += 1; + } else { + this.metrics.getHits += 1; + } + return value; + } catch (error) { + this.metrics.errors += 1; + this.logger.warn(`Cache get failed for key ${key}`, error as Error); + return null; + } + } + + async set(key: string, value: string, ttlSeconds?: number): Promise { + if (!this.redisClient) { + this.metrics.unavailable += 1; + return false; + } + + try { + const ttl = this.resolveTtlSeconds(ttlSeconds); + await this.redisClient.set(key, value, 'EX', ttl); + this.metrics.sets += 1; + return true; + } catch (error) { + this.metrics.errors += 1; + this.logger.warn(`Cache set failed for key ${key}`, error as Error); + return false; + } + } + + async del(key: string): Promise { + if (!this.redisClient) { + this.metrics.unavailable += 1; + return false; + } + + try { + await this.redisClient.del(key); + this.metrics.deletes += 1; + return true; + } catch (error) { + this.metrics.errors += 1; + this.logger.warn(`Cache delete failed for key ${key}`, error as Error); + return false; + } + } + + getRedisClient(): Redis | null { + return this.redisClient; + } + + recordError(operation: string, key: string, error: unknown): void { + this.metrics.errors += 1; + this.logger.warn( + `Cache ${operation} failed for key ${key}`, + error as Error, + ); + } + + recordUnavailable(): void { + this.metrics.unavailable += 1; + } + + private flushMetrics(): void { + const totalReads = this.metrics.getHits + this.metrics.getMisses; + + if ( + totalReads === 0 && + this.metrics.sets === 0 && + this.metrics.deletes === 0 && + this.metrics.errors === 0 && + this.metrics.unavailable === 0 + ) { + return; + } + + const hitRate = + totalReads === 0 + ? '0.00' + : ((this.metrics.getHits / totalReads) * 100).toFixed(2); + + this.logger.log( + `metrics reads=${totalReads} hits=${this.metrics.getHits} misses=${this.metrics.getMisses} hitRate=${hitRate}% sets=${this.metrics.sets} deletes=${this.metrics.deletes} errors=${this.metrics.errors} unavailable=${this.metrics.unavailable}`, + ); + + this.metrics.getHits = 0; + this.metrics.getMisses = 0; + this.metrics.sets = 0; + this.metrics.deletes = 0; + this.metrics.errors = 0; + this.metrics.unavailable = 0; + } +} diff --git a/src/common/cache/index.ts b/src/common/cache/index.ts new file mode 100644 index 0000000..eeb7604 --- /dev/null +++ b/src/common/cache/index.ts @@ -0,0 +1,3 @@ +export { CacheModule } from './cache.module'; +export { CacheService } from './cache.service'; +export { RedisThrottlerStorage } from './redis-throttler.storage'; diff --git a/src/common/cache/redis-throttler.storage.ts b/src/common/cache/redis-throttler.storage.ts new file mode 100644 index 0000000..cceb247 --- /dev/null +++ b/src/common/cache/redis-throttler.storage.ts @@ -0,0 +1,239 @@ +import { Injectable } from '@nestjs/common'; +import { ThrottlerStorage } from '@nestjs/throttler'; +import { CacheService } from './cache.service'; + +interface RedisThrottlerStorageRecord { + totalHits: number; + timeToExpire: number; + isBlocked: boolean; + timeToBlockExpire: number; +} + +@Injectable() +export class RedisThrottlerStorage implements ThrottlerStorage { + private static readonly IN_MEMORY_CLEANUP_INTERVAL = 500; + private readonly inMemoryStorage = new Map< + string, + { totalHits: number; expiresAt: number; blockExpiresAt: number } + >(); + private inMemoryOperationCount = 0; + + constructor(private readonly cacheService: CacheService) {} + + async increment( + key: string, + ttl: number, + limit: number, + blockDuration: number, + throttlerName: string, + ): Promise { + const safeLimit = Math.max(0, Math.floor(limit)); + const ttlMilliseconds = this.normalizeDurationMs(ttl); + const blockDurationMilliseconds = this.normalizeDurationMs(blockDuration); + const counterKey = this.cacheService.getNamespacedKey( + 'throttle:counter', + `${throttlerName}:${key}`, + ); + const blockKey = this.cacheService.getNamespacedKey( + 'throttle:block', + `${throttlerName}:${key}`, + ); + + const redisClient = this.cacheService.getRedisClient(); + if (!redisClient) { + this.cacheService.recordUnavailable(); + return this.incrementInMemory( + counterKey, + ttlMilliseconds, + safeLimit, + blockDurationMilliseconds, + ); + } + + try { + const initialized = await redisClient.set( + counterKey, + '1', + 'PX', + ttlMilliseconds, + 'NX', + ); + + if (initialized === 'OK') { + const existingBlockTtlRemainingMs = await redisClient.pttl(blockKey); + return { + totalHits: 1, + timeToExpire: Math.ceil(ttlMilliseconds / 1000), + isBlocked: existingBlockTtlRemainingMs > 0, + timeToBlockExpire: + existingBlockTtlRemainingMs > 0 + ? this.toSecondsFromPttl( + existingBlockTtlRemainingMs, + blockDurationMilliseconds, + ) + : 0, + }; + } + + const [ + existingBlockTtlRemainingMs, + ttlRemainingBeforeHitMs, + currentCount, + ] = await Promise.all([ + redisClient.pttl(blockKey), + redisClient.pttl(counterKey), + redisClient.get(counterKey), + ]); + + if (existingBlockTtlRemainingMs > 0) { + const totalHits = Number(currentCount ?? '0'); + return { + totalHits: Number.isFinite(totalHits) ? totalHits : 0, + timeToExpire: this.toSecondsFromPttl( + ttlRemainingBeforeHitMs, + ttlMilliseconds, + ), + isBlocked: true, + timeToBlockExpire: this.toSecondsFromPttl( + existingBlockTtlRemainingMs, + blockDurationMilliseconds, + ), + }; + } + + const count = await redisClient.incr(counterKey); + if (count === 1) { + await redisClient.pexpire(counterKey, ttlMilliseconds); + } + + const [ttlRemainingMs, blockTtlRemainingMs] = await Promise.all([ + redisClient.pttl(counterKey), + redisClient.pttl(blockKey), + ]); + + let isBlocked = blockTtlRemainingMs > 0; + + if (!isBlocked && safeLimit > 0 && count > safeLimit) { + await redisClient.set(blockKey, '1', 'PX', blockDurationMilliseconds); + isBlocked = true; + } + + const refreshedBlockTtlRemainingMs = isBlocked + ? await redisClient.pttl(blockKey) + : -1; + + return { + totalHits: count, + timeToExpire: this.toSecondsFromPttl(ttlRemainingMs, ttlMilliseconds), + isBlocked, + timeToBlockExpire: isBlocked + ? this.toSecondsFromPttl( + refreshedBlockTtlRemainingMs, + blockDurationMilliseconds, + ) + : 0, + }; + } catch (error) { + this.cacheService.recordError('throttler increment', counterKey, error); + + return this.incrementInMemory( + counterKey, + ttlMilliseconds, + safeLimit, + blockDurationMilliseconds, + ); + } + } + + private incrementInMemory( + key: string, + ttlMilliseconds: number, + limit: number, + blockDurationMilliseconds: number, + ): RedisThrottlerStorageRecord { + const now = Date.now(); + this.inMemoryOperationCount += 1; + if ( + this.inMemoryOperationCount % + RedisThrottlerStorage.IN_MEMORY_CLEANUP_INTERVAL === + 0 + ) { + this.cleanupExpiredInMemory(now); + } + + const existing = this.inMemoryStorage.get(key); + + if (existing && existing.blockExpiresAt > now) { + return { + totalHits: existing.totalHits, + timeToExpire: Math.max(1, Math.ceil((existing.expiresAt - now) / 1000)), + isBlocked: true, + timeToBlockExpire: Math.max( + 1, + Math.ceil((existing.blockExpiresAt - now) / 1000), + ), + }; + } + + let totalHits = 1; + let expiresAt = now + ttlMilliseconds; + let blockExpiresAt = 0; + + if (existing && existing.expiresAt > now) { + totalHits = existing.totalHits + 1; + expiresAt = existing.expiresAt; + } + + if (blockExpiresAt <= now) { + blockExpiresAt = 0; + } + + let isBlocked = blockExpiresAt > now; + if (!isBlocked && limit > 0 && totalHits > limit) { + blockExpiresAt = now + blockDurationMilliseconds; + isBlocked = true; + } + + this.inMemoryStorage.set(key, { + totalHits, + expiresAt, + blockExpiresAt, + }); + + return { + totalHits, + timeToExpire: Math.max(1, Math.ceil((expiresAt - now) / 1000)), + isBlocked, + timeToBlockExpire: isBlocked + ? Math.max(1, Math.ceil((blockExpiresAt - now) / 1000)) + : 0, + }; + } + + private normalizeDurationMs(value: number): number { + if (!Number.isFinite(value) || value <= 0) { + return 1000; + } + + return Math.max(1, Math.floor(value)); + } + + private toSecondsFromPttl(pttlMs: number, fallbackMs: number): number { + if (pttlMs > 0) { + return Math.max(1, Math.ceil(pttlMs / 1000)); + } + + return Math.max(1, Math.ceil(fallbackMs / 1000)); + } + + private cleanupExpiredInMemory(now: number): void { + for (const [mapKey, value] of this.inMemoryStorage) { + const counterExpired = value.expiresAt <= now; + const blockExpired = value.blockExpiresAt <= now; + + if (counterExpired && blockExpired) { + this.inMemoryStorage.delete(mapKey); + } + } + } +}