Compare commits

..

2 Commits

Author SHA1 Message Date
db747c4f7a hopefully docker fix 2026-03-29 17:45:17 +08:00
c88bb5d2c4 production hardening 2026-03-29 01:51:42 +08:00
42 changed files with 158 additions and 1818 deletions

12
.dockerignore Normal file
View File

@@ -0,0 +1,12 @@
.git
.github
.node_modules
dist
coverage
*.log
.env*
test
*.spec.ts
*.e2e-spec.ts
README.md
AGENTS.md

View File

@@ -12,17 +12,6 @@ REDIS_PORT=6379
REDIS_REQUIRED=false REDIS_REQUIRED=false
REDIS_CONNECT_TIMEOUT_MS=5000 REDIS_CONNECT_TIMEOUT_MS=5000
REDIS_STARTUP_RETRIES=10 REDIS_STARTUP_RETRIES=10
# Stale presence cleanup threshold and interval
PRESENCE_STALE_AGE_MS=604800000
PRESENCE_CLEANUP_INTERVAL_MS=300000
# Cache
CACHE_KEY_PREFIX=friendolls
CACHE_DEFAULT_TTL_SECONDS=60
CACHE_MAX_TTL_SECONDS=86400
CACHE_METRICS_LOG_INTERVAL_MS=60000
# Max number of cache keys tracked per invalidation tag
CACHE_TAG_MAX_ENTRIES=5000
# JWT Configuration # JWT Configuration
JWT_SECRET=replace-with-strong-random-secret JWT_SECRET=replace-with-strong-random-secret
@@ -35,10 +24,6 @@ AUTH_CLEANUP_ENABLED=true
AUTH_CLEANUP_INTERVAL_MS=900000 AUTH_CLEANUP_INTERVAL_MS=900000
AUTH_SESSION_REVOKED_RETENTION_DAYS=7 AUTH_SESSION_REVOKED_RETENTION_DAYS=7
# Rate limiting
THROTTLE_TTL=1000
THROTTLE_LIMIT=5
# Google OAuth # Google OAuth
GOOGLE_CLIENT_ID="replace-with-google-client-id" GOOGLE_CLIENT_ID="replace-with-google-client-id"
GOOGLE_CLIENT_SECRET="replace-with-google-client-secret" GOOGLE_CLIENT_SECRET="replace-with-google-client-secret"

View File

@@ -8,9 +8,9 @@ Backend server for Friendolls.
## Commands ## Commands
- **Lint/Check for errors**: `pnpm lint` - **Error Checks**: `pnpm check`
- **Format**: `pnpm format` - **Format/Lint**: `pnpm format`, `pnpm lint`
- **Test**: `pnpm test` - **Test**: `pnpm test` (Unit), `pnpm test:e2e` (E2E)
- **Single Test**: `pnpm test -- -t "test name"` or `pnpm test -- src/path/to/file.spec.ts` - **Single Test**: `pnpm test -- -t "test name"` or `pnpm test -- src/path/to/file.spec.ts`
- **Database**: `npx prisma generate`, `npx prisma migrate dev` - **Database**: `npx prisma generate`, `npx prisma migrate dev`
@@ -29,4 +29,4 @@ Backend server for Friendolls.
## Note ## Note
Do not run the project yourself. Run lints and tests to detect issues after each final changes. Do not run the project yourself. Run error checks and lints to detect issues.

View File

@@ -8,9 +8,11 @@ RUN pnpm build
FROM node:20-alpine FROM node:20-alpine
WORKDIR /app WORKDIR /app
COPY --from=builder /app/dist ./dist RUN npm i -g pnpm
COPY --from=builder /app/node_modules ./node_modules COPY --from=builder /app/package.json ./package.json
COPY --from=builder /app/pnpm-lock.yaml ./pnpm-lock.yaml
COPY --from=builder /app/prisma ./prisma COPY --from=builder /app/prisma ./prisma
COPY --from=builder /app/prisma.config.ts ./prisma.config.ts COPY --from=builder /app/prisma.config.ts ./prisma.config.ts
COPY --from=builder /app/package.json ./package.json COPY --from=builder /app/dist ./dist
RUN pnpm install --prod --frozen-lockfile
CMD ["node", "dist/src/main.js"] CMD ["node", "dist/src/main.js"]

View File

@@ -47,9 +47,9 @@
"dotenv": "^17.2.3", "dotenv": "^17.2.3",
"ioredis": "^5.8.2", "ioredis": "^5.8.2",
"jsonwebtoken": "^9.0.2", "jsonwebtoken": "^9.0.2",
"helmet": "^8.1.0",
"passport": "^0.7.0", "passport": "^0.7.0",
"passport-discord": "^0.1.4", "passport-discord": "^0.1.4",
"helmet": "^8.1.0",
"passport-google-oauth20": "^2.0.0", "passport-google-oauth20": "^2.0.0",
"passport-jwt": "^4.0.1", "passport-jwt": "^4.0.1",
"pg": "^8.16.3", "pg": "^8.16.3",

View File

@@ -3,6 +3,7 @@
generator client { generator client {
provider = "prisma-client-js" provider = "prisma-client-js"
output = "../node_modules/.prisma/client"
} }
datasource db { datasource db {

View File

@@ -5,7 +5,6 @@ import { EventEmitterModule } from '@nestjs/event-emitter';
import { ThrottlerGuard, ThrottlerModule } from '@nestjs/throttler'; import { ThrottlerGuard, ThrottlerModule } from '@nestjs/throttler';
import { AppController } from './app.controller'; import { AppController } from './app.controller';
import { AppService } from './app.service'; import { AppService } from './app.service';
import { CacheModule, RedisThrottlerStorage } from './common/cache';
import { UsersModule } from './users/users.module'; import { UsersModule } from './users/users.module';
import { AuthModule } from './auth/auth.module'; import { AuthModule } from './auth/auth.module';
import { DatabaseModule } from './database/database.module'; import { DatabaseModule } from './database/database.module';
@@ -13,10 +12,7 @@ import { RedisModule } from './database/redis.module';
import { WsModule } from './ws/ws.module'; import { WsModule } from './ws/ws.module';
import { FriendsModule } from './friends/friends.module'; import { FriendsModule } from './friends/friends.module';
import { DollsModule } from './dolls/dolls.module'; import { DollsModule } from './dolls/dolls.module';
import { import { parseRedisRequired } from './common/config/env.utils';
parsePositiveInteger,
parseRedisRequired,
} from './common/config/env.utils';
/** /**
* Validates required environment variables. * Validates required environment variables.
@@ -45,7 +41,7 @@ function validateEnvironment(
} }
// Validate PORT if provided // Validate PORT if provided
if (config.PORT !== undefined && !Number.isFinite(Number(config.PORT))) { if (config.PORT && isNaN(Number(config.PORT))) {
throw new Error('PORT must be a valid number'); throw new Error('PORT must be a valid number');
} }
@@ -83,15 +79,6 @@ function validateEnvironment(
throw new Error('REDIS_CONNECT_TIMEOUT_MS must be a positive number'); throw new Error('REDIS_CONNECT_TIMEOUT_MS must be a positive number');
} }
validateOptionalPositiveNumber(config, 'THROTTLE_TTL');
validateOptionalPositiveNumber(config, 'THROTTLE_LIMIT');
validateOptionalPositiveNumber(config, 'CACHE_DEFAULT_TTL_SECONDS');
validateOptionalPositiveNumber(config, 'CACHE_MAX_TTL_SECONDS');
validateOptionalPositiveNumber(config, 'CACHE_METRICS_LOG_INTERVAL_MS');
validateOptionalPositiveNumber(config, 'CACHE_TAG_MAX_ENTRIES');
validateOptionalPositiveNumber(config, 'PRESENCE_STALE_AGE_MS');
validateOptionalPositiveNumber(config, 'PRESENCE_CLEANUP_INTERVAL_MS');
validateOptionalProvider(config, 'GOOGLE'); validateOptionalProvider(config, 'GOOGLE');
validateOptionalProvider(config, 'DISCORD'); validateOptionalProvider(config, 'DISCORD');
@@ -118,20 +105,6 @@ function validateOptionalProvider(
} }
} }
function validateOptionalPositiveNumber(
config: Record<string, unknown>,
key: string,
): void {
const value = config[key];
if (value === undefined || value === null || value === '') {
return;
}
if (!Number.isFinite(Number(value)) || Number(value) <= 0) {
throw new Error(`${key} must be a positive number`);
}
}
/** /**
* Root Application Module * Root Application Module
* *
@@ -144,33 +117,15 @@ function validateOptionalPositiveNumber(
envFilePath: '.env', envFilePath: '.env',
validate: validateEnvironment, validate: validateEnvironment,
}), }),
CacheModule,
ThrottlerModule.forRootAsync({ ThrottlerModule.forRootAsync({
imports: [ConfigModule, CacheModule], imports: [ConfigModule],
inject: [ConfigService, RedisThrottlerStorage], inject: [ConfigService],
useFactory: ( useFactory: (config: ConfigService) => [
config: ConfigService,
redisThrottlerStorage: RedisThrottlerStorage,
) => {
const ttl = parsePositiveInteger(
config.get<string>('THROTTLE_TTL'),
1000,
);
const limit = parsePositiveInteger(
config.get<string>('THROTTLE_LIMIT'),
5,
);
return {
storage: redisThrottlerStorage,
throttlers: [
{ {
ttl, ttl: config.get('THROTTLE_TTL', 1000),
limit, limit: config.get('THROTTLE_LIMIT', 5),
}, },
], ],
};
},
}), }),
EventEmitterModule.forRoot(), EventEmitterModule.forRoot(),
DatabaseModule, DatabaseModule,

View File

@@ -4,11 +4,8 @@ import {
UnauthorizedException, UnauthorizedException,
} from '@nestjs/common'; } from '@nestjs/common';
import { ConfigService } from '@nestjs/config'; import { ConfigService } from '@nestjs/config';
import { EventEmitter2 } from '@nestjs/event-emitter';
import { Test, TestingModule } from '@nestjs/testing'; import { Test, TestingModule } from '@nestjs/testing';
import { decode, sign } from 'jsonwebtoken'; import { decode, sign } from 'jsonwebtoken';
import { CacheService } from '../common/cache/cache.service';
import { CacheTagsService } from '../common/cache/cache-tags.service';
import { PrismaService } from '../database/prisma.service'; import { PrismaService } from '../database/prisma.service';
import { AuthService } from './auth.service'; import { AuthService } from './auth.service';
import { sha256 } from './auth.utils'; import { sha256 } from './auth.utils';
@@ -61,27 +58,6 @@ describe('AuthService', () => {
$transaction: jest.fn(), $transaction: jest.fn(),
}; };
const mockEventEmitter = {
emit: jest.fn(),
};
const mockCacheService = {
get: jest.fn().mockResolvedValue(null),
set: jest.fn().mockResolvedValue(true),
del: jest.fn().mockResolvedValue(true),
getNamespacedKey: jest
.fn()
.mockImplementation(
(namespace: string, key: string) => `friendolls:${namespace}:${key}`,
),
recordError: jest.fn(),
};
const mockCacheTagsService = {
rememberKeyForTag: jest.fn().mockResolvedValue(undefined),
invalidateTag: jest.fn().mockResolvedValue(undefined),
};
const socialProfile: SocialAuthProfile = { const socialProfile: SocialAuthProfile = {
provider: 'google', provider: 'google',
providerSubject: 'google-user-123', providerSubject: 'google-user-123',
@@ -118,9 +94,6 @@ describe('AuthService', () => {
AuthService, AuthService,
{ provide: PrismaService, useValue: mockPrismaService }, { provide: PrismaService, useValue: mockPrismaService },
{ provide: ConfigService, useValue: mockConfigService }, { provide: ConfigService, useValue: mockConfigService },
{ provide: EventEmitter2, useValue: mockEventEmitter },
{ provide: CacheService, useValue: mockCacheService },
{ provide: CacheTagsService, useValue: mockCacheTagsService },
], ],
}).compile(); }).compile();
@@ -162,9 +135,6 @@ describe('AuthService', () => {
const localService = new AuthService( const localService = new AuthService(
mockPrismaService as unknown as PrismaService, mockPrismaService as unknown as PrismaService,
mockConfigService as unknown as ConfigService, mockConfigService as unknown as ConfigService,
mockEventEmitter as unknown as EventEmitter2,
mockCacheService as unknown as CacheService,
mockCacheTagsService as unknown as CacheTagsService,
); );
expect(() => expect(() =>

View File

@@ -13,7 +13,6 @@ import {
verify, verify,
} from 'jsonwebtoken'; } from 'jsonwebtoken';
import { PrismaService } from '../database/prisma.service'; import { PrismaService } from '../database/prisma.service';
import { EventEmitter2 } from '@nestjs/event-emitter';
import type { SocialAuthProfile } from './types/social-auth-profile'; import type { SocialAuthProfile } from './types/social-auth-profile';
import type { import type {
AuthTokens, AuthTokens,
@@ -36,15 +35,6 @@ import {
usernameFromEmail, usernameFromEmail,
} from './auth.utils'; } from './auth.utils';
import type { SsoProvider } from './dto/sso-provider'; import type { SsoProvider } from './dto/sso-provider';
import { UserEvents } from '../users/events/user.events';
import { CacheService } from '../common/cache/cache.service';
import {
authSessionUserTag,
authSessionCacheKey,
CACHE_NAMESPACE,
CACHE_TTL_SECONDS,
} from '../common/cache/cache-keys';
import { CacheTagsService } from '../common/cache/cache-tags.service';
interface SsoStateClaims { interface SsoStateClaims {
provider: SsoProvider; provider: SsoProvider;
@@ -53,28 +43,6 @@ interface SsoStateClaims {
typ: 'sso_state'; typ: 'sso_state';
} }
interface AuthSessionWithUser {
id: string;
refresh_token_hash: string;
expires_at: Date;
revoked_at: Date | null;
provider: 'GOOGLE' | 'DISCORD' | null;
user_id: string;
email: string;
roles: string[];
}
interface CachedAuthSessionWithUser {
id: string;
refresh_token_hash: string;
expires_at: string;
revoked_at: string | null;
provider: 'GOOGLE' | 'DISCORD' | null;
user_id: string;
email: string;
roles: string[];
}
@Injectable() @Injectable()
export class AuthService { export class AuthService {
private readonly logger = new Logger(AuthService.name); private readonly logger = new Logger(AuthService.name);
@@ -88,9 +56,6 @@ export class AuthService {
constructor( constructor(
private readonly prisma: PrismaService, private readonly prisma: PrismaService,
private readonly configService: ConfigService, private readonly configService: ConfigService,
private readonly eventEmitter: EventEmitter2,
private readonly cacheService: CacheService,
private readonly cacheTagsService: CacheTagsService,
) { ) {
this.jwtSecret = this.configService.get<string>('JWT_SECRET') || ''; this.jwtSecret = this.configService.get<string>('JWT_SECRET') || '';
this.jwtIssuer = this.jwtIssuer =
@@ -194,7 +159,7 @@ export class AuthService {
} }
if (session.refresh_token_hash !== refreshTokenHash) { if (session.refresh_token_hash !== refreshTokenHash) {
await this.revokeSessionOnReplay(session.id, session.user_id); await this.revokeSessionOnReplay(session.id);
throw new UnauthorizedException('Invalid refresh token'); throw new UnauthorizedException('Invalid refresh token');
} }
@@ -206,7 +171,7 @@ export class AuthService {
); );
if (!updated) { if (!updated) {
await this.revokeSessionOnReplay(session.id, session.user_id); await this.revokeSessionOnReplay(session.id);
throw new UnauthorizedException('Invalid refresh token'); throw new UnauthorizedException('Invalid refresh token');
} }
@@ -289,11 +254,6 @@ export class AuthService {
}, },
}); });
this.eventEmitter.emit(UserEvents.SEARCH_INDEX_INVALIDATED, {
userId: user.id,
});
this.eventEmitter.emit(UserEvents.PROFILE_UPDATED, { userId: user.id });
return user; return user;
} }
@@ -313,7 +273,7 @@ export class AuthService {
); );
} }
const user = await this.prisma.$transaction(async (tx) => { return this.prisma.$transaction(async (tx) => {
let user = await tx.user.findUnique({ let user = await tx.user.findUnique({
where: { email }, where: { email },
}); });
@@ -351,13 +311,6 @@ export class AuthService {
return user; return user;
}); });
this.eventEmitter.emit(UserEvents.SEARCH_INDEX_INVALIDATED, {
userId: user.id,
});
this.eventEmitter.emit(UserEvents.PROFILE_UPDATED, { userId: user.id });
return user;
} }
private async resolveUsername( private async resolveUsername(
@@ -594,34 +547,28 @@ export class AuthService {
return rows[0] ?? null; return rows[0] ?? null;
} }
private async getSessionWithUser( private async getSessionWithUser(sessionId: string): Promise<{
sessionId: string, id: string;
): Promise<AuthSessionWithUser | null> { refresh_token_hash: string;
const sessionCacheKey = this.getAuthSessionCacheKey(sessionId); expires_at: Date;
const cachedSessionRaw = await this.cacheService.get(sessionCacheKey); revoked_at: Date | null;
provider: 'GOOGLE' | 'DISCORD' | null;
if (cachedSessionRaw) { user_id: string;
try { email: string;
const cachedSession = JSON.parse( roles: string[];
cachedSessionRaw, } | null> {
) as CachedAuthSessionWithUser; const rows = await this.prisma.$queryRaw<
return { Array<{
...cachedSession, id: string;
expires_at: new Date(cachedSession.expires_at), refresh_token_hash: string;
revoked_at: cachedSession.revoked_at expires_at: Date;
? new Date(cachedSession.revoked_at) revoked_at: Date | null;
: null, provider: 'GOOGLE' | 'DISCORD' | null;
}; user_id: string;
} catch (error) { email: string;
this.cacheService.recordError( roles: string[];
'auth session parse', }>
sessionCacheKey, >`
error,
);
}
}
const rows = await this.prisma.$queryRaw<Array<AuthSessionWithUser>>`
SELECT s.id, s.refresh_token_hash, s.expires_at, s.revoked_at, s.provider, s.user_id, u.email, u.roles SELECT s.id, s.refresh_token_hash, s.expires_at, s.revoked_at, s.provider, s.user_id, u.email, u.roles
FROM auth_sessions AS s FROM auth_sessions AS s
INNER JOIN users AS u ON u.id = s.user_id INNER JOIN users AS u ON u.id = s.user_id
@@ -629,29 +576,7 @@ export class AuthService {
LIMIT 1 LIMIT 1
`; `;
const session = rows[0] ?? null; return rows[0] ?? null;
if (!session) {
return null;
}
const cachePayload: CachedAuthSessionWithUser = {
...session,
expires_at: session.expires_at.toISOString(),
revoked_at: session.revoked_at ? session.revoked_at.toISOString() : null,
};
await this.cacheService.set(
sessionCacheKey,
JSON.stringify(cachePayload),
CACHE_TTL_SECONDS.AUTH_SESSION,
);
await this.cacheTagsService.rememberKeyForTag(
CACHE_NAMESPACE.AUTH_SESSION,
authSessionUserTag(session.user_id),
authSessionCacheKey(session.id),
);
return session;
} }
private async rotateRefreshSession( private async rotateRefreshSession(
@@ -659,8 +584,6 @@ export class AuthService {
refreshTokenHash: string, refreshTokenHash: string,
nextRefreshToken: string, nextRefreshToken: string,
): Promise<boolean> { ): Promise<boolean> {
await this.cacheService.del(this.getAuthSessionCacheKey(sessionId));
const rows = await this.prisma.$queryRaw<Array<{ id: string }>>` const rows = await this.prisma.$queryRaw<Array<{ id: string }>>`
UPDATE auth_sessions UPDATE auth_sessions
SET refresh_token_hash = ${sha256(nextRefreshToken)}, SET refresh_token_hash = ${sha256(nextRefreshToken)},
@@ -674,10 +597,6 @@ export class AuthService {
RETURNING id RETURNING id
`; `;
if (rows.length === 1) {
await this.cacheService.del(this.getAuthSessionCacheKey(sessionId));
}
return rows.length === 1; return rows.length === 1;
} }
@@ -685,8 +604,6 @@ export class AuthService {
sessionId: string, sessionId: string,
refreshTokenHash: string, refreshTokenHash: string,
): Promise<boolean> { ): Promise<boolean> {
await this.cacheService.del(this.getAuthSessionCacheKey(sessionId));
const rows = await this.prisma.$queryRaw<Array<{ id: string }>>` const rows = await this.prisma.$queryRaw<Array<{ id: string }>>`
UPDATE auth_sessions UPDATE auth_sessions
SET revoked_at = NOW(), SET revoked_at = NOW(),
@@ -698,41 +615,17 @@ export class AuthService {
RETURNING id RETURNING id
`; `;
if (rows.length === 1) {
await this.cacheService.del(this.getAuthSessionCacheKey(sessionId));
}
return rows.length === 1; return rows.length === 1;
} }
private async revokeSessionOnReplay( private async revokeSessionOnReplay(sessionId: string): Promise<void> {
sessionId: string,
userId: string,
): Promise<void> {
await this.cacheService.del(this.getAuthSessionCacheKey(sessionId));
await this.revokeAllUserSessions(userId);
}
private async revokeAllUserSessions(userId: string): Promise<void> {
await this.prisma.$queryRaw<Array<{ id: string }>>` await this.prisma.$queryRaw<Array<{ id: string }>>`
UPDATE auth_sessions UPDATE auth_sessions
SET revoked_at = NOW(), SET revoked_at = NOW(),
updated_at = NOW() updated_at = NOW()
WHERE user_id = ${userId} WHERE id = ${sessionId}
AND revoked_at IS NULL AND revoked_at IS NULL
RETURNING id RETURNING id
`; `;
await this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.AUTH_SESSION,
authSessionUserTag(userId),
);
}
private getAuthSessionCacheKey(sessionId: string): string {
return this.cacheService.getNamespacedKey(
CACHE_NAMESPACE.AUTH_SESSION,
authSessionCacheKey(sessionId),
);
} }
} }

View File

@@ -20,14 +20,6 @@ const DEFAULT_REVOKED_RETENTION_DAYS = 7;
const CLEANUP_LOCK_KEY = 'lock:auth:cleanup'; const CLEANUP_LOCK_KEY = 'lock:auth:cleanup';
const CLEANUP_LOCK_TTL_MS = 55_000; const CLEANUP_LOCK_TTL_MS = 55_000;
const RELEASE_LOCK_SCRIPT = `
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
`;
@Injectable() @Injectable()
export class AuthCleanupService implements OnModuleInit, OnModuleDestroy { export class AuthCleanupService implements OnModuleInit, OnModuleDestroy {
private readonly logger = new Logger(AuthCleanupService.name); private readonly logger = new Logger(AuthCleanupService.name);
@@ -149,12 +141,10 @@ export class AuthCleanupService implements OnModuleInit, OnModuleDestroy {
} finally { } finally {
if (lockAcquired && this.redisClient) { if (lockAcquired && this.redisClient) {
try { try {
await this.redisClient.eval( const currentLockValue = await this.redisClient.get(CLEANUP_LOCK_KEY);
RELEASE_LOCK_SCRIPT, if (currentLockValue === lockToken) {
1, await this.redisClient.del(CLEANUP_LOCK_KEY);
CLEANUP_LOCK_KEY, }
lockToken,
);
} catch (error) { } catch (error) {
this.logger.warn( this.logger.warn(
'Failed to release auth cleanup lock', 'Failed to release auth cleanup lock',

View File

@@ -1,84 +0,0 @@
const EMPTY_VALUE_TOKEN = '_';
export const CACHE_NAMESPACE = {
FRIENDS_LIST: 'friends-list',
DOLLS_LIST: 'dolls-list',
USERS_SEARCH: 'users-search',
FRIENDSHIP_CHECK: 'friendship-check',
AUTH_SESSION: 'auth-session',
} as const;
function normalizeKeyPart(value: string | undefined): string {
if (!value) {
return EMPTY_VALUE_TOKEN;
}
return encodeURIComponent(value);
}
export const CACHE_TTL_SECONDS = {
FRIENDS_LIST: 30,
DOLLS_LIST: 30,
USERS_SEARCH: 20,
FRIENDSHIP_CHECK: 120,
AUTH_SESSION: 30,
} as const;
export function friendsListCacheKey(userId: string): string {
return normalizeKeyPart(userId);
}
export function friendsListOwnerTag(userId: string): string {
return `owner:${normalizeKeyPart(userId)}`;
}
export function friendsListDependsOnUserTag(userId: string): string {
return `depends-on:${normalizeKeyPart(userId)}`;
}
export function dollsListCacheKey(
ownerId: string,
requesterId: string,
): string {
return `${normalizeKeyPart(ownerId)}:${normalizeKeyPart(requesterId)}`;
}
export function dollsListOwnerTag(ownerId: string): string {
return `owner:${normalizeKeyPart(ownerId)}`;
}
export function dollsListViewerTag(viewerId: string): string {
return `viewer:${normalizeKeyPart(viewerId)}`;
}
export function usersSearchCacheKey(
username: string | undefined,
excludeUserId: string | undefined,
): string {
return `${normalizeKeyPart(username?.trim().toLowerCase())}:${normalizeKeyPart(excludeUserId)}`;
}
export const USERS_SEARCH_GLOBAL_TAG = 'global';
export function friendshipCheckCacheKey(
userId: string,
friendId: string,
): string {
return `${normalizeKeyPart(userId)}:${normalizeKeyPart(friendId)}`;
}
export function friendshipCheckUserTag(userId: string): string {
return `user:${normalizeKeyPart(userId)}`;
}
export function authSessionCacheKey(sessionId: string): string {
return normalizeKeyPart(sessionId);
}
export function authSessionUserTag(userId: string): string {
return `user:${normalizeKeyPart(userId)}`;
}
export function usersSearchUserTag(userId: string): string {
return `user:${normalizeKeyPart(userId)}`;
}

View File

@@ -1,106 +0,0 @@
import { Injectable } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { CacheService } from './cache.service';
import { parsePositiveInteger } from '../config/env.utils';
const CACHE_TAG_SET_TTL_SECONDS = 86_400;
const DEFAULT_CACHE_TAG_MAX_ENTRIES = 5_000;
@Injectable()
export class CacheTagsService {
private readonly cacheTagMaxEntries: number;
constructor(
private readonly cacheService: CacheService,
private readonly configService: ConfigService,
) {
this.cacheTagMaxEntries = parsePositiveInteger(
this.configService.get<string>('CACHE_TAG_MAX_ENTRIES'),
DEFAULT_CACHE_TAG_MAX_ENTRIES,
);
}
async rememberKeyForTag(
namespace: string,
tag: string,
cacheKey: string,
): Promise<void> {
const redisClient = this.cacheService.getRedisClient();
if (!redisClient) {
return;
}
const tagSetKey = this.getTagSetKey(namespace, tag);
const keyWithNamespace = this.cacheService.getNamespacedKey(
namespace,
cacheKey,
);
try {
await Promise.all([
redisClient.sadd(tagSetKey, keyWithNamespace),
redisClient.expire(tagSetKey, CACHE_TAG_SET_TTL_SECONDS),
]);
const size = await redisClient.scard(tagSetKey);
if (size > this.cacheTagMaxEntries) {
await this.trimTagSet(tagSetKey, size - this.cacheTagMaxEntries);
}
} catch (error) {
this.cacheService.recordError('tag remember', tagSetKey, error);
}
}
async invalidateTag(namespace: string, tag: string): Promise<void> {
const redisClient = this.cacheService.getRedisClient();
if (!redisClient) {
return;
}
const tagSetKey = this.getTagSetKey(namespace, tag);
try {
const keys = await redisClient.smembers(tagSetKey);
if (keys.length === 0) {
await redisClient.del(tagSetKey);
return;
}
const pipeline = redisClient.pipeline();
keys.forEach((key) => pipeline.del(key));
pipeline.del(tagSetKey);
await pipeline.exec();
} catch (error) {
this.cacheService.recordError('tag invalidate', tagSetKey, error);
}
}
private getTagSetKey(namespace: string, tag: string): string {
return this.cacheService.getNamespacedKey(
'cache-tag',
`${namespace}:${tag}`,
);
}
private async trimTagSet(
tagSetKey: string,
countToDrop: number,
): Promise<void> {
const redisClient = this.cacheService.getRedisClient();
if (!redisClient || countToDrop <= 0) {
return;
}
try {
const sample = await redisClient.srandmember(tagSetKey, countToDrop);
const members = Array.isArray(sample) ? sample : [sample].filter(Boolean);
if (members.length === 0) {
return;
}
await redisClient.srem(tagSetKey, ...members);
} catch (error) {
this.cacheService.recordError('tag trim', tagSetKey, error);
}
}
}

View File

@@ -1,13 +0,0 @@
import { Global, Module } from '@nestjs/common';
import { RedisModule } from '../../database/redis.module';
import { CacheTagsService } from './cache-tags.service';
import { CacheService } from './cache.service';
import { RedisThrottlerStorage } from './redis-throttler.storage';
@Global()
@Module({
imports: [RedisModule],
providers: [CacheService, CacheTagsService, RedisThrottlerStorage],
exports: [CacheService, CacheTagsService, RedisThrottlerStorage],
})
export class CacheModule {}

View File

@@ -1,185 +0,0 @@
import { Inject, Injectable, Logger, OnModuleDestroy } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import Redis from 'ioredis';
import { REDIS_CLIENT } from '../../database/redis.module';
import { parsePositiveInteger } from '../config/env.utils';
const DEFAULT_CACHE_KEY_PREFIX = 'friendolls';
const DEFAULT_CACHE_TTL_SECONDS = 60;
const DEFAULT_CACHE_MAX_TTL_SECONDS = 86_400;
const DEFAULT_CACHE_METRICS_LOG_INTERVAL_MS = 60_000;
@Injectable()
export class CacheService implements OnModuleDestroy {
private readonly logger = new Logger(CacheService.name);
private readonly keyPrefix: string;
private readonly defaultTtlSeconds: number;
private readonly maxTtlSeconds: number;
private readonly metricsLogIntervalMs: number;
private readonly metrics = {
getHits: 0,
getMisses: 0,
sets: 0,
deletes: 0,
errors: 0,
unavailable: 0,
};
private metricsTimer: NodeJS.Timeout | null = null;
constructor(
private readonly configService: ConfigService,
@Inject(REDIS_CLIENT) private readonly redisClient: Redis | null,
) {
this.keyPrefix =
this.configService.get<string>('CACHE_KEY_PREFIX') ||
DEFAULT_CACHE_KEY_PREFIX;
this.defaultTtlSeconds = parsePositiveInteger(
this.configService.get<string>('CACHE_DEFAULT_TTL_SECONDS'),
DEFAULT_CACHE_TTL_SECONDS,
);
this.maxTtlSeconds = parsePositiveInteger(
this.configService.get<string>('CACHE_MAX_TTL_SECONDS'),
DEFAULT_CACHE_MAX_TTL_SECONDS,
);
this.metricsLogIntervalMs = parsePositiveInteger(
this.configService.get<string>('CACHE_METRICS_LOG_INTERVAL_MS'),
DEFAULT_CACHE_METRICS_LOG_INTERVAL_MS,
);
if (this.metricsLogIntervalMs > 0) {
this.metricsTimer = setInterval(() => {
this.flushMetrics();
}, this.metricsLogIntervalMs);
this.metricsTimer.unref();
}
}
onModuleDestroy(): void {
if (this.metricsTimer) {
clearInterval(this.metricsTimer);
this.metricsTimer = null;
}
this.flushMetrics();
}
getNamespacedKey(namespace: string, key: string): string {
return `${this.keyPrefix}:${namespace}:${key}`;
}
resolveTtlSeconds(ttlSeconds?: number): number {
if (ttlSeconds === undefined) {
return this.defaultTtlSeconds;
}
if (!Number.isFinite(ttlSeconds) || ttlSeconds <= 0) {
return this.defaultTtlSeconds;
}
return Math.min(Math.floor(ttlSeconds), this.maxTtlSeconds);
}
async get(key: string): Promise<string | null> {
if (!this.redisClient) {
this.metrics.unavailable += 1;
return null;
}
try {
const value = await this.redisClient.get(key);
if (value === null) {
this.metrics.getMisses += 1;
} else {
this.metrics.getHits += 1;
}
return value;
} catch (error) {
this.metrics.errors += 1;
this.logger.warn(`Cache get failed for key ${key}`, error as Error);
return null;
}
}
async set(key: string, value: string, ttlSeconds?: number): Promise<boolean> {
if (!this.redisClient) {
this.metrics.unavailable += 1;
return false;
}
try {
const ttl = this.resolveTtlSeconds(ttlSeconds);
await this.redisClient.set(key, value, 'EX', ttl);
this.metrics.sets += 1;
return true;
} catch (error) {
this.metrics.errors += 1;
this.logger.warn(`Cache set failed for key ${key}`, error as Error);
return false;
}
}
async del(key: string): Promise<boolean> {
if (!this.redisClient) {
this.metrics.unavailable += 1;
return false;
}
try {
await this.redisClient.del(key);
this.metrics.deletes += 1;
return true;
} catch (error) {
this.metrics.errors += 1;
this.logger.warn(`Cache delete failed for key ${key}`, error as Error);
return false;
}
}
getRedisClient(): Redis | null {
return this.redisClient;
}
recordError(operation: string, key: string, error: unknown): void {
this.metrics.errors += 1;
this.logger.warn(
`Cache ${operation} failed for key ${key}`,
error as Error,
);
}
recordUnavailable(): void {
this.metrics.unavailable += 1;
}
private flushMetrics(): void {
const totalReads = this.metrics.getHits + this.metrics.getMisses;
if (
totalReads === 0 &&
this.metrics.sets === 0 &&
this.metrics.deletes === 0 &&
this.metrics.errors === 0 &&
this.metrics.unavailable === 0
) {
return;
}
const hitRate =
totalReads === 0
? '0.00'
: ((this.metrics.getHits / totalReads) * 100).toFixed(2);
this.logger.log(
`metrics reads=${totalReads} hits=${this.metrics.getHits} misses=${this.metrics.getMisses} hitRate=${hitRate}% sets=${this.metrics.sets} deletes=${this.metrics.deletes} errors=${this.metrics.errors} unavailable=${this.metrics.unavailable}`,
);
this.metrics.getHits = 0;
this.metrics.getMisses = 0;
this.metrics.sets = 0;
this.metrics.deletes = 0;
this.metrics.errors = 0;
this.metrics.unavailable = 0;
}
}

View File

@@ -1,4 +0,0 @@
export { CacheModule } from './cache.module';
export { CacheService } from './cache.service';
export { CacheTagsService } from './cache-tags.service';
export { RedisThrottlerStorage } from './redis-throttler.storage';

View File

@@ -1,239 +0,0 @@
import { Injectable } from '@nestjs/common';
import { ThrottlerStorage } from '@nestjs/throttler';
import { CacheService } from './cache.service';
interface RedisThrottlerStorageRecord {
totalHits: number;
timeToExpire: number;
isBlocked: boolean;
timeToBlockExpire: number;
}
@Injectable()
export class RedisThrottlerStorage implements ThrottlerStorage {
private static readonly IN_MEMORY_CLEANUP_INTERVAL = 500;
private readonly inMemoryStorage = new Map<
string,
{ totalHits: number; expiresAt: number; blockExpiresAt: number }
>();
private inMemoryOperationCount = 0;
constructor(private readonly cacheService: CacheService) {}
async increment(
key: string,
ttl: number,
limit: number,
blockDuration: number,
throttlerName: string,
): Promise<RedisThrottlerStorageRecord> {
const safeLimit = Math.max(0, Math.floor(limit));
const ttlMilliseconds = this.normalizeDurationMs(ttl);
const blockDurationMilliseconds = this.normalizeDurationMs(blockDuration);
const counterKey = this.cacheService.getNamespacedKey(
'throttle:counter',
`${throttlerName}:${key}`,
);
const blockKey = this.cacheService.getNamespacedKey(
'throttle:block',
`${throttlerName}:${key}`,
);
const redisClient = this.cacheService.getRedisClient();
if (!redisClient) {
this.cacheService.recordUnavailable();
return this.incrementInMemory(
counterKey,
ttlMilliseconds,
safeLimit,
blockDurationMilliseconds,
);
}
try {
const initialized = await redisClient.set(
counterKey,
'1',
'PX',
ttlMilliseconds,
'NX',
);
if (initialized === 'OK') {
const existingBlockTtlRemainingMs = await redisClient.pttl(blockKey);
return {
totalHits: 1,
timeToExpire: Math.ceil(ttlMilliseconds / 1000),
isBlocked: existingBlockTtlRemainingMs > 0,
timeToBlockExpire:
existingBlockTtlRemainingMs > 0
? this.toSecondsFromPttl(
existingBlockTtlRemainingMs,
blockDurationMilliseconds,
)
: 0,
};
}
const [
existingBlockTtlRemainingMs,
ttlRemainingBeforeHitMs,
currentCount,
] = await Promise.all([
redisClient.pttl(blockKey),
redisClient.pttl(counterKey),
redisClient.get(counterKey),
]);
if (existingBlockTtlRemainingMs > 0) {
const totalHits = Number(currentCount ?? '0');
return {
totalHits: Number.isFinite(totalHits) ? totalHits : 0,
timeToExpire: this.toSecondsFromPttl(
ttlRemainingBeforeHitMs,
ttlMilliseconds,
),
isBlocked: true,
timeToBlockExpire: this.toSecondsFromPttl(
existingBlockTtlRemainingMs,
blockDurationMilliseconds,
),
};
}
const count = await redisClient.incr(counterKey);
if (count === 1) {
await redisClient.pexpire(counterKey, ttlMilliseconds);
}
const [ttlRemainingMs, blockTtlRemainingMs] = await Promise.all([
redisClient.pttl(counterKey),
redisClient.pttl(blockKey),
]);
let isBlocked = blockTtlRemainingMs > 0;
if (!isBlocked && safeLimit > 0 && count > safeLimit) {
await redisClient.set(blockKey, '1', 'PX', blockDurationMilliseconds);
isBlocked = true;
}
const refreshedBlockTtlRemainingMs = isBlocked
? await redisClient.pttl(blockKey)
: -1;
return {
totalHits: count,
timeToExpire: this.toSecondsFromPttl(ttlRemainingMs, ttlMilliseconds),
isBlocked,
timeToBlockExpire: isBlocked
? this.toSecondsFromPttl(
refreshedBlockTtlRemainingMs,
blockDurationMilliseconds,
)
: 0,
};
} catch (error) {
this.cacheService.recordError('throttler increment', counterKey, error);
return this.incrementInMemory(
counterKey,
ttlMilliseconds,
safeLimit,
blockDurationMilliseconds,
);
}
}
private incrementInMemory(
key: string,
ttlMilliseconds: number,
limit: number,
blockDurationMilliseconds: number,
): RedisThrottlerStorageRecord {
const now = Date.now();
this.inMemoryOperationCount += 1;
if (
this.inMemoryOperationCount %
RedisThrottlerStorage.IN_MEMORY_CLEANUP_INTERVAL ===
0
) {
this.cleanupExpiredInMemory(now);
}
const existing = this.inMemoryStorage.get(key);
if (existing && existing.blockExpiresAt > now) {
return {
totalHits: existing.totalHits,
timeToExpire: Math.max(1, Math.ceil((existing.expiresAt - now) / 1000)),
isBlocked: true,
timeToBlockExpire: Math.max(
1,
Math.ceil((existing.blockExpiresAt - now) / 1000),
),
};
}
let totalHits = 1;
let expiresAt = now + ttlMilliseconds;
let blockExpiresAt = 0;
if (existing && existing.expiresAt > now) {
totalHits = existing.totalHits + 1;
expiresAt = existing.expiresAt;
}
if (blockExpiresAt <= now) {
blockExpiresAt = 0;
}
let isBlocked = blockExpiresAt > now;
if (!isBlocked && limit > 0 && totalHits > limit) {
blockExpiresAt = now + blockDurationMilliseconds;
isBlocked = true;
}
this.inMemoryStorage.set(key, {
totalHits,
expiresAt,
blockExpiresAt,
});
return {
totalHits,
timeToExpire: Math.max(1, Math.ceil((expiresAt - now) / 1000)),
isBlocked,
timeToBlockExpire: isBlocked
? Math.max(1, Math.ceil((blockExpiresAt - now) / 1000))
: 0,
};
}
private normalizeDurationMs(value: number): number {
if (!Number.isFinite(value) || value <= 0) {
return 1000;
}
return Math.max(1, Math.floor(value));
}
private toSecondsFromPttl(pttlMs: number, fallbackMs: number): number {
if (pttlMs > 0) {
return Math.max(1, Math.ceil(pttlMs / 1000));
}
return Math.max(1, Math.ceil(fallbackMs / 1000));
}
private cleanupExpiredInMemory(now: number): void {
for (const [mapKey, value] of this.inMemoryStorage) {
const counterExpired = value.expiresAt <= now;
const blockExpired = value.blockExpiresAt <= now;
if (counterExpired && blockExpired) {
this.inMemoryStorage.delete(mapKey);
}
}
}
}

View File

@@ -108,6 +108,10 @@ class RedisLifecycleService implements OnModuleDestroy {
}, },
}); });
client.on('error', (err) => {
logger.error('Redis connection error', err);
});
client.on('connect', () => { client.on('connect', () => {
logger.log(`Connected to Redis at ${host}:${port}`); logger.log(`Connected to Redis at ${host}:${port}`);
}); });

View File

@@ -1,37 +0,0 @@
import { Injectable } from '@nestjs/common';
import { OnEvent } from '@nestjs/event-emitter';
import { CacheTagsService } from '../common/cache/cache-tags.service';
import { CACHE_NAMESPACE, dollsListOwnerTag } from '../common/cache/cache-keys';
import { DollEvents } from './events/doll.events';
import {
type DollCreatedEvent,
type DollDeletedEvent,
type DollUpdatedEvent,
} from './events/doll.events';
@Injectable()
export class DollsCacheInvalidationService {
constructor(private readonly cacheTagsService: CacheTagsService) {}
@OnEvent(DollEvents.DOLL_CREATED)
async handleDollCreated(payload: DollCreatedEvent): Promise<void> {
await this.invalidateOwnerLists(payload.userId);
}
@OnEvent(DollEvents.DOLL_UPDATED)
async handleDollUpdated(payload: DollUpdatedEvent): Promise<void> {
await this.invalidateOwnerLists(payload.userId);
}
@OnEvent(DollEvents.DOLL_DELETED)
async handleDollDeleted(payload: DollDeletedEvent): Promise<void> {
await this.invalidateOwnerLists(payload.userId);
}
private async invalidateOwnerLists(ownerId: string): Promise<void> {
await this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.DOLLS_LIST,
dollsListOwnerTag(ownerId),
);
}
}

View File

@@ -1,26 +1,15 @@
import { Module, forwardRef } from '@nestjs/common'; import { Module, forwardRef } from '@nestjs/common';
import { DollsService } from './dolls.service'; import { DollsService } from './dolls.service';
import { DollsController } from './dolls.controller'; import { DollsController } from './dolls.controller';
import { DollsCacheInvalidationService } from './dolls-cache-invalidation.service';
import { DollsNotificationService } from './dolls-notification.service'; import { DollsNotificationService } from './dolls-notification.service';
import { DatabaseModule } from '../database/database.module'; import { DatabaseModule } from '../database/database.module';
import { AuthModule } from '../auth/auth.module'; import { AuthModule } from '../auth/auth.module';
import { WsModule } from '../ws/ws.module'; import { WsModule } from '../ws/ws.module';
import { FriendsModule } from '../friends/friends.module';
@Module({ @Module({
imports: [ imports: [DatabaseModule, AuthModule, forwardRef(() => WsModule)],
DatabaseModule,
AuthModule,
FriendsModule,
forwardRef(() => WsModule),
],
controllers: [DollsController], controllers: [DollsController],
providers: [ providers: [DollsService, DollsNotificationService],
DollsService,
DollsNotificationService,
DollsCacheInvalidationService,
],
exports: [DollsService], exports: [DollsService],
}) })
export class DollsModule {} export class DollsModule {}

View File

@@ -4,9 +4,6 @@ import { DollsService } from './dolls.service';
import { PrismaService } from '../database/prisma.service'; import { PrismaService } from '../database/prisma.service';
import { NotFoundException, ForbiddenException } from '@nestjs/common'; import { NotFoundException, ForbiddenException } from '@nestjs/common';
import { Doll } from '@prisma/client'; import { Doll } from '@prisma/client';
import { CacheService } from '../common/cache/cache.service';
import { CacheTagsService } from '../common/cache/cache-tags.service';
import { FriendsService } from '../friends/friends.service';
describe('DollsService', () => { describe('DollsService', () => {
let service: DollsService; let service: DollsService;
@@ -34,6 +31,9 @@ describe('DollsService', () => {
findFirst: jest.fn().mockResolvedValue(mockDoll), findFirst: jest.fn().mockResolvedValue(mockDoll),
update: jest.fn().mockResolvedValue(mockDoll), update: jest.fn().mockResolvedValue(mockDoll),
}, },
friendship: {
findMany: jest.fn().mockResolvedValue([]),
},
$transaction: jest.fn((callback) => { $transaction: jest.fn((callback) => {
// eslint-disable-next-line @typescript-eslint/no-unsafe-return // eslint-disable-next-line @typescript-eslint/no-unsafe-return
return callback(mockPrismaService); return callback(mockPrismaService);
@@ -47,25 +47,6 @@ describe('DollsService', () => {
emit: jest.fn(), emit: jest.fn(),
}; };
const mockCacheService = {
get: jest.fn().mockResolvedValue(null),
set: jest.fn().mockResolvedValue(true),
getNamespacedKey: jest
.fn()
.mockImplementation(
(namespace: string, key: string) => `friendolls:${namespace}:${key}`,
),
recordError: jest.fn(),
};
const mockCacheTagsService = {
rememberKeyForTag: jest.fn().mockResolvedValue(undefined),
};
const mockFriendsService = {
areFriends: jest.fn().mockResolvedValue(false),
};
beforeEach(async () => { beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({ const module: TestingModule = await Test.createTestingModule({
providers: [ providers: [
@@ -78,18 +59,6 @@ describe('DollsService', () => {
provide: EventEmitter2, provide: EventEmitter2,
useValue: mockEventEmitter, useValue: mockEventEmitter,
}, },
{
provide: CacheService,
useValue: mockCacheService,
},
{
provide: CacheTagsService,
useValue: mockCacheTagsService,
},
{
provide: FriendsService,
useValue: mockFriendsService,
},
], ],
}).compile(); }).compile();
@@ -143,7 +112,10 @@ describe('DollsService', () => {
const ownerId = 'friend-1'; const ownerId = 'friend-1';
const requestingUserId = 'user-1'; const requestingUserId = 'user-1';
(mockFriendsService.areFriends as jest.Mock).mockResolvedValueOnce(true); // Mock friendship
jest
.spyOn(prismaService.friendship, 'findMany')
.mockResolvedValueOnce([{ friendId: ownerId } as any]);
await service.listByOwner(ownerId, requestingUserId); await service.listByOwner(ownerId, requestingUserId);
@@ -162,7 +134,10 @@ describe('DollsService', () => {
const ownerId = 'stranger-1'; const ownerId = 'stranger-1';
const requestingUserId = 'user-1'; const requestingUserId = 'user-1';
(mockFriendsService.areFriends as jest.Mock).mockResolvedValueOnce(false); // Mock empty friendship (default)
jest
.spyOn(prismaService.friendship, 'findMany')
.mockResolvedValueOnce([]);
await expect( await expect(
service.listByOwner(ownerId, requestingUserId), service.listByOwner(ownerId, requestingUserId),
@@ -188,10 +163,7 @@ describe('DollsService', () => {
}); });
it('should throw NotFoundException if doll not accessible', async () => { it('should throw NotFoundException if doll not accessible', async () => {
jest jest.spyOn(prismaService.doll, 'findFirst').mockResolvedValueOnce(null);
.spyOn(prismaService.doll, 'findFirst')
.mockResolvedValueOnce({ ...mockDoll, userId: 'user-2' });
(mockFriendsService.areFriends as jest.Mock).mockResolvedValueOnce(false);
await expect(service.findOne('doll-1', 'user-1')).rejects.toThrow( await expect(service.findOne('doll-1', 'user-1')).rejects.toThrow(
NotFoundException, NotFoundException,
@@ -207,7 +179,7 @@ describe('DollsService', () => {
expect(prismaService.doll.update).toHaveBeenCalled(); expect(prismaService.doll.update).toHaveBeenCalled();
}); });
it('should throw NotFoundException if not owner and not a friend', async () => { it('should throw ForbiddenException if not owner', async () => {
jest jest
.spyOn(prismaService.doll, 'findFirst') .spyOn(prismaService.doll, 'findFirst')
.mockResolvedValueOnce({ ...mockDoll, userId: 'user-2' }); .mockResolvedValueOnce({ ...mockDoll, userId: 'user-2' });
@@ -215,7 +187,7 @@ describe('DollsService', () => {
const updateDto = { name: 'Updated Doll' }; const updateDto = { name: 'Updated Doll' };
await expect( await expect(
service.update('doll-1', 'user-1', updateDto), service.update('doll-1', 'user-1', updateDto),
).rejects.toThrow(NotFoundException); ).rejects.toThrow(ForbiddenException);
}); });
}); });
@@ -231,13 +203,13 @@ describe('DollsService', () => {
}); });
}); });
it('should throw NotFoundException if not owner and not a friend', async () => { it('should throw ForbiddenException if not owner', async () => {
jest jest
.spyOn(prismaService.doll, 'findFirst') .spyOn(prismaService.doll, 'findFirst')
.mockResolvedValueOnce({ ...mockDoll, userId: 'user-2' }); .mockResolvedValueOnce({ ...mockDoll, userId: 'user-2' });
await expect(service.remove('doll-1', 'user-1')).rejects.toThrow( await expect(service.remove('doll-1', 'user-1')).rejects.toThrow(
NotFoundException, ForbiddenException,
); );
}); });
}); });

View File

@@ -15,16 +15,6 @@ import {
DollUpdatedEvent, DollUpdatedEvent,
DollDeletedEvent, DollDeletedEvent,
} from './events/doll.events'; } from './events/doll.events';
import { CacheService } from '../common/cache/cache.service';
import { CacheTagsService } from '../common/cache/cache-tags.service';
import {
CACHE_NAMESPACE,
CACHE_TTL_SECONDS,
dollsListCacheKey,
dollsListOwnerTag,
dollsListViewerTag,
} from '../common/cache/cache-keys';
import { FriendsService } from '../friends/friends.service';
@Injectable() @Injectable()
export class DollsService { export class DollsService {
@@ -33,11 +23,16 @@ export class DollsService {
constructor( constructor(
private readonly prisma: PrismaService, private readonly prisma: PrismaService,
private readonly eventEmitter: EventEmitter2, private readonly eventEmitter: EventEmitter2,
private readonly cacheService: CacheService,
private readonly cacheTagsService: CacheTagsService,
private readonly friendsService: FriendsService,
) {} ) {}
async getFriendIds(userId: string): Promise<string[]> {
const friendships = await this.prisma.friendship.findMany({
where: { userId },
select: { friendId: true },
});
return friendships.map((f) => f.friendId);
}
async create( async create(
requestingUserId: string, requestingUserId: string,
createDollDto: CreateDollDto, createDollDto: CreateDollDto,
@@ -81,48 +76,6 @@ export class DollsService {
async listByOwner( async listByOwner(
ownerId: string, ownerId: string,
requestingUserId: string, requestingUserId: string,
): Promise<Doll[]> {
const cacheKey = dollsListCacheKey(ownerId, requestingUserId);
const namespacedKey = this.cacheService.getNamespacedKey(
CACHE_NAMESPACE.DOLLS_LIST,
cacheKey,
);
const cached = await this.cacheService.get(namespacedKey);
if (cached) {
try {
return JSON.parse(cached) as Doll[];
} catch (error) {
this.cacheService.recordError('dolls list parse', namespacedKey, error);
}
}
const dolls = await this.listByOwnerFromDatabase(ownerId, requestingUserId);
await this.cacheService.set(
namespacedKey,
JSON.stringify(dolls),
CACHE_TTL_SECONDS.DOLLS_LIST,
);
await Promise.all([
this.cacheTagsService.rememberKeyForTag(
CACHE_NAMESPACE.DOLLS_LIST,
dollsListOwnerTag(ownerId),
cacheKey,
),
this.cacheTagsService.rememberKeyForTag(
CACHE_NAMESPACE.DOLLS_LIST,
dollsListViewerTag(requestingUserId),
cacheKey,
),
]);
return dolls;
}
private async listByOwnerFromDatabase(
ownerId: string,
requestingUserId: string,
): Promise<Doll[]> { ): Promise<Doll[]> {
// If requesting own dolls, no need to check friendship // If requesting own dolls, no need to check friendship
if (ownerId === requestingUserId) { if (ownerId === requestingUserId) {
@@ -138,11 +91,8 @@ export class DollsService {
} }
// If requesting someone else's dolls, check friendship // If requesting someone else's dolls, check friendship
const isFriend = await this.friendsService.areFriends( const friendIds = await this.getFriendIds(requestingUserId);
requestingUserId, if (!friendIds.includes(ownerId)) {
ownerId,
);
if (!isFriend) {
throw new ForbiddenException('You are not friends with this user'); throw new ForbiddenException('You are not friends with this user');
} }
@@ -158,9 +108,13 @@ export class DollsService {
} }
async findOne(id: string, requestingUserId: string): Promise<Doll> { async findOne(id: string, requestingUserId: string): Promise<Doll> {
const friendIds = await this.getFriendIds(requestingUserId);
const accessibleUserIds = [requestingUserId, ...friendIds];
const doll = await this.prisma.doll.findFirst({ const doll = await this.prisma.doll.findFirst({
where: { where: {
id, id,
userId: { in: accessibleUserIds },
deletedAt: null, deletedAt: null,
}, },
}); });
@@ -171,18 +125,6 @@ export class DollsService {
); );
} }
if (doll.userId !== requestingUserId) {
const isFriend = await this.friendsService.areFriends(
requestingUserId,
doll.userId,
);
if (!isFriend) {
throw new NotFoundException(
`Doll with ID ${id} not found or access denied`,
);
}
}
return doll; return doll;
} }

View File

@@ -1,74 +0,0 @@
import { Injectable } from '@nestjs/common';
import { OnEvent } from '@nestjs/event-emitter';
import { CacheTagsService } from '../common/cache/cache-tags.service';
import {
CACHE_NAMESPACE,
dollsListViewerTag,
friendshipCheckUserTag,
friendsListDependsOnUserTag,
friendsListOwnerTag,
} from '../common/cache/cache-keys';
import { FriendEvents } from './events/friend.events';
import type {
FriendRequestAcceptedEvent,
UnfriendedEvent,
} from './events/friend.events';
@Injectable()
export class FriendsCacheInvalidationService {
constructor(private readonly cacheTagsService: CacheTagsService) {}
@OnEvent(FriendEvents.REQUEST_ACCEPTED)
async handleFriendAccepted(
payload: FriendRequestAcceptedEvent,
): Promise<void> {
const senderId = payload.friendRequest.senderId;
const receiverId = payload.friendRequest.receiverId;
await this.invalidateFriendAndDollViews(senderId, receiverId);
}
@OnEvent(FriendEvents.UNFRIENDED)
async handleUnfriended(payload: UnfriendedEvent): Promise<void> {
await this.invalidateFriendAndDollViews(payload.userId, payload.friendId);
}
private async invalidateFriendAndDollViews(
firstUserId: string,
secondUserId: string,
): Promise<void> {
await Promise.all([
this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.FRIENDS_LIST,
friendsListOwnerTag(firstUserId),
),
this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.FRIENDS_LIST,
friendsListOwnerTag(secondUserId),
),
this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.FRIENDS_LIST,
friendsListDependsOnUserTag(firstUserId),
),
this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.FRIENDS_LIST,
friendsListDependsOnUserTag(secondUserId),
),
this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.DOLLS_LIST,
dollsListViewerTag(firstUserId),
),
this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.DOLLS_LIST,
dollsListViewerTag(secondUserId),
),
this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.FRIENDSHIP_CHECK,
friendshipCheckUserTag(firstUserId),
),
this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.FRIENDSHIP_CHECK,
friendshipCheckUserTag(secondUserId),
),
]);
}
}

View File

@@ -1,6 +1,5 @@
import { Module, forwardRef } from '@nestjs/common'; import { Module, forwardRef } from '@nestjs/common';
import { FriendsController } from './friends.controller'; import { FriendsController } from './friends.controller';
import { FriendsCacheInvalidationService } from './friends-cache-invalidation.service';
import { FriendsService } from './friends.service'; import { FriendsService } from './friends.service';
import { FriendsNotificationService } from './friends-notification.service'; import { FriendsNotificationService } from './friends-notification.service';
import { DatabaseModule } from '../database/database.module'; import { DatabaseModule } from '../database/database.module';
@@ -16,11 +15,7 @@ import { WsModule } from '../ws/ws.module';
forwardRef(() => WsModule), forwardRef(() => WsModule),
], ],
controllers: [FriendsController], controllers: [FriendsController],
providers: [ providers: [FriendsService, FriendsNotificationService],
FriendsService,
FriendsNotificationService,
FriendsCacheInvalidationService,
],
exports: [FriendsService], exports: [FriendsService],
}) })
export class FriendsModule {} export class FriendsModule {}

View File

@@ -2,8 +2,6 @@ import { Test, TestingModule } from '@nestjs/testing';
import { EventEmitter2 } from '@nestjs/event-emitter'; import { EventEmitter2 } from '@nestjs/event-emitter';
import { FriendsService } from './friends.service'; import { FriendsService } from './friends.service';
import { PrismaService } from '../database/prisma.service'; import { PrismaService } from '../database/prisma.service';
import { CacheService } from '../common/cache/cache.service';
import { CacheTagsService } from '../common/cache/cache-tags.service';
import { import {
NotFoundException, NotFoundException,
BadRequestException, BadRequestException,
@@ -19,8 +17,6 @@ enum FriendRequestStatus {
describe('FriendsService', () => { describe('FriendsService', () => {
let service: FriendsService; let service: FriendsService;
let eventEmitter: EventEmitter2; let eventEmitter: EventEmitter2;
let cacheService: CacheService;
let cacheTagsService: CacheTagsService;
const mockUser1 = { const mockUser1 = {
id: 'user-1', id: 'user-1',
@@ -94,21 +90,6 @@ describe('FriendsService', () => {
emit: jest.fn(), emit: jest.fn(),
}; };
const mockCacheService = {
get: jest.fn().mockResolvedValue(null),
set: jest.fn().mockResolvedValue(true),
getNamespacedKey: jest
.fn()
.mockImplementation(
(namespace: string, key: string) => `friendolls:${namespace}:${key}`,
),
recordError: jest.fn(),
};
const mockCacheTagsService = {
rememberKeyForTag: jest.fn().mockResolvedValue(undefined),
};
beforeEach(async () => { beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({ const module: TestingModule = await Test.createTestingModule({
providers: [ providers: [
@@ -121,21 +102,11 @@ describe('FriendsService', () => {
provide: EventEmitter2, provide: EventEmitter2,
useValue: mockEventEmitter, useValue: mockEventEmitter,
}, },
{
provide: CacheService,
useValue: mockCacheService,
},
{
provide: CacheTagsService,
useValue: mockCacheTagsService,
},
], ],
}).compile(); }).compile();
service = module.get<FriendsService>(FriendsService); service = module.get<FriendsService>(FriendsService);
eventEmitter = module.get<EventEmitter2>(EventEmitter2); eventEmitter = module.get<EventEmitter2>(EventEmitter2);
cacheService = module.get<CacheService>(CacheService);
cacheTagsService = module.get<CacheTagsService>(CacheTagsService);
jest.clearAllMocks(); jest.clearAllMocks();
}); });
@@ -449,8 +420,6 @@ describe('FriendsService', () => {
createdAt: 'desc', createdAt: 'desc',
}, },
}); });
expect(cacheService.set).toHaveBeenCalled();
expect(cacheTagsService.rememberKeyForTag).toHaveBeenCalled();
}); });
}); });
@@ -500,12 +469,6 @@ describe('FriendsService', () => {
const result = await service.areFriends('user-1', 'user-2'); const result = await service.areFriends('user-1', 'user-2');
expect(result).toBe(true); expect(result).toBe(true);
expect(cacheService.set).toHaveBeenCalledWith(
expect.any(String),
'1',
expect.any(Number),
);
expect(cacheTagsService.rememberKeyForTag).toHaveBeenCalled();
}); });
it('should return false when users are not friends', async () => { it('should return false when users are not friends', async () => {
@@ -514,11 +477,6 @@ describe('FriendsService', () => {
const result = await service.areFriends('user-1', 'user-2'); const result = await service.areFriends('user-1', 'user-2');
expect(result).toBe(false); expect(result).toBe(false);
expect(cacheService.set).toHaveBeenCalledWith(
expect.any(String),
'0',
expect.any(Number),
);
}); });
}); });
}); });

View File

@@ -15,17 +15,6 @@ import {
FriendRequestDeniedEvent, FriendRequestDeniedEvent,
UnfriendedEvent, UnfriendedEvent,
} from './events/friend.events'; } from './events/friend.events';
import { CacheService } from '../common/cache/cache.service';
import { CacheTagsService } from '../common/cache/cache-tags.service';
import {
CACHE_NAMESPACE,
CACHE_TTL_SECONDS,
friendshipCheckCacheKey,
friendshipCheckUserTag,
friendsListCacheKey,
friendsListDependsOnUserTag,
friendsListOwnerTag,
} from '../common/cache/cache-keys';
export type FriendRequestWithRelations = FriendRequest & { export type FriendRequestWithRelations = FriendRequest & {
sender: User; sender: User;
@@ -39,8 +28,6 @@ export class FriendsService {
constructor( constructor(
private readonly prisma: PrismaService, private readonly prisma: PrismaService,
private readonly eventEmitter: EventEmitter2, private readonly eventEmitter: EventEmitter2,
private readonly cacheService: CacheService,
private readonly cacheTagsService: CacheTagsService,
) {} ) {}
async sendFriendRequest( async sendFriendRequest(
@@ -285,28 +272,7 @@ export class FriendsService {
} }
async getFriends(userId: string) { async getFriends(userId: string) {
const cacheKey = friendsListCacheKey(userId); return this.prisma.friendship.findMany({
const namespacedKey = this.cacheService.getNamespacedKey(
CACHE_NAMESPACE.FRIENDS_LIST,
cacheKey,
);
const cached = await this.cacheService.get(namespacedKey);
if (cached) {
try {
return JSON.parse(cached) as Awaited<
ReturnType<PrismaService['friendship']['findMany']>
>;
} catch (error) {
this.cacheService.recordError(
'friends list parse',
namespacedKey,
error,
);
}
}
const friendships = await this.prisma.friendship.findMany({
where: { userId }, where: { userId },
include: { include: {
friend: { friend: {
@@ -319,29 +285,6 @@ export class FriendsService {
createdAt: 'desc', createdAt: 'desc',
}, },
}); });
await this.cacheService.set(
namespacedKey,
JSON.stringify(friendships),
CACHE_TTL_SECONDS.FRIENDS_LIST,
);
const dependentFriendTags = friendships.map((friendship) =>
friendsListDependsOnUserTag(friendship.friendId),
);
const tags = [friendsListOwnerTag(userId), ...dependentFriendTags];
await Promise.all(
tags.map((tag) =>
this.cacheTagsService.rememberKeyForTag(
CACHE_NAMESPACE.FRIENDS_LIST,
tag,
cacheKey,
),
),
);
return friendships;
} }
async unfriend(userId: string, friendId: string): Promise<void> { async unfriend(userId: string, friendId: string): Promise<void> {
@@ -380,21 +323,6 @@ export class FriendsService {
} }
async areFriends(userId: string, friendId: string): Promise<boolean> { async areFriends(userId: string, friendId: string): Promise<boolean> {
const cacheKey = friendshipCheckCacheKey(userId, friendId);
const namespacedKey = this.cacheService.getNamespacedKey(
CACHE_NAMESPACE.FRIENDSHIP_CHECK,
cacheKey,
);
const cached = await this.cacheService.get(namespacedKey);
if (cached === '1') {
return true;
}
if (cached === '0') {
return false;
}
const friendship = await this.prisma.friendship.findFirst({ const friendship = await this.prisma.friendship.findFirst({
where: { where: {
userId, userId,
@@ -402,27 +330,6 @@ export class FriendsService {
}, },
}); });
const areFriends = !!friendship; return !!friendship;
await this.cacheService.set(
namespacedKey,
areFriends ? '1' : '0',
CACHE_TTL_SECONDS.FRIENDSHIP_CHECK,
);
await Promise.all([
this.cacheTagsService.rememberKeyForTag(
CACHE_NAMESPACE.FRIENDSHIP_CHECK,
friendshipCheckUserTag(userId),
cacheKey,
),
this.cacheTagsService.rememberKeyForTag(
CACHE_NAMESPACE.FRIENDSHIP_CHECK,
friendshipCheckUserTag(friendId),
cacheKey,
),
]);
return areFriends;
} }
} }

View File

@@ -9,10 +9,9 @@ export type AuthenticatedSocket = BaseSocket<
{ {
user?: AuthenticatedUser; user?: AuthenticatedUser;
userId?: string; userId?: string;
activeDollId?: string | null;
friends?: Set<string>; // Set of friend user IDs
senderName?: string; senderName?: string;
senderNameCachedAt?: number; senderNameCachedAt?: number;
lastSeenAt?: number; activeDollId?: string | null;
friends?: Set<string>; // Set of friend user IDs
} }
>; >;

View File

@@ -2,8 +2,6 @@ import { Doll } from '@prisma/client';
export const UserEvents = { export const UserEvents = {
ACTIVE_DOLL_CHANGED: 'user.active-doll.changed', ACTIVE_DOLL_CHANGED: 'user.active-doll.changed',
SEARCH_INDEX_INVALIDATED: 'user.search-index.invalidated',
PROFILE_UPDATED: 'user.profile.updated',
} as const; } as const;
export interface UserActiveDollChangedEvent { export interface UserActiveDollChangedEvent {
@@ -11,11 +9,3 @@ export interface UserActiveDollChangedEvent {
dollId: string | null; dollId: string | null;
doll: Doll | null; doll: Doll | null;
} }
export interface UserSearchIndexInvalidatedEvent {
userId?: string;
}
export interface UserProfileUpdatedEvent {
userId: string;
}

View File

@@ -1,38 +0,0 @@
import { Injectable } from '@nestjs/common';
import { OnEvent } from '@nestjs/event-emitter';
import { CacheTagsService } from '../common/cache/cache-tags.service';
import {
CACHE_NAMESPACE,
USERS_SEARCH_GLOBAL_TAG,
usersSearchUserTag,
} from '../common/cache/cache-keys';
import { UserEvents } from './events/user.events';
import type { UserSearchIndexInvalidatedEvent } from './events/user.events';
@Injectable()
export class UsersCacheInvalidationService {
constructor(private readonly cacheTagsService: CacheTagsService) {}
@OnEvent(UserEvents.SEARCH_INDEX_INVALIDATED)
async handleSearchIndexInvalidation(
payload: UserSearchIndexInvalidatedEvent,
): Promise<void> {
const tasks: Promise<void>[] = [
this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.USERS_SEARCH,
USERS_SEARCH_GLOBAL_TAG,
),
];
if (payload.userId) {
tasks.push(
this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.USERS_SEARCH,
usersSearchUserTag(payload.userId),
),
);
}
await Promise.all(tasks);
}
}

View File

@@ -1,6 +1,5 @@
import { Module, forwardRef } from '@nestjs/common'; import { Module, forwardRef } from '@nestjs/common';
import { UsersService } from './users.service'; import { UsersService } from './users.service';
import { UsersCacheInvalidationService } from './users-cache-invalidation.service';
import { UsersController } from './users.controller'; import { UsersController } from './users.controller';
import { UsersNotificationService } from './users-notification.service'; import { UsersNotificationService } from './users-notification.service';
import { AuthModule } from '../auth/auth.module'; import { AuthModule } from '../auth/auth.module';
@@ -17,11 +16,7 @@ import { WsModule } from '../ws/ws.module';
*/ */
@Module({ @Module({
imports: [forwardRef(() => AuthModule), forwardRef(() => WsModule)], imports: [forwardRef(() => AuthModule), forwardRef(() => WsModule)],
providers: [ providers: [UsersService, UsersNotificationService],
UsersService,
UsersNotificationService,
UsersCacheInvalidationService,
],
controllers: [UsersController], controllers: [UsersController],
exports: [UsersService], exports: [UsersService],
}) })

View File

@@ -5,13 +5,9 @@ import { NotFoundException, ForbiddenException } from '@nestjs/common';
import { User } from '@prisma/client'; import { User } from '@prisma/client';
import { UpdateUserDto } from './dto/update-user.dto'; import { UpdateUserDto } from './dto/update-user.dto';
import { EventEmitter2 } from '@nestjs/event-emitter'; import { EventEmitter2 } from '@nestjs/event-emitter';
import { CacheService } from '../common/cache/cache.service';
import { CacheTagsService } from '../common/cache/cache-tags.service';
describe('UsersService', () => { describe('UsersService', () => {
let service: UsersService; let service: UsersService;
let cacheService: CacheService;
let cacheTagsService: CacheTagsService;
const mockUser: User & { passwordHash?: string | null } = { const mockUser: User & { passwordHash?: string | null } = {
id: '550e8400-e29b-41d4-a716-446655440000', id: '550e8400-e29b-41d4-a716-446655440000',
@@ -43,21 +39,6 @@ describe('UsersService', () => {
emit: jest.fn(), emit: jest.fn(),
}; };
const mockCacheService = {
get: jest.fn().mockResolvedValue(null),
set: jest.fn().mockResolvedValue(true),
getNamespacedKey: jest
.fn()
.mockImplementation(
(namespace: string, key: string) => `friendolls:${namespace}:${key}`,
),
recordError: jest.fn(),
};
const mockCacheTagsService = {
rememberKeyForTag: jest.fn().mockResolvedValue(undefined),
};
beforeEach(async () => { beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({ const module: TestingModule = await Test.createTestingModule({
providers: [ providers: [
@@ -70,20 +51,10 @@ describe('UsersService', () => {
provide: EventEmitter2, provide: EventEmitter2,
useValue: mockEventEmitter, useValue: mockEventEmitter,
}, },
{
provide: CacheService,
useValue: mockCacheService,
},
{
provide: CacheTagsService,
useValue: mockCacheTagsService,
},
], ],
}).compile(); }).compile();
service = module.get<UsersService>(UsersService); service = module.get<UsersService>(UsersService);
cacheService = module.get<CacheService>(CacheService);
cacheTagsService = module.get<CacheTagsService>(CacheTagsService);
jest.clearAllMocks(); jest.clearAllMocks();
}); });
@@ -256,8 +227,6 @@ describe('UsersService', () => {
username: 'asc', username: 'asc',
}, },
}); });
expect(cacheService.set).toHaveBeenCalled();
expect(cacheTagsService.rememberKeyForTag).toHaveBeenCalled();
}); });
it('should exclude specified user from results', async () => { it('should exclude specified user from results', async () => {

View File

@@ -10,15 +10,6 @@ import { User, Prisma } from '@prisma/client';
import type { UpdateUserDto } from './dto/update-user.dto'; import type { UpdateUserDto } from './dto/update-user.dto';
import { UserEvents } from './events/user.events'; import { UserEvents } from './events/user.events';
import { normalizeEmail } from '../auth/auth.utils'; import { normalizeEmail } from '../auth/auth.utils';
import { CacheService } from '../common/cache/cache.service';
import { CacheTagsService } from '../common/cache/cache-tags.service';
import {
CACHE_NAMESPACE,
CACHE_TTL_SECONDS,
usersSearchUserTag,
usersSearchCacheKey,
USERS_SEARCH_GLOBAL_TAG,
} from '../common/cache/cache-keys';
export interface CreateLocalUserDto { export interface CreateLocalUserDto {
email: string; email: string;
@@ -39,8 +30,6 @@ export class UsersService {
constructor( constructor(
private readonly prisma: PrismaService, private readonly prisma: PrismaService,
private readonly eventEmitter: EventEmitter2, private readonly eventEmitter: EventEmitter2,
private readonly cacheService: CacheService,
private readonly cacheTagsService: CacheTagsService,
) {} ) {}
// Legacy Keycloak user creation removed in favor of local auth. // Legacy Keycloak user creation removed in favor of local auth.
@@ -103,9 +92,6 @@ export class UsersService {
data: updateData, data: updateData,
}); });
this.eventEmitter.emit(UserEvents.SEARCH_INDEX_INVALIDATED, { userId: id });
this.eventEmitter.emit(UserEvents.PROFILE_UPDATED, { userId: id });
this.logger.log(`User ${id} profile update requested`); this.logger.log(`User ${id} profile update requested`);
return updatedUser; return updatedUser;
@@ -135,9 +121,6 @@ export class UsersService {
where: { id }, where: { id },
}); });
this.eventEmitter.emit(UserEvents.SEARCH_INDEX_INVALIDATED, { userId: id });
this.eventEmitter.emit(UserEvents.PROFILE_UPDATED, { userId: id });
this.logger.log(`User ${id} deleted their account`); this.logger.log(`User ${id} deleted their account`);
} }
@@ -145,25 +128,6 @@ export class UsersService {
username?: string, username?: string,
excludeUserId?: string, excludeUserId?: string,
): Promise<User[]> { ): Promise<User[]> {
const cacheKey = usersSearchCacheKey(username, excludeUserId);
const namespacedKey = this.cacheService.getNamespacedKey(
CACHE_NAMESPACE.USERS_SEARCH,
cacheKey,
);
const cached = await this.cacheService.get(namespacedKey);
if (cached) {
try {
return JSON.parse(cached) as User[];
} catch (error) {
this.cacheService.recordError(
'users search parse',
namespacedKey,
error,
);
}
}
const where: Prisma.UserWhereInput = {}; const where: Prisma.UserWhereInput = {};
if (username) { if (username) {
@@ -187,24 +151,6 @@ export class UsersService {
}, },
}); });
await this.cacheService.set(
namespacedKey,
JSON.stringify(users),
CACHE_TTL_SECONDS.USERS_SEARCH,
);
await this.cacheTagsService.rememberKeyForTag(
CACHE_NAMESPACE.USERS_SEARCH,
USERS_SEARCH_GLOBAL_TAG,
cacheKey,
);
if (excludeUserId) {
await this.cacheTagsService.rememberKeyForTag(
CACHE_NAMESPACE.USERS_SEARCH,
usersSearchUserTag(excludeUserId),
cacheKey,
);
}
return users; return users;
} }
@@ -305,7 +251,7 @@ export class UsersService {
const now = new Date(); const now = new Date();
const roles: string[] = []; const roles: string[] = [];
const user = await this.prisma.user.create({ return this.prisma.user.create({
data: { data: {
email: normalizeEmail(createDto.email), email: normalizeEmail(createDto.email),
name: createDto.name, name: createDto.name,
@@ -316,13 +262,6 @@ export class UsersService {
keycloakSub: null, keycloakSub: null,
} as unknown as Prisma.UserUncheckedCreateInput, } as unknown as Prisma.UserUncheckedCreateInput,
}); });
this.eventEmitter.emit(UserEvents.SEARCH_INDEX_INVALIDATED, {
userId: user.id,
});
this.eventEmitter.emit(UserEvents.PROFILE_UPDATED, { userId: user.id });
return user;
} }
async updatePasswordHash( async updatePasswordHash(
@@ -333,8 +272,6 @@ export class UsersService {
where: { id: userId }, where: { id: userId },
data: { passwordHash } as unknown as Prisma.UserUpdateInput, data: { passwordHash } as unknown as Prisma.UserUpdateInput,
}); });
this.eventEmitter.emit(UserEvents.PROFILE_UPDATED, { userId });
} }
async updateLastLogin(userId: string): Promise<void> { async updateLastLogin(userId: string): Promise<void> {

View File

@@ -41,6 +41,7 @@ export class ConnectionHandler {
// Initialize defaults // Initialize defaults
client.data.activeDollId = null; client.data.activeDollId = null;
client.data.friends = new Set(); client.data.friends = new Set();
client.data.senderName = undefined;
// userId is not set yet, it will be set in handleClientInitialize // userId is not set yet, it will be set in handleClientInitialize
this.logger.log(`WebSocket authenticated (Pending Init): ${payload.sub}`); this.logger.log(`WebSocket authenticated (Pending Init): ${payload.sub}`);
@@ -116,9 +117,7 @@ export class ConnectionHandler {
// 3. Register socket mapping (Redis Write) // 3. Register socket mapping (Redis Write)
await this.userSocketService.setSocket(userState.id, client.id); await this.userSocketService.setSocket(userState.id, client.id);
await this.userSocketService.touchLastSeen(userState.id);
client.data.userId = userState.id; client.data.userId = userState.id;
client.data.lastSeenAt = Date.now();
client.data.activeDollId = userState.activeDollId || null; client.data.activeDollId = userState.activeDollId || null;
client.data.friends = new Set(friends.map((f) => f.friendId)); client.data.friends = new Set(friends.map((f) => f.friendId));
@@ -151,8 +150,7 @@ export class ConnectionHandler {
// Check if this socket is still the active one for the user // Check if this socket is still the active one for the user
const currentSocketId = await this.userSocketService.getSocket(userId); const currentSocketId = await this.userSocketService.getSocket(userId);
if (currentSocketId === client.id) { if (currentSocketId === client.id) {
await this.userSocketService.removeSocket(userId, client.id); await this.userSocketService.removeSocket(userId);
await this.userSocketService.touchLastSeen(userId);
// Note: throttling remove is done in gateway // Note: throttling remove is done in gateway
// Notify friends that this user has disconnected // Notify friends that this user has disconnected
@@ -182,7 +180,5 @@ export class ConnectionHandler {
this.logger.log( this.logger.log(
`Client id: ${client.id} disconnected (user: ${user?.userId || 'unknown'})`, `Client id: ${client.id} disconnected (user: ${user?.userId || 'unknown'})`,
); );
await this.userSocketService.removeSocketById(client.id);
} }
} }

View File

@@ -41,7 +41,6 @@ export class CursorHandler {
// Broadcast to online friends // Broadcast to online friends
const friends = client.data.friends; const friends = client.data.friends;
if (friends) { if (friends) {
await this.broadcaster.touchPresence(client);
const payload = { const payload = {
userId: currentUserId, userId: currentUserId,
position: data, position: data,

View File

@@ -1,9 +1,9 @@
import { Logger } from '@nestjs/common'; import { Logger } from '@nestjs/common';
import { WsException } from '@nestjs/websockets'; import { WsException } from '@nestjs/websockets';
import type { AuthenticatedSocket } from '../../../types/socket'; import type { AuthenticatedSocket } from '../../../types/socket';
import { PrismaService } from '../../../database/prisma.service';
import { SendInteractionDto } from '../../dto/send-interaction.dto'; import { SendInteractionDto } from '../../dto/send-interaction.dto';
import { InteractionPayloadDto } from '../../dto/interaction-payload.dto'; import { InteractionPayloadDto } from '../../dto/interaction-payload.dto';
import { PrismaService } from '../../../database/prisma.service';
import { UserSocketService } from '../user-socket.service'; import { UserSocketService } from '../user-socket.service';
import { WsNotificationService } from '../ws-notification.service'; import { WsNotificationService } from '../ws-notification.service';
import { WS_EVENT } from '../ws-events'; import { WS_EVENT } from '../ws-events';
@@ -50,8 +50,6 @@ export class InteractionHandler {
client: AuthenticatedSocket, client: AuthenticatedSocket,
data: SendInteractionDto, data: SendInteractionDto,
) { ) {
await this.wsNotificationService.maybeTouchPresence(client);
const user = client.data.user; const user = client.data.user;
const currentUserId = Validator.validateInitialized(client); const currentUserId = Validator.validateInitialized(client);

View File

@@ -6,7 +6,6 @@ import { JwtVerificationService } from '../../auth/services/jwt-verification.ser
import { PrismaService } from '../../database/prisma.service'; import { PrismaService } from '../../database/prisma.service';
import { UserSocketService } from './user-socket.service'; import { UserSocketService } from './user-socket.service';
import { WsNotificationService } from './ws-notification.service'; import { WsNotificationService } from './ws-notification.service';
import { ConfigService } from '@nestjs/config';
import { SendInteractionDto } from '../dto/send-interaction.dto'; import { SendInteractionDto } from '../dto/send-interaction.dto';
import { WsException } from '@nestjs/websockets'; import { WsException } from '@nestjs/websockets';
@@ -23,8 +22,6 @@ type MockSocket = {
userId?: string; userId?: string;
activeDollId?: string | null; activeDollId?: string | null;
friends?: Set<string>; friends?: Set<string>;
senderName?: string;
senderNameCachedAt?: number;
}; };
handshake?: any; handshake?: any;
disconnect?: jest.Mock; disconnect?: jest.Mock;
@@ -46,7 +43,6 @@ describe('StateGateway', () => {
let mockUserSocketService: Partial<UserSocketService>; let mockUserSocketService: Partial<UserSocketService>;
let mockRedisClient: { publish: jest.Mock }; let mockRedisClient: { publish: jest.Mock };
let mockRedisSubscriber: { subscribe: jest.Mock; on: jest.Mock }; let mockRedisSubscriber: { subscribe: jest.Mock; on: jest.Mock };
let mockConfigService: { get: jest.Mock };
let mockWsNotificationService: { let mockWsNotificationService: {
setIo: jest.Mock; setIo: jest.Mock;
emitToUser: jest.Mock; emitToUser: jest.Mock;
@@ -54,8 +50,6 @@ describe('StateGateway', () => {
emitToSocket: jest.Mock; emitToSocket: jest.Mock;
updateActiveDollCache: jest.Mock; updateActiveDollCache: jest.Mock;
publishActiveDollUpdate: jest.Mock; publishActiveDollUpdate: jest.Mock;
clearSenderNameCache: jest.Mock;
maybeTouchPresence: jest.Mock;
}; };
beforeEach(async () => { beforeEach(async () => {
@@ -96,12 +90,9 @@ describe('StateGateway', () => {
mockUserSocketService = { mockUserSocketService = {
setSocket: jest.fn().mockResolvedValue(undefined), setSocket: jest.fn().mockResolvedValue(undefined),
removeSocket: jest.fn().mockResolvedValue(undefined), removeSocket: jest.fn().mockResolvedValue(undefined),
removeSocketById: jest.fn().mockResolvedValue(undefined),
touchLastSeen: jest.fn().mockResolvedValue(undefined),
getSocket: jest.fn().mockResolvedValue(null), getSocket: jest.fn().mockResolvedValue(null),
isUserOnline: jest.fn().mockResolvedValue(false), isUserOnline: jest.fn().mockResolvedValue(false),
getFriendsSockets: jest.fn().mockResolvedValue([]), getFriendsSockets: jest.fn().mockResolvedValue([]),
cleanupStalePresence: jest.fn().mockResolvedValue(0),
}; };
mockRedisClient = { mockRedisClient = {
@@ -113,10 +104,6 @@ describe('StateGateway', () => {
on: jest.fn(), on: jest.fn(),
}; };
mockConfigService = {
get: jest.fn().mockReturnValue(undefined),
};
mockWsNotificationService = { mockWsNotificationService = {
setIo: jest.fn(), setIo: jest.fn(),
emitToUser: jest.fn(), emitToUser: jest.fn(),
@@ -124,8 +111,6 @@ describe('StateGateway', () => {
emitToSocket: jest.fn(), emitToSocket: jest.fn(),
updateActiveDollCache: jest.fn(), updateActiveDollCache: jest.fn(),
publishActiveDollUpdate: jest.fn(), publishActiveDollUpdate: jest.fn(),
clearSenderNameCache: jest.fn().mockResolvedValue(undefined),
maybeTouchPresence: jest.fn().mockResolvedValue(undefined),
}; };
const module: TestingModule = await Test.createTestingModule({ const module: TestingModule = await Test.createTestingModule({
@@ -138,7 +123,6 @@ describe('StateGateway', () => {
{ provide: PrismaService, useValue: mockPrismaService }, { provide: PrismaService, useValue: mockPrismaService },
{ provide: UserSocketService, useValue: mockUserSocketService }, { provide: UserSocketService, useValue: mockUserSocketService },
{ provide: WsNotificationService, useValue: mockWsNotificationService }, { provide: WsNotificationService, useValue: mockWsNotificationService },
{ provide: ConfigService, useValue: mockConfigService },
{ provide: 'REDIS_CLIENT', useValue: mockRedisClient }, { provide: 'REDIS_CLIENT', useValue: mockRedisClient },
{ provide: 'REDIS_SUBSCRIBER_CLIENT', useValue: mockRedisSubscriber }, { provide: 'REDIS_SUBSCRIBER_CLIENT', useValue: mockRedisSubscriber },
], ],
@@ -175,32 +159,9 @@ describe('StateGateway', () => {
expect(mockRedisSubscriber.subscribe).toHaveBeenCalledWith( expect(mockRedisSubscriber.subscribe).toHaveBeenCalledWith(
'active-doll-update', 'active-doll-update',
'friend-cache-update', 'friend-cache-update',
'user-profile-cache-invalidate',
expect.any(Function), expect.any(Function),
); );
}); });
it('should route user profile cache invalidation messages', async () => {
gateway.afterInit();
const onCalls = (mockRedisSubscriber.on as jest.Mock).mock.calls;
const messageHandler = onCalls.find(
(call) => call[0] === 'message',
)?.[1] as ((channel: string, message: string) => void) | undefined;
expect(messageHandler).toBeDefined();
messageHandler?.(
'user-profile-cache-invalidate',
JSON.stringify({ userId: 'user-1' }),
);
await Promise.resolve();
expect(
mockWsNotificationService.clearSenderNameCache,
).toHaveBeenCalledWith('user-1');
});
}); });
describe('handleConnection', () => { describe('handleConnection', () => {
@@ -297,9 +258,6 @@ describe('StateGateway', () => {
'user-id', 'user-id',
'client1', 'client1',
); );
expect(mockUserSocketService.touchLastSeen).toHaveBeenCalledWith(
'user-id',
);
// 2. Fetch State (DB) // 2. Fetch State (DB)
expect(mockPrismaService.user!.findUnique).toHaveBeenCalledWith({ expect(mockPrismaService.user!.findUnique).toHaveBeenCalledWith({
@@ -399,13 +357,6 @@ describe('StateGateway', () => {
expect(mockUserSocketService.getSocket).toHaveBeenCalledWith('user-id'); expect(mockUserSocketService.getSocket).toHaveBeenCalledWith('user-id');
expect(mockUserSocketService.removeSocket).toHaveBeenCalledWith( expect(mockUserSocketService.removeSocket).toHaveBeenCalledWith(
'user-id', 'user-id',
'client1',
);
expect(mockUserSocketService.touchLastSeen).toHaveBeenCalledWith(
'user-id',
);
expect(mockUserSocketService.removeSocketById).toHaveBeenCalledWith(
'client1',
); );
expect(mockWsNotificationService.emitToSocket).toHaveBeenCalledWith( expect(mockWsNotificationService.emitToSocket).toHaveBeenCalledWith(
'friend-socket-id', 'friend-socket-id',

View File

@@ -29,11 +29,6 @@ import { InteractionHandler } from './interaction/handler';
import { RedisHandler } from './utils/redis-handler'; import { RedisHandler } from './utils/redis-handler';
import { Broadcaster } from './utils/broadcasting'; import { Broadcaster } from './utils/broadcasting';
import { Throttler } from './utils/throttling'; import { Throttler } from './utils/throttling';
import { ConfigService } from '@nestjs/config';
import { parsePositiveInteger } from '../../common/config/env.utils';
const DEFAULT_PRESENCE_STALE_AGE_MS = 7 * 24 * 60 * 60 * 1000;
const DEFAULT_PRESENCE_CLEANUP_INTERVAL_MS = 5 * 60 * 1000;
@WebSocketGateway() @WebSocketGateway()
export class StateGateway export class StateGateway
@@ -54,16 +49,12 @@ export class StateGateway
private readonly cursorHandler: CursorHandler; private readonly cursorHandler: CursorHandler;
private readonly statusHandler: StatusHandler; private readonly statusHandler: StatusHandler;
private readonly interactionHandler: InteractionHandler; private readonly interactionHandler: InteractionHandler;
private readonly presenceStaleAgeMs: number;
private readonly presenceCleanupIntervalMs: number;
private presenceCleanupTimer: NodeJS.Timeout | null = null;
constructor( constructor(
private readonly jwtVerificationService: JwtVerificationService, private readonly jwtVerificationService: JwtVerificationService,
private readonly prisma: PrismaService, private readonly prisma: PrismaService,
private readonly userSocketService: UserSocketService, private readonly userSocketService: UserSocketService,
private readonly wsNotificationService: WsNotificationService, private readonly wsNotificationService: WsNotificationService,
private readonly configService: ConfigService,
@Inject(REDIS_CLIENT) private readonly redisClient: Redis | null, @Inject(REDIS_CLIENT) private readonly redisClient: Redis | null,
@Inject(REDIS_SUBSCRIBER_CLIENT) @Inject(REDIS_SUBSCRIBER_CLIENT)
private readonly redisSubscriber: Redis | null, private readonly redisSubscriber: Redis | null,
@@ -87,14 +78,6 @@ export class StateGateway
this.userSocketService, this.userSocketService,
this.wsNotificationService, this.wsNotificationService,
); );
this.presenceStaleAgeMs = parsePositiveInteger(
this.configService.get<string>('PRESENCE_STALE_AGE_MS'),
DEFAULT_PRESENCE_STALE_AGE_MS,
);
this.presenceCleanupIntervalMs = parsePositiveInteger(
this.configService.get<string>('PRESENCE_CLEANUP_INTERVAL_MS'),
DEFAULT_PRESENCE_CLEANUP_INTERVAL_MS,
);
// Setup Redis subscription for cross-instance communication // Setup Redis subscription for cross-instance communication
if (this.redisSubscriber) { if (this.redisSubscriber) {
@@ -102,7 +85,6 @@ export class StateGateway
.subscribe( .subscribe(
REDIS_CHANNEL.ACTIVE_DOLL_UPDATE, REDIS_CHANNEL.ACTIVE_DOLL_UPDATE,
REDIS_CHANNEL.FRIEND_CACHE_UPDATE, REDIS_CHANNEL.FRIEND_CACHE_UPDATE,
REDIS_CHANNEL.USER_PROFILE_CACHE_INVALIDATE,
(err) => { (err) => {
if (err) { if (err) {
this.logger.error(`Failed to subscribe to Redis channels`, err); this.logger.error(`Failed to subscribe to Redis channels`, err);
@@ -122,9 +104,6 @@ export class StateGateway
} else if (channel === REDIS_CHANNEL.FRIEND_CACHE_UPDATE) { } else if (channel === REDIS_CHANNEL.FRIEND_CACHE_UPDATE) {
// eslint-disable-next-line @typescript-eslint/no-floating-promises // eslint-disable-next-line @typescript-eslint/no-floating-promises
this.redisHandler.handleFriendCacheUpdateMessage(message); this.redisHandler.handleFriendCacheUpdateMessage(message);
} else if (channel === REDIS_CHANNEL.USER_PROFILE_CACHE_INVALIDATE) {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
this.redisHandler.handleUserProfileCacheInvalidateMessage(message);
} }
}); });
} }
@@ -133,11 +112,6 @@ export class StateGateway
afterInit() { afterInit() {
this.logger.log('Initialized'); this.logger.log('Initialized');
this.wsNotificationService.setIo(this.io); this.wsNotificationService.setIo(this.io);
this.presenceCleanupTimer = setInterval(() => {
void this.cleanupStalePresence();
}, this.presenceCleanupIntervalMs);
this.presenceCleanupTimer.unref();
} }
handleConnection(client: AuthenticatedSocket) { handleConnection(client: AuthenticatedSocket) {
@@ -158,6 +132,12 @@ export class StateGateway
} }
} }
onModuleDestroy() {
if (this.redisSubscriber) {
this.redisSubscriber.removeAllListeners('message');
}
}
async isUserOnline(userId: string): Promise<boolean> { async isUserOnline(userId: string): Promise<boolean> {
return this.userSocketService.isUserOnline(userId); return this.userSocketService.isUserOnline(userId);
} }
@@ -185,23 +165,4 @@ export class StateGateway
) { ) {
await this.interactionHandler.handleSendInteraction(client, data); await this.interactionHandler.handleSendInteraction(client, data);
} }
onModuleDestroy() {
if (this.redisSubscriber) {
this.redisSubscriber.removeAllListeners('message');
}
if (this.presenceCleanupTimer) {
clearInterval(this.presenceCleanupTimer);
this.presenceCleanupTimer = null;
}
}
private async cleanupStalePresence(): Promise<void> {
const cutoffMs = Date.now() - this.presenceStaleAgeMs;
const removed = await this.userSocketService.cleanupStalePresence(cutoffMs);
if (removed > 0) {
this.logger.debug(`Cleaned up ${removed} stale presence entries`);
}
}
} }

View File

@@ -47,7 +47,6 @@ export class StatusHandler {
const friends = client.data.friends; const friends = client.data.friends;
if (friends) { if (friends) {
try { try {
await this.broadcaster.touchPresence(client);
const payload = { const payload = {
userId: currentUserId, userId: currentUserId,
status: data, status: data,

View File

@@ -2,76 +2,12 @@ import { Injectable, Inject, Logger } from '@nestjs/common';
import { REDIS_CLIENT } from '../../database/redis.module'; import { REDIS_CLIENT } from '../../database/redis.module';
import Redis from 'ioredis'; import Redis from 'ioredis';
const SOCKET_KEY_PREFIX = 'socket:user:';
const SOCKET_REVERSE_KEY_PREFIX = 'socket:id:';
const LAST_SEEN_KEY_PREFIX = 'presence:last-seen:';
const PRESENCE_ZSET_KEY = 'presence:last-seen:zset';
const SET_SOCKET_MAPPING_SCRIPT = `
local userKey = KEYS[1]
local reverseKey = KEYS[2]
local userId = ARGV[1]
local socketId = ARGV[2]
local ttl = ARGV[3]
local reversePrefix = ARGV[4]
local previousSocketId = redis.call('GET', userKey)
redis.call('SET', userKey, socketId, 'EX', ttl)
redis.call('SET', reverseKey, userId, 'EX', ttl)
if previousSocketId and previousSocketId ~= socketId then
redis.call('DEL', reversePrefix .. previousSocketId)
end
return 1
`;
const REMOVE_SOCKET_MAPPING_SCRIPT = `
local userKey = KEYS[1]
local reversePrefix = ARGV[1]
local expectedSocketId = ARGV[2]
local currentSocketId = redis.call('GET', userKey)
if not currentSocketId then
return 0
end
if expectedSocketId ~= '' and currentSocketId ~= expectedSocketId then
return 0
end
redis.call('DEL', userKey)
redis.call('DEL', reversePrefix .. currentSocketId)
return 1
`;
const REMOVE_BY_SOCKET_ID_SCRIPT = `
local reverseKey = KEYS[1]
local userPrefix = ARGV[1]
local socketId = ARGV[2]
local userId = redis.call('GET', reverseKey)
if not userId then
return 0
end
local userKey = userPrefix .. userId
local currentSocketId = redis.call('GET', userKey)
redis.call('DEL', reverseKey)
if currentSocketId == socketId then
redis.call('DEL', userKey)
end
return 1
`;
@Injectable() @Injectable()
export class UserSocketService { export class UserSocketService {
private readonly logger = new Logger(UserSocketService.name); private readonly logger = new Logger(UserSocketService.name);
private localUserSocketMap: Map<string, string> = new Map(); private localUserSocketMap: Map<string, string> = new Map();
private readonly PREFIX = 'socket:user:';
private readonly TTL = 86400; // 24 hours private readonly TTL = 86400; // 24 hours
private readonly LAST_SEEN_TTL_SECONDS = 604800; // 7 days
constructor( constructor(
@Inject(REDIS_CLIENT) private readonly redisClient: Redis | null, @Inject(REDIS_CLIENT) private readonly redisClient: Redis | null,
@@ -80,15 +16,11 @@ export class UserSocketService {
async setSocket(userId: string, socketId: string): Promise<void> { async setSocket(userId: string, socketId: string): Promise<void> {
if (this.redisClient) { if (this.redisClient) {
try { try {
await this.redisClient.eval( await this.redisClient.set(
SET_SOCKET_MAPPING_SCRIPT, `${this.PREFIX}${userId}`,
2,
`${SOCKET_KEY_PREFIX}${userId}`,
`${SOCKET_REVERSE_KEY_PREFIX}${socketId}`,
userId,
socketId, socketId,
String(this.TTL), 'EX',
SOCKET_REVERSE_KEY_PREFIX, this.TTL,
); );
} catch (error) { } catch (error) {
this.logger.error( this.logger.error(
@@ -104,16 +36,10 @@ export class UserSocketService {
} }
} }
async removeSocket(userId: string, expectedSocketId?: string): Promise<void> { async removeSocket(userId: string): Promise<void> {
if (this.redisClient) { if (this.redisClient) {
try { try {
await this.redisClient.eval( await this.redisClient.del(`${this.PREFIX}${userId}`);
REMOVE_SOCKET_MAPPING_SCRIPT,
1,
`${SOCKET_KEY_PREFIX}${userId}`,
SOCKET_REVERSE_KEY_PREFIX,
expectedSocketId || '',
);
} catch (error) { } catch (error) {
this.logger.error( this.logger.error(
`Failed to remove socket for user ${userId} from Redis`, `Failed to remove socket for user ${userId} from Redis`,
@@ -121,23 +47,13 @@ export class UserSocketService {
); );
} }
} }
if (!expectedSocketId) {
this.localUserSocketMap.delete(userId); this.localUserSocketMap.delete(userId);
return;
}
const currentLocalSocketId = this.localUserSocketMap.get(userId);
if (currentLocalSocketId === expectedSocketId) {
this.localUserSocketMap.delete(userId);
}
} }
async getSocket(userId: string): Promise<string | null> { async getSocket(userId: string): Promise<string | null> {
if (this.redisClient) { if (this.redisClient) {
try { try {
const socketId = await this.redisClient.get( const socketId = await this.redisClient.get(`${this.PREFIX}${userId}`);
`${SOCKET_KEY_PREFIX}${userId}`,
);
return socketId; return socketId;
} catch (error) { } catch (error) {
this.logger.error( this.logger.error(
@@ -166,7 +82,7 @@ export class UserSocketService {
try { try {
// Use pipeline for batch fetching // Use pipeline for batch fetching
const pipeline = this.redisClient.pipeline(); const pipeline = this.redisClient.pipeline();
friendIds.forEach((id) => pipeline.get(`${SOCKET_KEY_PREFIX}${id}`)); friendIds.forEach((id) => pipeline.get(`${this.PREFIX}${id}`));
const results = await pipeline.exec(); const results = await pipeline.exec();
const sockets: { userId: string; socketId: string }[] = []; const sockets: { userId: string; socketId: string }[] = [];
@@ -199,79 +115,4 @@ export class UserSocketService {
} }
return sockets; return sockets;
} }
async touchLastSeen(userId: string): Promise<void> {
const now = Date.now();
if (this.redisClient) {
try {
const key = `${LAST_SEEN_KEY_PREFIX}${userId}`;
await this.redisClient.set(
key,
String(now),
'EX',
this.LAST_SEEN_TTL_SECONDS,
);
await this.redisClient.zadd(PRESENCE_ZSET_KEY, now, userId);
return;
} catch (error) {
this.logger.warn(
`Failed to touch last-seen for user ${userId} in Redis`,
error as Error,
);
}
}
}
async removeSocketById(socketId: string): Promise<void> {
if (!this.redisClient) {
return;
}
try {
await this.redisClient.eval(
REMOVE_BY_SOCKET_ID_SCRIPT,
1,
`${SOCKET_REVERSE_KEY_PREFIX}${socketId}`,
SOCKET_KEY_PREFIX,
socketId,
);
} catch (error) {
this.logger.warn(
`Failed to remove socket mapping by socket id ${socketId}`,
error as Error,
);
}
}
async cleanupStalePresence(cutoffMs: number): Promise<number> {
if (!this.redisClient) {
return 0;
}
try {
const staleUserIds = await this.redisClient.zrangebyscore(
PRESENCE_ZSET_KEY,
'-inf',
cutoffMs,
);
if (staleUserIds.length === 0) {
return 0;
}
const pipeline = this.redisClient.pipeline();
staleUserIds.forEach((userId) => {
pipeline.del(`${LAST_SEEN_KEY_PREFIX}${userId}`);
});
pipeline.zremrangebyscore(PRESENCE_ZSET_KEY, '-inf', cutoffMs);
await pipeline.exec();
return staleUserIds.length;
} catch (error) {
this.logger.warn(
'Failed to cleanup stale presence entries',
error as Error,
);
return 0;
}
}
} }

View File

@@ -1,6 +1,5 @@
import { UserSocketService } from '../user-socket.service'; import { UserSocketService } from '../user-socket.service';
import { WsNotificationService } from '../ws-notification.service'; import { WsNotificationService } from '../ws-notification.service';
import type { AuthenticatedSocket } from '../../../types/socket';
export class Broadcaster { export class Broadcaster {
constructor( constructor(
@@ -8,10 +7,6 @@ export class Broadcaster {
private readonly wsNotificationService: WsNotificationService, private readonly wsNotificationService: WsNotificationService,
) {} ) {}
async touchPresence(client: AuthenticatedSocket) {
await this.wsNotificationService.maybeTouchPresence(client);
}
async broadcastToFriends(friends: Set<string>, event: string, payload: any) { async broadcastToFriends(friends: Set<string>, event: string, payload: any) {
const friendIds = Array.from(friends); const friendIds = Array.from(friends);
const friendSockets = const friendSockets =

View File

@@ -36,20 +36,4 @@ export class RedisHandler {
this.logger.error('Error handling friend cache update message', error); this.logger.error('Error handling friend cache update message', error);
} }
} }
async handleUserProfileCacheInvalidateMessage(
message: string,
): Promise<void> {
try {
const data = JSON.parse(message) as {
userId: string;
};
await this.wsNotificationService.clearSenderNameCache(data.userId);
} catch (error) {
this.logger.error(
'Error handling user profile cache invalidate message',
error,
);
}
}
} }

View File

@@ -22,5 +22,4 @@ export const WS_EVENT = {
export const REDIS_CHANNEL = { export const REDIS_CHANNEL = {
ACTIVE_DOLL_UPDATE: 'active-doll-update', ACTIVE_DOLL_UPDATE: 'active-doll-update',
FRIEND_CACHE_UPDATE: 'friend-cache-update', FRIEND_CACHE_UPDATE: 'friend-cache-update',
USER_PROFILE_CACHE_INVALIDATE: 'user-profile-cache-invalidate',
} as const; } as const;

View File

@@ -1,14 +1,10 @@
import { Inject, Injectable, Logger } from '@nestjs/common'; import { Injectable, Logger, Inject } from '@nestjs/common';
import { OnEvent } from '@nestjs/event-emitter';
import Redis from 'ioredis'; import Redis from 'ioredis';
import { Server } from 'socket.io'; import { Server } from 'socket.io';
import { UserEvents } from '../../users/events/user.events'; import { UserSocketService } from './user-socket.service';
import type { AuthenticatedSocket } from '../../types/socket'; import type { AuthenticatedSocket } from '../../types/socket';
import { REDIS_CLIENT } from '../../database/redis.module'; import { REDIS_CLIENT } from '../../database/redis.module';
import { REDIS_CHANNEL } from './ws-events'; import { REDIS_CHANNEL } from './ws-events';
import { UserSocketService } from './user-socket.service';
const PRESENCE_UPDATE_THROTTLE_MS = 15_000;
@Injectable() @Injectable()
export class WsNotificationService { export class WsNotificationService {
@@ -46,11 +42,6 @@ export class WsNotificationService {
this.io.to(socketId).emit(event, payload); this.io.to(socketId).emit(event, payload);
} }
@OnEvent(UserEvents.PROFILE_UPDATED)
async handleUserProfileUpdated(payload: { userId: string }) {
await this.publishUserProfileCacheInvalidate(payload.userId);
}
async updateFriendsCache( async updateFriendsCache(
userId: string, userId: string,
friendId: string, friendId: string,
@@ -135,63 +126,4 @@ export class WsNotificationService {
); );
} }
} }
async publishUserProfileCacheInvalidate(userId: string) {
if (this.redisClient) {
try {
await this.redisClient.publish(
REDIS_CHANNEL.USER_PROFILE_CACHE_INVALIDATE,
JSON.stringify({ userId }),
);
return;
} catch (error) {
this.logger.warn(
'Redis publish failed for user profile cache invalidate; applying local update only',
error as Error,
);
}
}
await this.clearSenderNameCache(userId);
}
async clearSenderNameCache(userId: string) {
if (!this.io) {
return;
}
const socketId = await this.userSocketService.getSocket(userId);
if (!socketId) {
return;
}
const socket = this.io.sockets.sockets.get(socketId) as
| AuthenticatedSocket
| undefined;
if (!socket?.data) {
return;
}
socket.data.senderName = undefined;
socket.data.senderNameCachedAt = undefined;
}
async maybeTouchPresence(client: AuthenticatedSocket): Promise<void> {
const userId = client.data.userId;
if (!userId) {
return;
}
const now = Date.now();
const lastSeenAt = client.data.lastSeenAt;
if (
typeof lastSeenAt === 'number' &&
now - lastSeenAt < PRESENCE_UPDATE_THROTTLE_MS
) {
return;
}
client.data.lastSeenAt = now;
await this.userSocketService.touchLastSeen(userId);
}
} }