Compare commits

..

9 Commits

39 changed files with 1812 additions and 125 deletions

View File

@@ -12,6 +12,17 @@ REDIS_PORT=6379
REDIS_REQUIRED=false
REDIS_CONNECT_TIMEOUT_MS=5000
REDIS_STARTUP_RETRIES=10
# Stale presence cleanup threshold and interval
PRESENCE_STALE_AGE_MS=604800000
PRESENCE_CLEANUP_INTERVAL_MS=300000
# Cache
CACHE_KEY_PREFIX=friendolls
CACHE_DEFAULT_TTL_SECONDS=60
CACHE_MAX_TTL_SECONDS=86400
CACHE_METRICS_LOG_INTERVAL_MS=60000
# Max number of cache keys tracked per invalidation tag
CACHE_TAG_MAX_ENTRIES=5000
# JWT Configuration
JWT_SECRET=replace-with-strong-random-secret
@@ -24,6 +35,10 @@ AUTH_CLEANUP_ENABLED=true
AUTH_CLEANUP_INTERVAL_MS=900000
AUTH_SESSION_REVOKED_RETENTION_DAYS=7
# Rate limiting
THROTTLE_TTL=1000
THROTTLE_LIMIT=5
# Google OAuth
GOOGLE_CLIENT_ID="replace-with-google-client-id"
GOOGLE_CLIENT_SECRET="replace-with-google-client-secret"

View File

@@ -8,9 +8,9 @@ Backend server for Friendolls.
## Commands
- **Error Checks**: `pnpm check`
- **Format/Lint**: `pnpm format`, `pnpm lint`
- **Test**: `pnpm test` (Unit), `pnpm test:e2e` (E2E)
- **Lint/Check for errors**: `pnpm lint`
- **Format**: `pnpm format`
- **Test**: `pnpm test`
- **Single Test**: `pnpm test -- -t "test name"` or `pnpm test -- src/path/to/file.spec.ts`
- **Database**: `npx prisma generate`, `npx prisma migrate dev`
@@ -29,4 +29,4 @@ Backend server for Friendolls.
## Note
Do not run the project yourself. Run error checks and lints to detect issues.
Do not run the project yourself. Run lints and tests to detect issues after each final changes.

9
pnpm-lock.yaml generated
View File

@@ -62,6 +62,9 @@ importers:
dotenv:
specifier: ^17.2.3
version: 17.2.3
helmet:
specifier: ^8.1.0
version: 8.1.0
ioredis:
specifier: ^5.8.2
version: 5.8.2
@@ -2298,6 +2301,10 @@ packages:
resolution: {integrity: sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==}
engines: {node: '>= 0.4'}
helmet@8.1.0:
resolution: {integrity: sha512-jOiHyAZsmnr8LqoPGmCjYAaiuWwjAPLgY8ZX2XrmHawt99/u1y6RgrZMTeoPfpUbV96HOalYgz1qzkRbw54Pmg==}
engines: {node: '>=18.0.0'}
hono@4.7.10:
resolution: {integrity: sha512-QkACju9MiN59CKSY5JsGZCYmPZkA6sIW6OFCUp7qDjZu6S6KHtJHhAc9Uy9mV9F8PJ1/HQ3ybZF2yjCa/73fvQ==}
engines: {node: '>=16.9.0'}
@@ -6227,6 +6234,8 @@ snapshots:
dependencies:
function-bind: 1.1.2
helmet@8.1.0: {}
hono@4.7.10: {}
html-escaper@2.0.2: {}

View File

@@ -3,7 +3,6 @@
generator client {
provider = "prisma-client-js"
output = "../node_modules/.prisma/client"
}
datasource db {

View File

@@ -5,6 +5,7 @@ import { EventEmitterModule } from '@nestjs/event-emitter';
import { ThrottlerGuard, ThrottlerModule } from '@nestjs/throttler';
import { AppController } from './app.controller';
import { AppService } from './app.service';
import { CacheModule, RedisThrottlerStorage } from './common/cache';
import { UsersModule } from './users/users.module';
import { AuthModule } from './auth/auth.module';
import { DatabaseModule } from './database/database.module';
@@ -12,7 +13,10 @@ import { RedisModule } from './database/redis.module';
import { WsModule } from './ws/ws.module';
import { FriendsModule } from './friends/friends.module';
import { DollsModule } from './dolls/dolls.module';
import { parseRedisRequired } from './common/config/env.utils';
import {
parsePositiveInteger,
parseRedisRequired,
} from './common/config/env.utils';
/**
* Validates required environment variables.
@@ -79,6 +83,15 @@ function validateEnvironment(
throw new Error('REDIS_CONNECT_TIMEOUT_MS must be a positive number');
}
validateOptionalPositiveNumber(config, 'THROTTLE_TTL');
validateOptionalPositiveNumber(config, 'THROTTLE_LIMIT');
validateOptionalPositiveNumber(config, 'CACHE_DEFAULT_TTL_SECONDS');
validateOptionalPositiveNumber(config, 'CACHE_MAX_TTL_SECONDS');
validateOptionalPositiveNumber(config, 'CACHE_METRICS_LOG_INTERVAL_MS');
validateOptionalPositiveNumber(config, 'CACHE_TAG_MAX_ENTRIES');
validateOptionalPositiveNumber(config, 'PRESENCE_STALE_AGE_MS');
validateOptionalPositiveNumber(config, 'PRESENCE_CLEANUP_INTERVAL_MS');
validateOptionalProvider(config, 'GOOGLE');
validateOptionalProvider(config, 'DISCORD');
@@ -105,6 +118,20 @@ function validateOptionalProvider(
}
}
function validateOptionalPositiveNumber(
config: Record<string, unknown>,
key: string,
): void {
const value = config[key];
if (value === undefined || value === null || value === '') {
return;
}
if (!Number.isFinite(Number(value)) || Number(value) <= 0) {
throw new Error(`${key} must be a positive number`);
}
}
/**
* Root Application Module
*
@@ -117,15 +144,33 @@ function validateOptionalProvider(
envFilePath: '.env',
validate: validateEnvironment,
}),
CacheModule,
ThrottlerModule.forRootAsync({
imports: [ConfigModule],
inject: [ConfigService],
useFactory: (config: ConfigService) => [
imports: [ConfigModule, CacheModule],
inject: [ConfigService, RedisThrottlerStorage],
useFactory: (
config: ConfigService,
redisThrottlerStorage: RedisThrottlerStorage,
) => {
const ttl = parsePositiveInteger(
config.get<string>('THROTTLE_TTL'),
1000,
);
const limit = parsePositiveInteger(
config.get<string>('THROTTLE_LIMIT'),
5,
);
return {
storage: redisThrottlerStorage,
throttlers: [
{
ttl: config.get('THROTTLE_TTL', 1000),
limit: config.get('THROTTLE_LIMIT', 5),
ttl,
limit,
},
],
};
},
}),
EventEmitterModule.forRoot(),
DatabaseModule,

View File

@@ -4,8 +4,11 @@ import {
UnauthorizedException,
} from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { EventEmitter2 } from '@nestjs/event-emitter';
import { Test, TestingModule } from '@nestjs/testing';
import { decode, sign } from 'jsonwebtoken';
import { CacheService } from '../common/cache/cache.service';
import { CacheTagsService } from '../common/cache/cache-tags.service';
import { PrismaService } from '../database/prisma.service';
import { AuthService } from './auth.service';
import { sha256 } from './auth.utils';
@@ -58,6 +61,27 @@ describe('AuthService', () => {
$transaction: jest.fn(),
};
const mockEventEmitter = {
emit: jest.fn(),
};
const mockCacheService = {
get: jest.fn().mockResolvedValue(null),
set: jest.fn().mockResolvedValue(true),
del: jest.fn().mockResolvedValue(true),
getNamespacedKey: jest
.fn()
.mockImplementation(
(namespace: string, key: string) => `friendolls:${namespace}:${key}`,
),
recordError: jest.fn(),
};
const mockCacheTagsService = {
rememberKeyForTag: jest.fn().mockResolvedValue(undefined),
invalidateTag: jest.fn().mockResolvedValue(undefined),
};
const socialProfile: SocialAuthProfile = {
provider: 'google',
providerSubject: 'google-user-123',
@@ -94,6 +118,9 @@ describe('AuthService', () => {
AuthService,
{ provide: PrismaService, useValue: mockPrismaService },
{ provide: ConfigService, useValue: mockConfigService },
{ provide: EventEmitter2, useValue: mockEventEmitter },
{ provide: CacheService, useValue: mockCacheService },
{ provide: CacheTagsService, useValue: mockCacheTagsService },
],
}).compile();
@@ -135,6 +162,9 @@ describe('AuthService', () => {
const localService = new AuthService(
mockPrismaService as unknown as PrismaService,
mockConfigService as unknown as ConfigService,
mockEventEmitter as unknown as EventEmitter2,
mockCacheService as unknown as CacheService,
mockCacheTagsService as unknown as CacheTagsService,
);
expect(() =>

View File

@@ -13,6 +13,7 @@ import {
verify,
} from 'jsonwebtoken';
import { PrismaService } from '../database/prisma.service';
import { EventEmitter2 } from '@nestjs/event-emitter';
import type { SocialAuthProfile } from './types/social-auth-profile';
import type {
AuthTokens,
@@ -35,6 +36,15 @@ import {
usernameFromEmail,
} from './auth.utils';
import type { SsoProvider } from './dto/sso-provider';
import { UserEvents } from '../users/events/user.events';
import { CacheService } from '../common/cache/cache.service';
import {
authSessionUserTag,
authSessionCacheKey,
CACHE_NAMESPACE,
CACHE_TTL_SECONDS,
} from '../common/cache/cache-keys';
import { CacheTagsService } from '../common/cache/cache-tags.service';
interface SsoStateClaims {
provider: SsoProvider;
@@ -43,6 +53,28 @@ interface SsoStateClaims {
typ: 'sso_state';
}
interface AuthSessionWithUser {
id: string;
refresh_token_hash: string;
expires_at: Date;
revoked_at: Date | null;
provider: 'GOOGLE' | 'DISCORD' | null;
user_id: string;
email: string;
roles: string[];
}
interface CachedAuthSessionWithUser {
id: string;
refresh_token_hash: string;
expires_at: string;
revoked_at: string | null;
provider: 'GOOGLE' | 'DISCORD' | null;
user_id: string;
email: string;
roles: string[];
}
@Injectable()
export class AuthService {
private readonly logger = new Logger(AuthService.name);
@@ -56,6 +88,9 @@ export class AuthService {
constructor(
private readonly prisma: PrismaService,
private readonly configService: ConfigService,
private readonly eventEmitter: EventEmitter2,
private readonly cacheService: CacheService,
private readonly cacheTagsService: CacheTagsService,
) {
this.jwtSecret = this.configService.get<string>('JWT_SECRET') || '';
this.jwtIssuer =
@@ -159,7 +194,7 @@ export class AuthService {
}
if (session.refresh_token_hash !== refreshTokenHash) {
await this.revokeSessionOnReplay(session.id);
await this.revokeSessionOnReplay(session.id, session.user_id);
throw new UnauthorizedException('Invalid refresh token');
}
@@ -171,7 +206,7 @@ export class AuthService {
);
if (!updated) {
await this.revokeSessionOnReplay(session.id);
await this.revokeSessionOnReplay(session.id, session.user_id);
throw new UnauthorizedException('Invalid refresh token');
}
@@ -254,6 +289,11 @@ export class AuthService {
},
});
this.eventEmitter.emit(UserEvents.SEARCH_INDEX_INVALIDATED, {
userId: user.id,
});
this.eventEmitter.emit(UserEvents.PROFILE_UPDATED, { userId: user.id });
return user;
}
@@ -273,7 +313,7 @@ export class AuthService {
);
}
return this.prisma.$transaction(async (tx) => {
const user = await this.prisma.$transaction(async (tx) => {
let user = await tx.user.findUnique({
where: { email },
});
@@ -311,6 +351,13 @@ export class AuthService {
return user;
});
this.eventEmitter.emit(UserEvents.SEARCH_INDEX_INVALIDATED, {
userId: user.id,
});
this.eventEmitter.emit(UserEvents.PROFILE_UPDATED, { userId: user.id });
return user;
}
private async resolveUsername(
@@ -547,28 +594,34 @@ export class AuthService {
return rows[0] ?? null;
}
private async getSessionWithUser(sessionId: string): Promise<{
id: string;
refresh_token_hash: string;
expires_at: Date;
revoked_at: Date | null;
provider: 'GOOGLE' | 'DISCORD' | null;
user_id: string;
email: string;
roles: string[];
} | null> {
const rows = await this.prisma.$queryRaw<
Array<{
id: string;
refresh_token_hash: string;
expires_at: Date;
revoked_at: Date | null;
provider: 'GOOGLE' | 'DISCORD' | null;
user_id: string;
email: string;
roles: string[];
}>
>`
private async getSessionWithUser(
sessionId: string,
): Promise<AuthSessionWithUser | null> {
const sessionCacheKey = this.getAuthSessionCacheKey(sessionId);
const cachedSessionRaw = await this.cacheService.get(sessionCacheKey);
if (cachedSessionRaw) {
try {
const cachedSession = JSON.parse(
cachedSessionRaw,
) as CachedAuthSessionWithUser;
return {
...cachedSession,
expires_at: new Date(cachedSession.expires_at),
revoked_at: cachedSession.revoked_at
? new Date(cachedSession.revoked_at)
: null,
};
} catch (error) {
this.cacheService.recordError(
'auth session parse',
sessionCacheKey,
error,
);
}
}
const rows = await this.prisma.$queryRaw<Array<AuthSessionWithUser>>`
SELECT s.id, s.refresh_token_hash, s.expires_at, s.revoked_at, s.provider, s.user_id, u.email, u.roles
FROM auth_sessions AS s
INNER JOIN users AS u ON u.id = s.user_id
@@ -576,7 +629,29 @@ export class AuthService {
LIMIT 1
`;
return rows[0] ?? null;
const session = rows[0] ?? null;
if (!session) {
return null;
}
const cachePayload: CachedAuthSessionWithUser = {
...session,
expires_at: session.expires_at.toISOString(),
revoked_at: session.revoked_at ? session.revoked_at.toISOString() : null,
};
await this.cacheService.set(
sessionCacheKey,
JSON.stringify(cachePayload),
CACHE_TTL_SECONDS.AUTH_SESSION,
);
await this.cacheTagsService.rememberKeyForTag(
CACHE_NAMESPACE.AUTH_SESSION,
authSessionUserTag(session.user_id),
authSessionCacheKey(session.id),
);
return session;
}
private async rotateRefreshSession(
@@ -584,6 +659,8 @@ export class AuthService {
refreshTokenHash: string,
nextRefreshToken: string,
): Promise<boolean> {
await this.cacheService.del(this.getAuthSessionCacheKey(sessionId));
const rows = await this.prisma.$queryRaw<Array<{ id: string }>>`
UPDATE auth_sessions
SET refresh_token_hash = ${sha256(nextRefreshToken)},
@@ -597,6 +674,10 @@ export class AuthService {
RETURNING id
`;
if (rows.length === 1) {
await this.cacheService.del(this.getAuthSessionCacheKey(sessionId));
}
return rows.length === 1;
}
@@ -604,6 +685,8 @@ export class AuthService {
sessionId: string,
refreshTokenHash: string,
): Promise<boolean> {
await this.cacheService.del(this.getAuthSessionCacheKey(sessionId));
const rows = await this.prisma.$queryRaw<Array<{ id: string }>>`
UPDATE auth_sessions
SET revoked_at = NOW(),
@@ -615,17 +698,41 @@ export class AuthService {
RETURNING id
`;
if (rows.length === 1) {
await this.cacheService.del(this.getAuthSessionCacheKey(sessionId));
}
return rows.length === 1;
}
private async revokeSessionOnReplay(sessionId: string): Promise<void> {
private async revokeSessionOnReplay(
sessionId: string,
userId: string,
): Promise<void> {
await this.cacheService.del(this.getAuthSessionCacheKey(sessionId));
await this.revokeAllUserSessions(userId);
}
private async revokeAllUserSessions(userId: string): Promise<void> {
await this.prisma.$queryRaw<Array<{ id: string }>>`
UPDATE auth_sessions
SET revoked_at = NOW(),
updated_at = NOW()
WHERE id = ${sessionId}
WHERE user_id = ${userId}
AND revoked_at IS NULL
RETURNING id
`;
await this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.AUTH_SESSION,
authSessionUserTag(userId),
);
}
private getAuthSessionCacheKey(sessionId: string): string {
return this.cacheService.getNamespacedKey(
CACHE_NAMESPACE.AUTH_SESSION,
authSessionCacheKey(sessionId),
);
}
}

View File

@@ -20,6 +20,14 @@ const DEFAULT_REVOKED_RETENTION_DAYS = 7;
const CLEANUP_LOCK_KEY = 'lock:auth:cleanup';
const CLEANUP_LOCK_TTL_MS = 55_000;
const RELEASE_LOCK_SCRIPT = `
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
`;
@Injectable()
export class AuthCleanupService implements OnModuleInit, OnModuleDestroy {
private readonly logger = new Logger(AuthCleanupService.name);
@@ -141,10 +149,12 @@ export class AuthCleanupService implements OnModuleInit, OnModuleDestroy {
} finally {
if (lockAcquired && this.redisClient) {
try {
const currentLockValue = await this.redisClient.get(CLEANUP_LOCK_KEY);
if (currentLockValue === lockToken) {
await this.redisClient.del(CLEANUP_LOCK_KEY);
}
await this.redisClient.eval(
RELEASE_LOCK_SCRIPT,
1,
CLEANUP_LOCK_KEY,
lockToken,
);
} catch (error) {
this.logger.warn(
'Failed to release auth cleanup lock',

84
src/common/cache/cache-keys.ts vendored Normal file
View File

@@ -0,0 +1,84 @@
const EMPTY_VALUE_TOKEN = '_';
export const CACHE_NAMESPACE = {
FRIENDS_LIST: 'friends-list',
DOLLS_LIST: 'dolls-list',
USERS_SEARCH: 'users-search',
FRIENDSHIP_CHECK: 'friendship-check',
AUTH_SESSION: 'auth-session',
} as const;
function normalizeKeyPart(value: string | undefined): string {
if (!value) {
return EMPTY_VALUE_TOKEN;
}
return encodeURIComponent(value);
}
export const CACHE_TTL_SECONDS = {
FRIENDS_LIST: 30,
DOLLS_LIST: 30,
USERS_SEARCH: 20,
FRIENDSHIP_CHECK: 120,
AUTH_SESSION: 30,
} as const;
export function friendsListCacheKey(userId: string): string {
return normalizeKeyPart(userId);
}
export function friendsListOwnerTag(userId: string): string {
return `owner:${normalizeKeyPart(userId)}`;
}
export function friendsListDependsOnUserTag(userId: string): string {
return `depends-on:${normalizeKeyPart(userId)}`;
}
export function dollsListCacheKey(
ownerId: string,
requesterId: string,
): string {
return `${normalizeKeyPart(ownerId)}:${normalizeKeyPart(requesterId)}`;
}
export function dollsListOwnerTag(ownerId: string): string {
return `owner:${normalizeKeyPart(ownerId)}`;
}
export function dollsListViewerTag(viewerId: string): string {
return `viewer:${normalizeKeyPart(viewerId)}`;
}
export function usersSearchCacheKey(
username: string | undefined,
excludeUserId: string | undefined,
): string {
return `${normalizeKeyPart(username?.trim().toLowerCase())}:${normalizeKeyPart(excludeUserId)}`;
}
export const USERS_SEARCH_GLOBAL_TAG = 'global';
export function friendshipCheckCacheKey(
userId: string,
friendId: string,
): string {
return `${normalizeKeyPart(userId)}:${normalizeKeyPart(friendId)}`;
}
export function friendshipCheckUserTag(userId: string): string {
return `user:${normalizeKeyPart(userId)}`;
}
export function authSessionCacheKey(sessionId: string): string {
return normalizeKeyPart(sessionId);
}
export function authSessionUserTag(userId: string): string {
return `user:${normalizeKeyPart(userId)}`;
}
export function usersSearchUserTag(userId: string): string {
return `user:${normalizeKeyPart(userId)}`;
}

106
src/common/cache/cache-tags.service.ts vendored Normal file
View File

@@ -0,0 +1,106 @@
import { Injectable } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { CacheService } from './cache.service';
import { parsePositiveInteger } from '../config/env.utils';
const CACHE_TAG_SET_TTL_SECONDS = 86_400;
const DEFAULT_CACHE_TAG_MAX_ENTRIES = 5_000;
@Injectable()
export class CacheTagsService {
private readonly cacheTagMaxEntries: number;
constructor(
private readonly cacheService: CacheService,
private readonly configService: ConfigService,
) {
this.cacheTagMaxEntries = parsePositiveInteger(
this.configService.get<string>('CACHE_TAG_MAX_ENTRIES'),
DEFAULT_CACHE_TAG_MAX_ENTRIES,
);
}
async rememberKeyForTag(
namespace: string,
tag: string,
cacheKey: string,
): Promise<void> {
const redisClient = this.cacheService.getRedisClient();
if (!redisClient) {
return;
}
const tagSetKey = this.getTagSetKey(namespace, tag);
const keyWithNamespace = this.cacheService.getNamespacedKey(
namespace,
cacheKey,
);
try {
await Promise.all([
redisClient.sadd(tagSetKey, keyWithNamespace),
redisClient.expire(tagSetKey, CACHE_TAG_SET_TTL_SECONDS),
]);
const size = await redisClient.scard(tagSetKey);
if (size > this.cacheTagMaxEntries) {
await this.trimTagSet(tagSetKey, size - this.cacheTagMaxEntries);
}
} catch (error) {
this.cacheService.recordError('tag remember', tagSetKey, error);
}
}
async invalidateTag(namespace: string, tag: string): Promise<void> {
const redisClient = this.cacheService.getRedisClient();
if (!redisClient) {
return;
}
const tagSetKey = this.getTagSetKey(namespace, tag);
try {
const keys = await redisClient.smembers(tagSetKey);
if (keys.length === 0) {
await redisClient.del(tagSetKey);
return;
}
const pipeline = redisClient.pipeline();
keys.forEach((key) => pipeline.del(key));
pipeline.del(tagSetKey);
await pipeline.exec();
} catch (error) {
this.cacheService.recordError('tag invalidate', tagSetKey, error);
}
}
private getTagSetKey(namespace: string, tag: string): string {
return this.cacheService.getNamespacedKey(
'cache-tag',
`${namespace}:${tag}`,
);
}
private async trimTagSet(
tagSetKey: string,
countToDrop: number,
): Promise<void> {
const redisClient = this.cacheService.getRedisClient();
if (!redisClient || countToDrop <= 0) {
return;
}
try {
const sample = await redisClient.srandmember(tagSetKey, countToDrop);
const members = Array.isArray(sample) ? sample : [sample].filter(Boolean);
if (members.length === 0) {
return;
}
await redisClient.srem(tagSetKey, ...members);
} catch (error) {
this.cacheService.recordError('tag trim', tagSetKey, error);
}
}
}

13
src/common/cache/cache.module.ts vendored Normal file
View File

@@ -0,0 +1,13 @@
import { Global, Module } from '@nestjs/common';
import { RedisModule } from '../../database/redis.module';
import { CacheTagsService } from './cache-tags.service';
import { CacheService } from './cache.service';
import { RedisThrottlerStorage } from './redis-throttler.storage';
@Global()
@Module({
imports: [RedisModule],
providers: [CacheService, CacheTagsService, RedisThrottlerStorage],
exports: [CacheService, CacheTagsService, RedisThrottlerStorage],
})
export class CacheModule {}

185
src/common/cache/cache.service.ts vendored Normal file
View File

@@ -0,0 +1,185 @@
import { Inject, Injectable, Logger, OnModuleDestroy } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import Redis from 'ioredis';
import { REDIS_CLIENT } from '../../database/redis.module';
import { parsePositiveInteger } from '../config/env.utils';
const DEFAULT_CACHE_KEY_PREFIX = 'friendolls';
const DEFAULT_CACHE_TTL_SECONDS = 60;
const DEFAULT_CACHE_MAX_TTL_SECONDS = 86_400;
const DEFAULT_CACHE_METRICS_LOG_INTERVAL_MS = 60_000;
@Injectable()
export class CacheService implements OnModuleDestroy {
private readonly logger = new Logger(CacheService.name);
private readonly keyPrefix: string;
private readonly defaultTtlSeconds: number;
private readonly maxTtlSeconds: number;
private readonly metricsLogIntervalMs: number;
private readonly metrics = {
getHits: 0,
getMisses: 0,
sets: 0,
deletes: 0,
errors: 0,
unavailable: 0,
};
private metricsTimer: NodeJS.Timeout | null = null;
constructor(
private readonly configService: ConfigService,
@Inject(REDIS_CLIENT) private readonly redisClient: Redis | null,
) {
this.keyPrefix =
this.configService.get<string>('CACHE_KEY_PREFIX') ||
DEFAULT_CACHE_KEY_PREFIX;
this.defaultTtlSeconds = parsePositiveInteger(
this.configService.get<string>('CACHE_DEFAULT_TTL_SECONDS'),
DEFAULT_CACHE_TTL_SECONDS,
);
this.maxTtlSeconds = parsePositiveInteger(
this.configService.get<string>('CACHE_MAX_TTL_SECONDS'),
DEFAULT_CACHE_MAX_TTL_SECONDS,
);
this.metricsLogIntervalMs = parsePositiveInteger(
this.configService.get<string>('CACHE_METRICS_LOG_INTERVAL_MS'),
DEFAULT_CACHE_METRICS_LOG_INTERVAL_MS,
);
if (this.metricsLogIntervalMs > 0) {
this.metricsTimer = setInterval(() => {
this.flushMetrics();
}, this.metricsLogIntervalMs);
this.metricsTimer.unref();
}
}
onModuleDestroy(): void {
if (this.metricsTimer) {
clearInterval(this.metricsTimer);
this.metricsTimer = null;
}
this.flushMetrics();
}
getNamespacedKey(namespace: string, key: string): string {
return `${this.keyPrefix}:${namespace}:${key}`;
}
resolveTtlSeconds(ttlSeconds?: number): number {
if (ttlSeconds === undefined) {
return this.defaultTtlSeconds;
}
if (!Number.isFinite(ttlSeconds) || ttlSeconds <= 0) {
return this.defaultTtlSeconds;
}
return Math.min(Math.floor(ttlSeconds), this.maxTtlSeconds);
}
async get(key: string): Promise<string | null> {
if (!this.redisClient) {
this.metrics.unavailable += 1;
return null;
}
try {
const value = await this.redisClient.get(key);
if (value === null) {
this.metrics.getMisses += 1;
} else {
this.metrics.getHits += 1;
}
return value;
} catch (error) {
this.metrics.errors += 1;
this.logger.warn(`Cache get failed for key ${key}`, error as Error);
return null;
}
}
async set(key: string, value: string, ttlSeconds?: number): Promise<boolean> {
if (!this.redisClient) {
this.metrics.unavailable += 1;
return false;
}
try {
const ttl = this.resolveTtlSeconds(ttlSeconds);
await this.redisClient.set(key, value, 'EX', ttl);
this.metrics.sets += 1;
return true;
} catch (error) {
this.metrics.errors += 1;
this.logger.warn(`Cache set failed for key ${key}`, error as Error);
return false;
}
}
async del(key: string): Promise<boolean> {
if (!this.redisClient) {
this.metrics.unavailable += 1;
return false;
}
try {
await this.redisClient.del(key);
this.metrics.deletes += 1;
return true;
} catch (error) {
this.metrics.errors += 1;
this.logger.warn(`Cache delete failed for key ${key}`, error as Error);
return false;
}
}
getRedisClient(): Redis | null {
return this.redisClient;
}
recordError(operation: string, key: string, error: unknown): void {
this.metrics.errors += 1;
this.logger.warn(
`Cache ${operation} failed for key ${key}`,
error as Error,
);
}
recordUnavailable(): void {
this.metrics.unavailable += 1;
}
private flushMetrics(): void {
const totalReads = this.metrics.getHits + this.metrics.getMisses;
if (
totalReads === 0 &&
this.metrics.sets === 0 &&
this.metrics.deletes === 0 &&
this.metrics.errors === 0 &&
this.metrics.unavailable === 0
) {
return;
}
const hitRate =
totalReads === 0
? '0.00'
: ((this.metrics.getHits / totalReads) * 100).toFixed(2);
this.logger.log(
`metrics reads=${totalReads} hits=${this.metrics.getHits} misses=${this.metrics.getMisses} hitRate=${hitRate}% sets=${this.metrics.sets} deletes=${this.metrics.deletes} errors=${this.metrics.errors} unavailable=${this.metrics.unavailable}`,
);
this.metrics.getHits = 0;
this.metrics.getMisses = 0;
this.metrics.sets = 0;
this.metrics.deletes = 0;
this.metrics.errors = 0;
this.metrics.unavailable = 0;
}
}

4
src/common/cache/index.ts vendored Normal file
View File

@@ -0,0 +1,4 @@
export { CacheModule } from './cache.module';
export { CacheService } from './cache.service';
export { CacheTagsService } from './cache-tags.service';
export { RedisThrottlerStorage } from './redis-throttler.storage';

View File

@@ -0,0 +1,239 @@
import { Injectable } from '@nestjs/common';
import { ThrottlerStorage } from '@nestjs/throttler';
import { CacheService } from './cache.service';
interface RedisThrottlerStorageRecord {
totalHits: number;
timeToExpire: number;
isBlocked: boolean;
timeToBlockExpire: number;
}
@Injectable()
export class RedisThrottlerStorage implements ThrottlerStorage {
private static readonly IN_MEMORY_CLEANUP_INTERVAL = 500;
private readonly inMemoryStorage = new Map<
string,
{ totalHits: number; expiresAt: number; blockExpiresAt: number }
>();
private inMemoryOperationCount = 0;
constructor(private readonly cacheService: CacheService) {}
async increment(
key: string,
ttl: number,
limit: number,
blockDuration: number,
throttlerName: string,
): Promise<RedisThrottlerStorageRecord> {
const safeLimit = Math.max(0, Math.floor(limit));
const ttlMilliseconds = this.normalizeDurationMs(ttl);
const blockDurationMilliseconds = this.normalizeDurationMs(blockDuration);
const counterKey = this.cacheService.getNamespacedKey(
'throttle:counter',
`${throttlerName}:${key}`,
);
const blockKey = this.cacheService.getNamespacedKey(
'throttle:block',
`${throttlerName}:${key}`,
);
const redisClient = this.cacheService.getRedisClient();
if (!redisClient) {
this.cacheService.recordUnavailable();
return this.incrementInMemory(
counterKey,
ttlMilliseconds,
safeLimit,
blockDurationMilliseconds,
);
}
try {
const initialized = await redisClient.set(
counterKey,
'1',
'PX',
ttlMilliseconds,
'NX',
);
if (initialized === 'OK') {
const existingBlockTtlRemainingMs = await redisClient.pttl(blockKey);
return {
totalHits: 1,
timeToExpire: Math.ceil(ttlMilliseconds / 1000),
isBlocked: existingBlockTtlRemainingMs > 0,
timeToBlockExpire:
existingBlockTtlRemainingMs > 0
? this.toSecondsFromPttl(
existingBlockTtlRemainingMs,
blockDurationMilliseconds,
)
: 0,
};
}
const [
existingBlockTtlRemainingMs,
ttlRemainingBeforeHitMs,
currentCount,
] = await Promise.all([
redisClient.pttl(blockKey),
redisClient.pttl(counterKey),
redisClient.get(counterKey),
]);
if (existingBlockTtlRemainingMs > 0) {
const totalHits = Number(currentCount ?? '0');
return {
totalHits: Number.isFinite(totalHits) ? totalHits : 0,
timeToExpire: this.toSecondsFromPttl(
ttlRemainingBeforeHitMs,
ttlMilliseconds,
),
isBlocked: true,
timeToBlockExpire: this.toSecondsFromPttl(
existingBlockTtlRemainingMs,
blockDurationMilliseconds,
),
};
}
const count = await redisClient.incr(counterKey);
if (count === 1) {
await redisClient.pexpire(counterKey, ttlMilliseconds);
}
const [ttlRemainingMs, blockTtlRemainingMs] = await Promise.all([
redisClient.pttl(counterKey),
redisClient.pttl(blockKey),
]);
let isBlocked = blockTtlRemainingMs > 0;
if (!isBlocked && safeLimit > 0 && count > safeLimit) {
await redisClient.set(blockKey, '1', 'PX', blockDurationMilliseconds);
isBlocked = true;
}
const refreshedBlockTtlRemainingMs = isBlocked
? await redisClient.pttl(blockKey)
: -1;
return {
totalHits: count,
timeToExpire: this.toSecondsFromPttl(ttlRemainingMs, ttlMilliseconds),
isBlocked,
timeToBlockExpire: isBlocked
? this.toSecondsFromPttl(
refreshedBlockTtlRemainingMs,
blockDurationMilliseconds,
)
: 0,
};
} catch (error) {
this.cacheService.recordError('throttler increment', counterKey, error);
return this.incrementInMemory(
counterKey,
ttlMilliseconds,
safeLimit,
blockDurationMilliseconds,
);
}
}
private incrementInMemory(
key: string,
ttlMilliseconds: number,
limit: number,
blockDurationMilliseconds: number,
): RedisThrottlerStorageRecord {
const now = Date.now();
this.inMemoryOperationCount += 1;
if (
this.inMemoryOperationCount %
RedisThrottlerStorage.IN_MEMORY_CLEANUP_INTERVAL ===
0
) {
this.cleanupExpiredInMemory(now);
}
const existing = this.inMemoryStorage.get(key);
if (existing && existing.blockExpiresAt > now) {
return {
totalHits: existing.totalHits,
timeToExpire: Math.max(1, Math.ceil((existing.expiresAt - now) / 1000)),
isBlocked: true,
timeToBlockExpire: Math.max(
1,
Math.ceil((existing.blockExpiresAt - now) / 1000),
),
};
}
let totalHits = 1;
let expiresAt = now + ttlMilliseconds;
let blockExpiresAt = 0;
if (existing && existing.expiresAt > now) {
totalHits = existing.totalHits + 1;
expiresAt = existing.expiresAt;
}
if (blockExpiresAt <= now) {
blockExpiresAt = 0;
}
let isBlocked = blockExpiresAt > now;
if (!isBlocked && limit > 0 && totalHits > limit) {
blockExpiresAt = now + blockDurationMilliseconds;
isBlocked = true;
}
this.inMemoryStorage.set(key, {
totalHits,
expiresAt,
blockExpiresAt,
});
return {
totalHits,
timeToExpire: Math.max(1, Math.ceil((expiresAt - now) / 1000)),
isBlocked,
timeToBlockExpire: isBlocked
? Math.max(1, Math.ceil((blockExpiresAt - now) / 1000))
: 0,
};
}
private normalizeDurationMs(value: number): number {
if (!Number.isFinite(value) || value <= 0) {
return 1000;
}
return Math.max(1, Math.floor(value));
}
private toSecondsFromPttl(pttlMs: number, fallbackMs: number): number {
if (pttlMs > 0) {
return Math.max(1, Math.ceil(pttlMs / 1000));
}
return Math.max(1, Math.ceil(fallbackMs / 1000));
}
private cleanupExpiredInMemory(now: number): void {
for (const [mapKey, value] of this.inMemoryStorage) {
const counterExpired = value.expiresAt <= now;
const blockExpired = value.blockExpiresAt <= now;
if (counterExpired && blockExpired) {
this.inMemoryStorage.delete(mapKey);
}
}
}
}

View File

@@ -0,0 +1,37 @@
import { Injectable } from '@nestjs/common';
import { OnEvent } from '@nestjs/event-emitter';
import { CacheTagsService } from '../common/cache/cache-tags.service';
import { CACHE_NAMESPACE, dollsListOwnerTag } from '../common/cache/cache-keys';
import { DollEvents } from './events/doll.events';
import {
type DollCreatedEvent,
type DollDeletedEvent,
type DollUpdatedEvent,
} from './events/doll.events';
@Injectable()
export class DollsCacheInvalidationService {
constructor(private readonly cacheTagsService: CacheTagsService) {}
@OnEvent(DollEvents.DOLL_CREATED)
async handleDollCreated(payload: DollCreatedEvent): Promise<void> {
await this.invalidateOwnerLists(payload.userId);
}
@OnEvent(DollEvents.DOLL_UPDATED)
async handleDollUpdated(payload: DollUpdatedEvent): Promise<void> {
await this.invalidateOwnerLists(payload.userId);
}
@OnEvent(DollEvents.DOLL_DELETED)
async handleDollDeleted(payload: DollDeletedEvent): Promise<void> {
await this.invalidateOwnerLists(payload.userId);
}
private async invalidateOwnerLists(ownerId: string): Promise<void> {
await this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.DOLLS_LIST,
dollsListOwnerTag(ownerId),
);
}
}

View File

@@ -1,15 +1,26 @@
import { Module, forwardRef } from '@nestjs/common';
import { DollsService } from './dolls.service';
import { DollsController } from './dolls.controller';
import { DollsCacheInvalidationService } from './dolls-cache-invalidation.service';
import { DollsNotificationService } from './dolls-notification.service';
import { DatabaseModule } from '../database/database.module';
import { AuthModule } from '../auth/auth.module';
import { WsModule } from '../ws/ws.module';
import { FriendsModule } from '../friends/friends.module';
@Module({
imports: [DatabaseModule, AuthModule, forwardRef(() => WsModule)],
imports: [
DatabaseModule,
AuthModule,
FriendsModule,
forwardRef(() => WsModule),
],
controllers: [DollsController],
providers: [DollsService, DollsNotificationService],
providers: [
DollsService,
DollsNotificationService,
DollsCacheInvalidationService,
],
exports: [DollsService],
})
export class DollsModule {}

View File

@@ -4,6 +4,9 @@ import { DollsService } from './dolls.service';
import { PrismaService } from '../database/prisma.service';
import { NotFoundException, ForbiddenException } from '@nestjs/common';
import { Doll } from '@prisma/client';
import { CacheService } from '../common/cache/cache.service';
import { CacheTagsService } from '../common/cache/cache-tags.service';
import { FriendsService } from '../friends/friends.service';
describe('DollsService', () => {
let service: DollsService;
@@ -31,9 +34,6 @@ describe('DollsService', () => {
findFirst: jest.fn().mockResolvedValue(mockDoll),
update: jest.fn().mockResolvedValue(mockDoll),
},
friendship: {
findMany: jest.fn().mockResolvedValue([]),
},
$transaction: jest.fn((callback) => {
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
return callback(mockPrismaService);
@@ -47,6 +47,25 @@ describe('DollsService', () => {
emit: jest.fn(),
};
const mockCacheService = {
get: jest.fn().mockResolvedValue(null),
set: jest.fn().mockResolvedValue(true),
getNamespacedKey: jest
.fn()
.mockImplementation(
(namespace: string, key: string) => `friendolls:${namespace}:${key}`,
),
recordError: jest.fn(),
};
const mockCacheTagsService = {
rememberKeyForTag: jest.fn().mockResolvedValue(undefined),
};
const mockFriendsService = {
areFriends: jest.fn().mockResolvedValue(false),
};
beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({
providers: [
@@ -59,6 +78,18 @@ describe('DollsService', () => {
provide: EventEmitter2,
useValue: mockEventEmitter,
},
{
provide: CacheService,
useValue: mockCacheService,
},
{
provide: CacheTagsService,
useValue: mockCacheTagsService,
},
{
provide: FriendsService,
useValue: mockFriendsService,
},
],
}).compile();
@@ -112,10 +143,7 @@ describe('DollsService', () => {
const ownerId = 'friend-1';
const requestingUserId = 'user-1';
// Mock friendship
jest
.spyOn(prismaService.friendship, 'findMany')
.mockResolvedValueOnce([{ friendId: ownerId } as any]);
(mockFriendsService.areFriends as jest.Mock).mockResolvedValueOnce(true);
await service.listByOwner(ownerId, requestingUserId);
@@ -134,10 +162,7 @@ describe('DollsService', () => {
const ownerId = 'stranger-1';
const requestingUserId = 'user-1';
// Mock empty friendship (default)
jest
.spyOn(prismaService.friendship, 'findMany')
.mockResolvedValueOnce([]);
(mockFriendsService.areFriends as jest.Mock).mockResolvedValueOnce(false);
await expect(
service.listByOwner(ownerId, requestingUserId),
@@ -163,7 +188,10 @@ describe('DollsService', () => {
});
it('should throw NotFoundException if doll not accessible', async () => {
jest.spyOn(prismaService.doll, 'findFirst').mockResolvedValueOnce(null);
jest
.spyOn(prismaService.doll, 'findFirst')
.mockResolvedValueOnce({ ...mockDoll, userId: 'user-2' });
(mockFriendsService.areFriends as jest.Mock).mockResolvedValueOnce(false);
await expect(service.findOne('doll-1', 'user-1')).rejects.toThrow(
NotFoundException,
@@ -179,7 +207,7 @@ describe('DollsService', () => {
expect(prismaService.doll.update).toHaveBeenCalled();
});
it('should throw ForbiddenException if not owner', async () => {
it('should throw NotFoundException if not owner and not a friend', async () => {
jest
.spyOn(prismaService.doll, 'findFirst')
.mockResolvedValueOnce({ ...mockDoll, userId: 'user-2' });
@@ -187,7 +215,7 @@ describe('DollsService', () => {
const updateDto = { name: 'Updated Doll' };
await expect(
service.update('doll-1', 'user-1', updateDto),
).rejects.toThrow(ForbiddenException);
).rejects.toThrow(NotFoundException);
});
});
@@ -203,13 +231,13 @@ describe('DollsService', () => {
});
});
it('should throw ForbiddenException if not owner', async () => {
it('should throw NotFoundException if not owner and not a friend', async () => {
jest
.spyOn(prismaService.doll, 'findFirst')
.mockResolvedValueOnce({ ...mockDoll, userId: 'user-2' });
await expect(service.remove('doll-1', 'user-1')).rejects.toThrow(
ForbiddenException,
NotFoundException,
);
});
});

View File

@@ -15,6 +15,16 @@ import {
DollUpdatedEvent,
DollDeletedEvent,
} from './events/doll.events';
import { CacheService } from '../common/cache/cache.service';
import { CacheTagsService } from '../common/cache/cache-tags.service';
import {
CACHE_NAMESPACE,
CACHE_TTL_SECONDS,
dollsListCacheKey,
dollsListOwnerTag,
dollsListViewerTag,
} from '../common/cache/cache-keys';
import { FriendsService } from '../friends/friends.service';
@Injectable()
export class DollsService {
@@ -23,16 +33,11 @@ export class DollsService {
constructor(
private readonly prisma: PrismaService,
private readonly eventEmitter: EventEmitter2,
private readonly cacheService: CacheService,
private readonly cacheTagsService: CacheTagsService,
private readonly friendsService: FriendsService,
) {}
async getFriendIds(userId: string): Promise<string[]> {
const friendships = await this.prisma.friendship.findMany({
where: { userId },
select: { friendId: true },
});
return friendships.map((f) => f.friendId);
}
async create(
requestingUserId: string,
createDollDto: CreateDollDto,
@@ -76,6 +81,48 @@ export class DollsService {
async listByOwner(
ownerId: string,
requestingUserId: string,
): Promise<Doll[]> {
const cacheKey = dollsListCacheKey(ownerId, requestingUserId);
const namespacedKey = this.cacheService.getNamespacedKey(
CACHE_NAMESPACE.DOLLS_LIST,
cacheKey,
);
const cached = await this.cacheService.get(namespacedKey);
if (cached) {
try {
return JSON.parse(cached) as Doll[];
} catch (error) {
this.cacheService.recordError('dolls list parse', namespacedKey, error);
}
}
const dolls = await this.listByOwnerFromDatabase(ownerId, requestingUserId);
await this.cacheService.set(
namespacedKey,
JSON.stringify(dolls),
CACHE_TTL_SECONDS.DOLLS_LIST,
);
await Promise.all([
this.cacheTagsService.rememberKeyForTag(
CACHE_NAMESPACE.DOLLS_LIST,
dollsListOwnerTag(ownerId),
cacheKey,
),
this.cacheTagsService.rememberKeyForTag(
CACHE_NAMESPACE.DOLLS_LIST,
dollsListViewerTag(requestingUserId),
cacheKey,
),
]);
return dolls;
}
private async listByOwnerFromDatabase(
ownerId: string,
requestingUserId: string,
): Promise<Doll[]> {
// If requesting own dolls, no need to check friendship
if (ownerId === requestingUserId) {
@@ -91,8 +138,11 @@ export class DollsService {
}
// If requesting someone else's dolls, check friendship
const friendIds = await this.getFriendIds(requestingUserId);
if (!friendIds.includes(ownerId)) {
const isFriend = await this.friendsService.areFriends(
requestingUserId,
ownerId,
);
if (!isFriend) {
throw new ForbiddenException('You are not friends with this user');
}
@@ -108,13 +158,9 @@ export class DollsService {
}
async findOne(id: string, requestingUserId: string): Promise<Doll> {
const friendIds = await this.getFriendIds(requestingUserId);
const accessibleUserIds = [requestingUserId, ...friendIds];
const doll = await this.prisma.doll.findFirst({
where: {
id,
userId: { in: accessibleUserIds },
deletedAt: null,
},
});
@@ -125,6 +171,18 @@ export class DollsService {
);
}
if (doll.userId !== requestingUserId) {
const isFriend = await this.friendsService.areFriends(
requestingUserId,
doll.userId,
);
if (!isFriend) {
throw new NotFoundException(
`Doll with ID ${id} not found or access denied`,
);
}
}
return doll;
}

View File

@@ -0,0 +1,74 @@
import { Injectable } from '@nestjs/common';
import { OnEvent } from '@nestjs/event-emitter';
import { CacheTagsService } from '../common/cache/cache-tags.service';
import {
CACHE_NAMESPACE,
dollsListViewerTag,
friendshipCheckUserTag,
friendsListDependsOnUserTag,
friendsListOwnerTag,
} from '../common/cache/cache-keys';
import { FriendEvents } from './events/friend.events';
import type {
FriendRequestAcceptedEvent,
UnfriendedEvent,
} from './events/friend.events';
@Injectable()
export class FriendsCacheInvalidationService {
constructor(private readonly cacheTagsService: CacheTagsService) {}
@OnEvent(FriendEvents.REQUEST_ACCEPTED)
async handleFriendAccepted(
payload: FriendRequestAcceptedEvent,
): Promise<void> {
const senderId = payload.friendRequest.senderId;
const receiverId = payload.friendRequest.receiverId;
await this.invalidateFriendAndDollViews(senderId, receiverId);
}
@OnEvent(FriendEvents.UNFRIENDED)
async handleUnfriended(payload: UnfriendedEvent): Promise<void> {
await this.invalidateFriendAndDollViews(payload.userId, payload.friendId);
}
private async invalidateFriendAndDollViews(
firstUserId: string,
secondUserId: string,
): Promise<void> {
await Promise.all([
this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.FRIENDS_LIST,
friendsListOwnerTag(firstUserId),
),
this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.FRIENDS_LIST,
friendsListOwnerTag(secondUserId),
),
this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.FRIENDS_LIST,
friendsListDependsOnUserTag(firstUserId),
),
this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.FRIENDS_LIST,
friendsListDependsOnUserTag(secondUserId),
),
this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.DOLLS_LIST,
dollsListViewerTag(firstUserId),
),
this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.DOLLS_LIST,
dollsListViewerTag(secondUserId),
),
this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.FRIENDSHIP_CHECK,
friendshipCheckUserTag(firstUserId),
),
this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.FRIENDSHIP_CHECK,
friendshipCheckUserTag(secondUserId),
),
]);
}
}

View File

@@ -1,5 +1,6 @@
import { Module, forwardRef } from '@nestjs/common';
import { FriendsController } from './friends.controller';
import { FriendsCacheInvalidationService } from './friends-cache-invalidation.service';
import { FriendsService } from './friends.service';
import { FriendsNotificationService } from './friends-notification.service';
import { DatabaseModule } from '../database/database.module';
@@ -15,7 +16,11 @@ import { WsModule } from '../ws/ws.module';
forwardRef(() => WsModule),
],
controllers: [FriendsController],
providers: [FriendsService, FriendsNotificationService],
providers: [
FriendsService,
FriendsNotificationService,
FriendsCacheInvalidationService,
],
exports: [FriendsService],
})
export class FriendsModule {}

View File

@@ -2,6 +2,8 @@ import { Test, TestingModule } from '@nestjs/testing';
import { EventEmitter2 } from '@nestjs/event-emitter';
import { FriendsService } from './friends.service';
import { PrismaService } from '../database/prisma.service';
import { CacheService } from '../common/cache/cache.service';
import { CacheTagsService } from '../common/cache/cache-tags.service';
import {
NotFoundException,
BadRequestException,
@@ -17,6 +19,8 @@ enum FriendRequestStatus {
describe('FriendsService', () => {
let service: FriendsService;
let eventEmitter: EventEmitter2;
let cacheService: CacheService;
let cacheTagsService: CacheTagsService;
const mockUser1 = {
id: 'user-1',
@@ -90,6 +94,21 @@ describe('FriendsService', () => {
emit: jest.fn(),
};
const mockCacheService = {
get: jest.fn().mockResolvedValue(null),
set: jest.fn().mockResolvedValue(true),
getNamespacedKey: jest
.fn()
.mockImplementation(
(namespace: string, key: string) => `friendolls:${namespace}:${key}`,
),
recordError: jest.fn(),
};
const mockCacheTagsService = {
rememberKeyForTag: jest.fn().mockResolvedValue(undefined),
};
beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({
providers: [
@@ -102,11 +121,21 @@ describe('FriendsService', () => {
provide: EventEmitter2,
useValue: mockEventEmitter,
},
{
provide: CacheService,
useValue: mockCacheService,
},
{
provide: CacheTagsService,
useValue: mockCacheTagsService,
},
],
}).compile();
service = module.get<FriendsService>(FriendsService);
eventEmitter = module.get<EventEmitter2>(EventEmitter2);
cacheService = module.get<CacheService>(CacheService);
cacheTagsService = module.get<CacheTagsService>(CacheTagsService);
jest.clearAllMocks();
});
@@ -420,6 +449,8 @@ describe('FriendsService', () => {
createdAt: 'desc',
},
});
expect(cacheService.set).toHaveBeenCalled();
expect(cacheTagsService.rememberKeyForTag).toHaveBeenCalled();
});
});
@@ -469,6 +500,12 @@ describe('FriendsService', () => {
const result = await service.areFriends('user-1', 'user-2');
expect(result).toBe(true);
expect(cacheService.set).toHaveBeenCalledWith(
expect.any(String),
'1',
expect.any(Number),
);
expect(cacheTagsService.rememberKeyForTag).toHaveBeenCalled();
});
it('should return false when users are not friends', async () => {
@@ -477,6 +514,11 @@ describe('FriendsService', () => {
const result = await service.areFriends('user-1', 'user-2');
expect(result).toBe(false);
expect(cacheService.set).toHaveBeenCalledWith(
expect.any(String),
'0',
expect.any(Number),
);
});
});
});

View File

@@ -15,6 +15,17 @@ import {
FriendRequestDeniedEvent,
UnfriendedEvent,
} from './events/friend.events';
import { CacheService } from '../common/cache/cache.service';
import { CacheTagsService } from '../common/cache/cache-tags.service';
import {
CACHE_NAMESPACE,
CACHE_TTL_SECONDS,
friendshipCheckCacheKey,
friendshipCheckUserTag,
friendsListCacheKey,
friendsListDependsOnUserTag,
friendsListOwnerTag,
} from '../common/cache/cache-keys';
export type FriendRequestWithRelations = FriendRequest & {
sender: User;
@@ -28,6 +39,8 @@ export class FriendsService {
constructor(
private readonly prisma: PrismaService,
private readonly eventEmitter: EventEmitter2,
private readonly cacheService: CacheService,
private readonly cacheTagsService: CacheTagsService,
) {}
async sendFriendRequest(
@@ -272,7 +285,28 @@ export class FriendsService {
}
async getFriends(userId: string) {
return this.prisma.friendship.findMany({
const cacheKey = friendsListCacheKey(userId);
const namespacedKey = this.cacheService.getNamespacedKey(
CACHE_NAMESPACE.FRIENDS_LIST,
cacheKey,
);
const cached = await this.cacheService.get(namespacedKey);
if (cached) {
try {
return JSON.parse(cached) as Awaited<
ReturnType<PrismaService['friendship']['findMany']>
>;
} catch (error) {
this.cacheService.recordError(
'friends list parse',
namespacedKey,
error,
);
}
}
const friendships = await this.prisma.friendship.findMany({
where: { userId },
include: {
friend: {
@@ -285,6 +319,29 @@ export class FriendsService {
createdAt: 'desc',
},
});
await this.cacheService.set(
namespacedKey,
JSON.stringify(friendships),
CACHE_TTL_SECONDS.FRIENDS_LIST,
);
const dependentFriendTags = friendships.map((friendship) =>
friendsListDependsOnUserTag(friendship.friendId),
);
const tags = [friendsListOwnerTag(userId), ...dependentFriendTags];
await Promise.all(
tags.map((tag) =>
this.cacheTagsService.rememberKeyForTag(
CACHE_NAMESPACE.FRIENDS_LIST,
tag,
cacheKey,
),
),
);
return friendships;
}
async unfriend(userId: string, friendId: string): Promise<void> {
@@ -323,6 +380,21 @@ export class FriendsService {
}
async areFriends(userId: string, friendId: string): Promise<boolean> {
const cacheKey = friendshipCheckCacheKey(userId, friendId);
const namespacedKey = this.cacheService.getNamespacedKey(
CACHE_NAMESPACE.FRIENDSHIP_CHECK,
cacheKey,
);
const cached = await this.cacheService.get(namespacedKey);
if (cached === '1') {
return true;
}
if (cached === '0') {
return false;
}
const friendship = await this.prisma.friendship.findFirst({
where: {
userId,
@@ -330,6 +402,27 @@ export class FriendsService {
},
});
return !!friendship;
const areFriends = !!friendship;
await this.cacheService.set(
namespacedKey,
areFriends ? '1' : '0',
CACHE_TTL_SECONDS.FRIENDSHIP_CHECK,
);
await Promise.all([
this.cacheTagsService.rememberKeyForTag(
CACHE_NAMESPACE.FRIENDSHIP_CHECK,
friendshipCheckUserTag(userId),
cacheKey,
),
this.cacheTagsService.rememberKeyForTag(
CACHE_NAMESPACE.FRIENDSHIP_CHECK,
friendshipCheckUserTag(friendId),
cacheKey,
),
]);
return areFriends;
}
}

View File

@@ -13,5 +13,6 @@ export type AuthenticatedSocket = BaseSocket<
friends?: Set<string>; // Set of friend user IDs
senderName?: string;
senderNameCachedAt?: number;
lastSeenAt?: number;
}
>;

View File

@@ -2,6 +2,8 @@ import { Doll } from '@prisma/client';
export const UserEvents = {
ACTIVE_DOLL_CHANGED: 'user.active-doll.changed',
SEARCH_INDEX_INVALIDATED: 'user.search-index.invalidated',
PROFILE_UPDATED: 'user.profile.updated',
} as const;
export interface UserActiveDollChangedEvent {
@@ -9,3 +11,11 @@ export interface UserActiveDollChangedEvent {
dollId: string | null;
doll: Doll | null;
}
export interface UserSearchIndexInvalidatedEvent {
userId?: string;
}
export interface UserProfileUpdatedEvent {
userId: string;
}

View File

@@ -0,0 +1,38 @@
import { Injectable } from '@nestjs/common';
import { OnEvent } from '@nestjs/event-emitter';
import { CacheTagsService } from '../common/cache/cache-tags.service';
import {
CACHE_NAMESPACE,
USERS_SEARCH_GLOBAL_TAG,
usersSearchUserTag,
} from '../common/cache/cache-keys';
import { UserEvents } from './events/user.events';
import type { UserSearchIndexInvalidatedEvent } from './events/user.events';
@Injectable()
export class UsersCacheInvalidationService {
constructor(private readonly cacheTagsService: CacheTagsService) {}
@OnEvent(UserEvents.SEARCH_INDEX_INVALIDATED)
async handleSearchIndexInvalidation(
payload: UserSearchIndexInvalidatedEvent,
): Promise<void> {
const tasks: Promise<void>[] = [
this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.USERS_SEARCH,
USERS_SEARCH_GLOBAL_TAG,
),
];
if (payload.userId) {
tasks.push(
this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.USERS_SEARCH,
usersSearchUserTag(payload.userId),
),
);
}
await Promise.all(tasks);
}
}

View File

@@ -1,5 +1,6 @@
import { Module, forwardRef } from '@nestjs/common';
import { UsersService } from './users.service';
import { UsersCacheInvalidationService } from './users-cache-invalidation.service';
import { UsersController } from './users.controller';
import { UsersNotificationService } from './users-notification.service';
import { AuthModule } from '../auth/auth.module';
@@ -16,7 +17,11 @@ import { WsModule } from '../ws/ws.module';
*/
@Module({
imports: [forwardRef(() => AuthModule), forwardRef(() => WsModule)],
providers: [UsersService, UsersNotificationService],
providers: [
UsersService,
UsersNotificationService,
UsersCacheInvalidationService,
],
controllers: [UsersController],
exports: [UsersService],
})

View File

@@ -5,9 +5,13 @@ import { NotFoundException, ForbiddenException } from '@nestjs/common';
import { User } from '@prisma/client';
import { UpdateUserDto } from './dto/update-user.dto';
import { EventEmitter2 } from '@nestjs/event-emitter';
import { CacheService } from '../common/cache/cache.service';
import { CacheTagsService } from '../common/cache/cache-tags.service';
describe('UsersService', () => {
let service: UsersService;
let cacheService: CacheService;
let cacheTagsService: CacheTagsService;
const mockUser: User & { passwordHash?: string | null } = {
id: '550e8400-e29b-41d4-a716-446655440000',
@@ -39,6 +43,21 @@ describe('UsersService', () => {
emit: jest.fn(),
};
const mockCacheService = {
get: jest.fn().mockResolvedValue(null),
set: jest.fn().mockResolvedValue(true),
getNamespacedKey: jest
.fn()
.mockImplementation(
(namespace: string, key: string) => `friendolls:${namespace}:${key}`,
),
recordError: jest.fn(),
};
const mockCacheTagsService = {
rememberKeyForTag: jest.fn().mockResolvedValue(undefined),
};
beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({
providers: [
@@ -51,10 +70,20 @@ describe('UsersService', () => {
provide: EventEmitter2,
useValue: mockEventEmitter,
},
{
provide: CacheService,
useValue: mockCacheService,
},
{
provide: CacheTagsService,
useValue: mockCacheTagsService,
},
],
}).compile();
service = module.get<UsersService>(UsersService);
cacheService = module.get<CacheService>(CacheService);
cacheTagsService = module.get<CacheTagsService>(CacheTagsService);
jest.clearAllMocks();
});
@@ -227,6 +256,8 @@ describe('UsersService', () => {
username: 'asc',
},
});
expect(cacheService.set).toHaveBeenCalled();
expect(cacheTagsService.rememberKeyForTag).toHaveBeenCalled();
});
it('should exclude specified user from results', async () => {

View File

@@ -10,6 +10,15 @@ import { User, Prisma } from '@prisma/client';
import type { UpdateUserDto } from './dto/update-user.dto';
import { UserEvents } from './events/user.events';
import { normalizeEmail } from '../auth/auth.utils';
import { CacheService } from '../common/cache/cache.service';
import { CacheTagsService } from '../common/cache/cache-tags.service';
import {
CACHE_NAMESPACE,
CACHE_TTL_SECONDS,
usersSearchUserTag,
usersSearchCacheKey,
USERS_SEARCH_GLOBAL_TAG,
} from '../common/cache/cache-keys';
export interface CreateLocalUserDto {
email: string;
@@ -30,6 +39,8 @@ export class UsersService {
constructor(
private readonly prisma: PrismaService,
private readonly eventEmitter: EventEmitter2,
private readonly cacheService: CacheService,
private readonly cacheTagsService: CacheTagsService,
) {}
// Legacy Keycloak user creation removed in favor of local auth.
@@ -92,6 +103,9 @@ export class UsersService {
data: updateData,
});
this.eventEmitter.emit(UserEvents.SEARCH_INDEX_INVALIDATED, { userId: id });
this.eventEmitter.emit(UserEvents.PROFILE_UPDATED, { userId: id });
this.logger.log(`User ${id} profile update requested`);
return updatedUser;
@@ -121,6 +135,9 @@ export class UsersService {
where: { id },
});
this.eventEmitter.emit(UserEvents.SEARCH_INDEX_INVALIDATED, { userId: id });
this.eventEmitter.emit(UserEvents.PROFILE_UPDATED, { userId: id });
this.logger.log(`User ${id} deleted their account`);
}
@@ -128,6 +145,25 @@ export class UsersService {
username?: string,
excludeUserId?: string,
): Promise<User[]> {
const cacheKey = usersSearchCacheKey(username, excludeUserId);
const namespacedKey = this.cacheService.getNamespacedKey(
CACHE_NAMESPACE.USERS_SEARCH,
cacheKey,
);
const cached = await this.cacheService.get(namespacedKey);
if (cached) {
try {
return JSON.parse(cached) as User[];
} catch (error) {
this.cacheService.recordError(
'users search parse',
namespacedKey,
error,
);
}
}
const where: Prisma.UserWhereInput = {};
if (username) {
@@ -151,6 +187,24 @@ export class UsersService {
},
});
await this.cacheService.set(
namespacedKey,
JSON.stringify(users),
CACHE_TTL_SECONDS.USERS_SEARCH,
);
await this.cacheTagsService.rememberKeyForTag(
CACHE_NAMESPACE.USERS_SEARCH,
USERS_SEARCH_GLOBAL_TAG,
cacheKey,
);
if (excludeUserId) {
await this.cacheTagsService.rememberKeyForTag(
CACHE_NAMESPACE.USERS_SEARCH,
usersSearchUserTag(excludeUserId),
cacheKey,
);
}
return users;
}
@@ -251,7 +305,7 @@ export class UsersService {
const now = new Date();
const roles: string[] = [];
return this.prisma.user.create({
const user = await this.prisma.user.create({
data: {
email: normalizeEmail(createDto.email),
name: createDto.name,
@@ -262,6 +316,13 @@ export class UsersService {
keycloakSub: null,
} as unknown as Prisma.UserUncheckedCreateInput,
});
this.eventEmitter.emit(UserEvents.SEARCH_INDEX_INVALIDATED, {
userId: user.id,
});
this.eventEmitter.emit(UserEvents.PROFILE_UPDATED, { userId: user.id });
return user;
}
async updatePasswordHash(
@@ -272,6 +333,8 @@ export class UsersService {
where: { id: userId },
data: { passwordHash } as unknown as Prisma.UserUpdateInput,
});
this.eventEmitter.emit(UserEvents.PROFILE_UPDATED, { userId });
}
async updateLastLogin(userId: string): Promise<void> {

View File

@@ -116,7 +116,9 @@ export class ConnectionHandler {
// 3. Register socket mapping (Redis Write)
await this.userSocketService.setSocket(userState.id, client.id);
await this.userSocketService.touchLastSeen(userState.id);
client.data.userId = userState.id;
client.data.lastSeenAt = Date.now();
client.data.activeDollId = userState.activeDollId || null;
client.data.friends = new Set(friends.map((f) => f.friendId));
@@ -149,7 +151,8 @@ export class ConnectionHandler {
// Check if this socket is still the active one for the user
const currentSocketId = await this.userSocketService.getSocket(userId);
if (currentSocketId === client.id) {
await this.userSocketService.removeSocket(userId);
await this.userSocketService.removeSocket(userId, client.id);
await this.userSocketService.touchLastSeen(userId);
// Note: throttling remove is done in gateway
// Notify friends that this user has disconnected
@@ -179,5 +182,7 @@ export class ConnectionHandler {
this.logger.log(
`Client id: ${client.id} disconnected (user: ${user?.userId || 'unknown'})`,
);
await this.userSocketService.removeSocketById(client.id);
}
}

View File

@@ -41,6 +41,7 @@ export class CursorHandler {
// Broadcast to online friends
const friends = client.data.friends;
if (friends) {
await this.broadcaster.touchPresence(client);
const payload = {
userId: currentUserId,
position: data,

View File

@@ -50,6 +50,8 @@ export class InteractionHandler {
client: AuthenticatedSocket,
data: SendInteractionDto,
) {
await this.wsNotificationService.maybeTouchPresence(client);
const user = client.data.user;
const currentUserId = Validator.validateInitialized(client);

View File

@@ -6,6 +6,7 @@ import { JwtVerificationService } from '../../auth/services/jwt-verification.ser
import { PrismaService } from '../../database/prisma.service';
import { UserSocketService } from './user-socket.service';
import { WsNotificationService } from './ws-notification.service';
import { ConfigService } from '@nestjs/config';
import { SendInteractionDto } from '../dto/send-interaction.dto';
import { WsException } from '@nestjs/websockets';
@@ -45,6 +46,7 @@ describe('StateGateway', () => {
let mockUserSocketService: Partial<UserSocketService>;
let mockRedisClient: { publish: jest.Mock };
let mockRedisSubscriber: { subscribe: jest.Mock; on: jest.Mock };
let mockConfigService: { get: jest.Mock };
let mockWsNotificationService: {
setIo: jest.Mock;
emitToUser: jest.Mock;
@@ -52,6 +54,8 @@ describe('StateGateway', () => {
emitToSocket: jest.Mock;
updateActiveDollCache: jest.Mock;
publishActiveDollUpdate: jest.Mock;
clearSenderNameCache: jest.Mock;
maybeTouchPresence: jest.Mock;
};
beforeEach(async () => {
@@ -92,9 +96,12 @@ describe('StateGateway', () => {
mockUserSocketService = {
setSocket: jest.fn().mockResolvedValue(undefined),
removeSocket: jest.fn().mockResolvedValue(undefined),
removeSocketById: jest.fn().mockResolvedValue(undefined),
touchLastSeen: jest.fn().mockResolvedValue(undefined),
getSocket: jest.fn().mockResolvedValue(null),
isUserOnline: jest.fn().mockResolvedValue(false),
getFriendsSockets: jest.fn().mockResolvedValue([]),
cleanupStalePresence: jest.fn().mockResolvedValue(0),
};
mockRedisClient = {
@@ -106,6 +113,10 @@ describe('StateGateway', () => {
on: jest.fn(),
};
mockConfigService = {
get: jest.fn().mockReturnValue(undefined),
};
mockWsNotificationService = {
setIo: jest.fn(),
emitToUser: jest.fn(),
@@ -113,6 +124,8 @@ describe('StateGateway', () => {
emitToSocket: jest.fn(),
updateActiveDollCache: jest.fn(),
publishActiveDollUpdate: jest.fn(),
clearSenderNameCache: jest.fn().mockResolvedValue(undefined),
maybeTouchPresence: jest.fn().mockResolvedValue(undefined),
};
const module: TestingModule = await Test.createTestingModule({
@@ -125,6 +138,7 @@ describe('StateGateway', () => {
{ provide: PrismaService, useValue: mockPrismaService },
{ provide: UserSocketService, useValue: mockUserSocketService },
{ provide: WsNotificationService, useValue: mockWsNotificationService },
{ provide: ConfigService, useValue: mockConfigService },
{ provide: 'REDIS_CLIENT', useValue: mockRedisClient },
{ provide: 'REDIS_SUBSCRIBER_CLIENT', useValue: mockRedisSubscriber },
],
@@ -161,9 +175,32 @@ describe('StateGateway', () => {
expect(mockRedisSubscriber.subscribe).toHaveBeenCalledWith(
'active-doll-update',
'friend-cache-update',
'user-profile-cache-invalidate',
expect.any(Function),
);
});
it('should route user profile cache invalidation messages', async () => {
gateway.afterInit();
const onCalls = (mockRedisSubscriber.on as jest.Mock).mock.calls;
const messageHandler = onCalls.find(
(call) => call[0] === 'message',
)?.[1] as ((channel: string, message: string) => void) | undefined;
expect(messageHandler).toBeDefined();
messageHandler?.(
'user-profile-cache-invalidate',
JSON.stringify({ userId: 'user-1' }),
);
await Promise.resolve();
expect(
mockWsNotificationService.clearSenderNameCache,
).toHaveBeenCalledWith('user-1');
});
});
describe('handleConnection', () => {
@@ -260,6 +297,9 @@ describe('StateGateway', () => {
'user-id',
'client1',
);
expect(mockUserSocketService.touchLastSeen).toHaveBeenCalledWith(
'user-id',
);
// 2. Fetch State (DB)
expect(mockPrismaService.user!.findUnique).toHaveBeenCalledWith({
@@ -359,6 +399,13 @@ describe('StateGateway', () => {
expect(mockUserSocketService.getSocket).toHaveBeenCalledWith('user-id');
expect(mockUserSocketService.removeSocket).toHaveBeenCalledWith(
'user-id',
'client1',
);
expect(mockUserSocketService.touchLastSeen).toHaveBeenCalledWith(
'user-id',
);
expect(mockUserSocketService.removeSocketById).toHaveBeenCalledWith(
'client1',
);
expect(mockWsNotificationService.emitToSocket).toHaveBeenCalledWith(
'friend-socket-id',

View File

@@ -29,6 +29,11 @@ import { InteractionHandler } from './interaction/handler';
import { RedisHandler } from './utils/redis-handler';
import { Broadcaster } from './utils/broadcasting';
import { Throttler } from './utils/throttling';
import { ConfigService } from '@nestjs/config';
import { parsePositiveInteger } from '../../common/config/env.utils';
const DEFAULT_PRESENCE_STALE_AGE_MS = 7 * 24 * 60 * 60 * 1000;
const DEFAULT_PRESENCE_CLEANUP_INTERVAL_MS = 5 * 60 * 1000;
@WebSocketGateway()
export class StateGateway
@@ -49,12 +54,16 @@ export class StateGateway
private readonly cursorHandler: CursorHandler;
private readonly statusHandler: StatusHandler;
private readonly interactionHandler: InteractionHandler;
private readonly presenceStaleAgeMs: number;
private readonly presenceCleanupIntervalMs: number;
private presenceCleanupTimer: NodeJS.Timeout | null = null;
constructor(
private readonly jwtVerificationService: JwtVerificationService,
private readonly prisma: PrismaService,
private readonly userSocketService: UserSocketService,
private readonly wsNotificationService: WsNotificationService,
private readonly configService: ConfigService,
@Inject(REDIS_CLIENT) private readonly redisClient: Redis | null,
@Inject(REDIS_SUBSCRIBER_CLIENT)
private readonly redisSubscriber: Redis | null,
@@ -78,6 +87,14 @@ export class StateGateway
this.userSocketService,
this.wsNotificationService,
);
this.presenceStaleAgeMs = parsePositiveInteger(
this.configService.get<string>('PRESENCE_STALE_AGE_MS'),
DEFAULT_PRESENCE_STALE_AGE_MS,
);
this.presenceCleanupIntervalMs = parsePositiveInteger(
this.configService.get<string>('PRESENCE_CLEANUP_INTERVAL_MS'),
DEFAULT_PRESENCE_CLEANUP_INTERVAL_MS,
);
// Setup Redis subscription for cross-instance communication
if (this.redisSubscriber) {
@@ -85,6 +102,7 @@ export class StateGateway
.subscribe(
REDIS_CHANNEL.ACTIVE_DOLL_UPDATE,
REDIS_CHANNEL.FRIEND_CACHE_UPDATE,
REDIS_CHANNEL.USER_PROFILE_CACHE_INVALIDATE,
(err) => {
if (err) {
this.logger.error(`Failed to subscribe to Redis channels`, err);
@@ -104,6 +122,9 @@ export class StateGateway
} else if (channel === REDIS_CHANNEL.FRIEND_CACHE_UPDATE) {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
this.redisHandler.handleFriendCacheUpdateMessage(message);
} else if (channel === REDIS_CHANNEL.USER_PROFILE_CACHE_INVALIDATE) {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
this.redisHandler.handleUserProfileCacheInvalidateMessage(message);
}
});
}
@@ -112,6 +133,11 @@ export class StateGateway
afterInit() {
this.logger.log('Initialized');
this.wsNotificationService.setIo(this.io);
this.presenceCleanupTimer = setInterval(() => {
void this.cleanupStalePresence();
}, this.presenceCleanupIntervalMs);
this.presenceCleanupTimer.unref();
}
handleConnection(client: AuthenticatedSocket) {
@@ -152,6 +178,7 @@ export class StateGateway
await this.statusHandler.handleClientReportUserStatus(client, data);
}
@SubscribeMessage(WS_EVENT.CLIENT_SEND_INTERACTION)
async handleSendInteraction(
client: AuthenticatedSocket,
data: SendInteractionDto,
@@ -163,5 +190,18 @@ export class StateGateway
if (this.redisSubscriber) {
this.redisSubscriber.removeAllListeners('message');
}
if (this.presenceCleanupTimer) {
clearInterval(this.presenceCleanupTimer);
this.presenceCleanupTimer = null;
}
}
private async cleanupStalePresence(): Promise<void> {
const cutoffMs = Date.now() - this.presenceStaleAgeMs;
const removed = await this.userSocketService.cleanupStalePresence(cutoffMs);
if (removed > 0) {
this.logger.debug(`Cleaned up ${removed} stale presence entries`);
}
}
}

View File

@@ -47,6 +47,7 @@ export class StatusHandler {
const friends = client.data.friends;
if (friends) {
try {
await this.broadcaster.touchPresence(client);
const payload = {
userId: currentUserId,
status: data,

View File

@@ -2,12 +2,76 @@ import { Injectable, Inject, Logger } from '@nestjs/common';
import { REDIS_CLIENT } from '../../database/redis.module';
import Redis from 'ioredis';
const SOCKET_KEY_PREFIX = 'socket:user:';
const SOCKET_REVERSE_KEY_PREFIX = 'socket:id:';
const LAST_SEEN_KEY_PREFIX = 'presence:last-seen:';
const PRESENCE_ZSET_KEY = 'presence:last-seen:zset';
const SET_SOCKET_MAPPING_SCRIPT = `
local userKey = KEYS[1]
local reverseKey = KEYS[2]
local userId = ARGV[1]
local socketId = ARGV[2]
local ttl = ARGV[3]
local reversePrefix = ARGV[4]
local previousSocketId = redis.call('GET', userKey)
redis.call('SET', userKey, socketId, 'EX', ttl)
redis.call('SET', reverseKey, userId, 'EX', ttl)
if previousSocketId and previousSocketId ~= socketId then
redis.call('DEL', reversePrefix .. previousSocketId)
end
return 1
`;
const REMOVE_SOCKET_MAPPING_SCRIPT = `
local userKey = KEYS[1]
local reversePrefix = ARGV[1]
local expectedSocketId = ARGV[2]
local currentSocketId = redis.call('GET', userKey)
if not currentSocketId then
return 0
end
if expectedSocketId ~= '' and currentSocketId ~= expectedSocketId then
return 0
end
redis.call('DEL', userKey)
redis.call('DEL', reversePrefix .. currentSocketId)
return 1
`;
const REMOVE_BY_SOCKET_ID_SCRIPT = `
local reverseKey = KEYS[1]
local userPrefix = ARGV[1]
local socketId = ARGV[2]
local userId = redis.call('GET', reverseKey)
if not userId then
return 0
end
local userKey = userPrefix .. userId
local currentSocketId = redis.call('GET', userKey)
redis.call('DEL', reverseKey)
if currentSocketId == socketId then
redis.call('DEL', userKey)
end
return 1
`;
@Injectable()
export class UserSocketService {
private readonly logger = new Logger(UserSocketService.name);
private localUserSocketMap: Map<string, string> = new Map();
private readonly PREFIX = 'socket:user:';
private readonly TTL = 86400; // 24 hours
private readonly LAST_SEEN_TTL_SECONDS = 604800; // 7 days
constructor(
@Inject(REDIS_CLIENT) private readonly redisClient: Redis | null,
@@ -16,11 +80,15 @@ export class UserSocketService {
async setSocket(userId: string, socketId: string): Promise<void> {
if (this.redisClient) {
try {
await this.redisClient.set(
`${this.PREFIX}${userId}`,
await this.redisClient.eval(
SET_SOCKET_MAPPING_SCRIPT,
2,
`${SOCKET_KEY_PREFIX}${userId}`,
`${SOCKET_REVERSE_KEY_PREFIX}${socketId}`,
userId,
socketId,
'EX',
this.TTL,
String(this.TTL),
SOCKET_REVERSE_KEY_PREFIX,
);
} catch (error) {
this.logger.error(
@@ -36,10 +104,16 @@ export class UserSocketService {
}
}
async removeSocket(userId: string): Promise<void> {
async removeSocket(userId: string, expectedSocketId?: string): Promise<void> {
if (this.redisClient) {
try {
await this.redisClient.del(`${this.PREFIX}${userId}`);
await this.redisClient.eval(
REMOVE_SOCKET_MAPPING_SCRIPT,
1,
`${SOCKET_KEY_PREFIX}${userId}`,
SOCKET_REVERSE_KEY_PREFIX,
expectedSocketId || '',
);
} catch (error) {
this.logger.error(
`Failed to remove socket for user ${userId} from Redis`,
@@ -47,13 +121,23 @@ export class UserSocketService {
);
}
}
if (!expectedSocketId) {
this.localUserSocketMap.delete(userId);
return;
}
const currentLocalSocketId = this.localUserSocketMap.get(userId);
if (currentLocalSocketId === expectedSocketId) {
this.localUserSocketMap.delete(userId);
}
}
async getSocket(userId: string): Promise<string | null> {
if (this.redisClient) {
try {
const socketId = await this.redisClient.get(`${this.PREFIX}${userId}`);
const socketId = await this.redisClient.get(
`${SOCKET_KEY_PREFIX}${userId}`,
);
return socketId;
} catch (error) {
this.logger.error(
@@ -82,7 +166,7 @@ export class UserSocketService {
try {
// Use pipeline for batch fetching
const pipeline = this.redisClient.pipeline();
friendIds.forEach((id) => pipeline.get(`${this.PREFIX}${id}`));
friendIds.forEach((id) => pipeline.get(`${SOCKET_KEY_PREFIX}${id}`));
const results = await pipeline.exec();
const sockets: { userId: string; socketId: string }[] = [];
@@ -115,4 +199,79 @@ export class UserSocketService {
}
return sockets;
}
async touchLastSeen(userId: string): Promise<void> {
const now = Date.now();
if (this.redisClient) {
try {
const key = `${LAST_SEEN_KEY_PREFIX}${userId}`;
await this.redisClient.set(
key,
String(now),
'EX',
this.LAST_SEEN_TTL_SECONDS,
);
await this.redisClient.zadd(PRESENCE_ZSET_KEY, now, userId);
return;
} catch (error) {
this.logger.warn(
`Failed to touch last-seen for user ${userId} in Redis`,
error as Error,
);
}
}
}
async removeSocketById(socketId: string): Promise<void> {
if (!this.redisClient) {
return;
}
try {
await this.redisClient.eval(
REMOVE_BY_SOCKET_ID_SCRIPT,
1,
`${SOCKET_REVERSE_KEY_PREFIX}${socketId}`,
SOCKET_KEY_PREFIX,
socketId,
);
} catch (error) {
this.logger.warn(
`Failed to remove socket mapping by socket id ${socketId}`,
error as Error,
);
}
}
async cleanupStalePresence(cutoffMs: number): Promise<number> {
if (!this.redisClient) {
return 0;
}
try {
const staleUserIds = await this.redisClient.zrangebyscore(
PRESENCE_ZSET_KEY,
'-inf',
cutoffMs,
);
if (staleUserIds.length === 0) {
return 0;
}
const pipeline = this.redisClient.pipeline();
staleUserIds.forEach((userId) => {
pipeline.del(`${LAST_SEEN_KEY_PREFIX}${userId}`);
});
pipeline.zremrangebyscore(PRESENCE_ZSET_KEY, '-inf', cutoffMs);
await pipeline.exec();
return staleUserIds.length;
} catch (error) {
this.logger.warn(
'Failed to cleanup stale presence entries',
error as Error,
);
return 0;
}
}
}

View File

@@ -1,5 +1,6 @@
import { UserSocketService } from '../user-socket.service';
import { WsNotificationService } from '../ws-notification.service';
import type { AuthenticatedSocket } from '../../../types/socket';
export class Broadcaster {
constructor(
@@ -7,6 +8,10 @@ export class Broadcaster {
private readonly wsNotificationService: WsNotificationService,
) {}
async touchPresence(client: AuthenticatedSocket) {
await this.wsNotificationService.maybeTouchPresence(client);
}
async broadcastToFriends(friends: Set<string>, event: string, payload: any) {
const friendIds = Array.from(friends);
const friendSockets =

View File

@@ -36,4 +36,20 @@ export class RedisHandler {
this.logger.error('Error handling friend cache update message', error);
}
}
async handleUserProfileCacheInvalidateMessage(
message: string,
): Promise<void> {
try {
const data = JSON.parse(message) as {
userId: string;
};
await this.wsNotificationService.clearSenderNameCache(data.userId);
} catch (error) {
this.logger.error(
'Error handling user profile cache invalidate message',
error,
);
}
}
}

View File

@@ -22,4 +22,5 @@ export const WS_EVENT = {
export const REDIS_CHANNEL = {
ACTIVE_DOLL_UPDATE: 'active-doll-update',
FRIEND_CACHE_UPDATE: 'friend-cache-update',
USER_PROFILE_CACHE_INVALIDATE: 'user-profile-cache-invalidate',
} as const;

View File

@@ -1,10 +1,14 @@
import { Injectable, Logger, Inject } from '@nestjs/common';
import { Inject, Injectable, Logger } from '@nestjs/common';
import { OnEvent } from '@nestjs/event-emitter';
import Redis from 'ioredis';
import { Server } from 'socket.io';
import { UserSocketService } from './user-socket.service';
import { UserEvents } from '../../users/events/user.events';
import type { AuthenticatedSocket } from '../../types/socket';
import { REDIS_CLIENT } from '../../database/redis.module';
import { REDIS_CHANNEL } from './ws-events';
import { UserSocketService } from './user-socket.service';
const PRESENCE_UPDATE_THROTTLE_MS = 15_000;
@Injectable()
export class WsNotificationService {
@@ -42,6 +46,11 @@ export class WsNotificationService {
this.io.to(socketId).emit(event, payload);
}
@OnEvent(UserEvents.PROFILE_UPDATED)
async handleUserProfileUpdated(payload: { userId: string }) {
await this.publishUserProfileCacheInvalidate(payload.userId);
}
async updateFriendsCache(
userId: string,
friendId: string,
@@ -126,4 +135,63 @@ export class WsNotificationService {
);
}
}
async publishUserProfileCacheInvalidate(userId: string) {
if (this.redisClient) {
try {
await this.redisClient.publish(
REDIS_CHANNEL.USER_PROFILE_CACHE_INVALIDATE,
JSON.stringify({ userId }),
);
return;
} catch (error) {
this.logger.warn(
'Redis publish failed for user profile cache invalidate; applying local update only',
error as Error,
);
}
}
await this.clearSenderNameCache(userId);
}
async clearSenderNameCache(userId: string) {
if (!this.io) {
return;
}
const socketId = await this.userSocketService.getSocket(userId);
if (!socketId) {
return;
}
const socket = this.io.sockets.sockets.get(socketId) as
| AuthenticatedSocket
| undefined;
if (!socket?.data) {
return;
}
socket.data.senderName = undefined;
socket.data.senderNameCachedAt = undefined;
}
async maybeTouchPresence(client: AuthenticatedSocket): Promise<void> {
const userId = client.data.userId;
if (!userId) {
return;
}
const now = Date.now();
const lastSeenAt = client.data.lastSeenAt;
if (
typeof lastSeenAt === 'number' &&
now - lastSeenAt < PRESENCE_UPDATE_THROTTLE_MS
) {
return;
}
client.data.lastSeenAt = now;
await this.userSocketService.touchLastSeen(userId);
}
}