resid pt 3: friendship checks & auth session reads

This commit is contained in:
2026-03-30 19:34:03 +08:00
parent d12d3e1ec7
commit ad09de2b11
11 changed files with 295 additions and 48 deletions

View File

@@ -18,6 +18,8 @@ CACHE_KEY_PREFIX=friendolls
CACHE_DEFAULT_TTL_SECONDS=60 CACHE_DEFAULT_TTL_SECONDS=60
CACHE_MAX_TTL_SECONDS=86400 CACHE_MAX_TTL_SECONDS=86400
CACHE_METRICS_LOG_INTERVAL_MS=60000 CACHE_METRICS_LOG_INTERVAL_MS=60000
# Max number of cache keys tracked per invalidation tag
CACHE_TAG_MAX_ENTRIES=5000
# JWT Configuration # JWT Configuration
JWT_SECRET=replace-with-strong-random-secret JWT_SECRET=replace-with-strong-random-secret

View File

@@ -83,6 +83,13 @@ function validateEnvironment(
throw new Error('REDIS_CONNECT_TIMEOUT_MS must be a positive number'); throw new Error('REDIS_CONNECT_TIMEOUT_MS must be a positive number');
} }
validateOptionalPositiveNumber(config, 'THROTTLE_TTL');
validateOptionalPositiveNumber(config, 'THROTTLE_LIMIT');
validateOptionalPositiveNumber(config, 'CACHE_DEFAULT_TTL_SECONDS');
validateOptionalPositiveNumber(config, 'CACHE_MAX_TTL_SECONDS');
validateOptionalPositiveNumber(config, 'CACHE_METRICS_LOG_INTERVAL_MS');
validateOptionalPositiveNumber(config, 'CACHE_TAG_MAX_ENTRIES');
validateOptionalProvider(config, 'GOOGLE'); validateOptionalProvider(config, 'GOOGLE');
validateOptionalProvider(config, 'DISCORD'); validateOptionalProvider(config, 'DISCORD');
@@ -109,6 +116,20 @@ function validateOptionalProvider(
} }
} }
function validateOptionalPositiveNumber(
config: Record<string, unknown>,
key: string,
): void {
const value = config[key];
if (value === undefined || value === null || value === '') {
return;
}
if (!Number.isFinite(Number(value)) || Number(value) <= 0) {
throw new Error(`${key} must be a positive number`);
}
}
/** /**
* Root Application Module * Root Application Module
* *

View File

@@ -37,6 +37,14 @@ import {
} from './auth.utils'; } from './auth.utils';
import type { SsoProvider } from './dto/sso-provider'; import type { SsoProvider } from './dto/sso-provider';
import { UserEvents } from '../users/events/user.events'; import { UserEvents } from '../users/events/user.events';
import { CacheService } from '../common/cache/cache.service';
import {
authSessionUserTag,
authSessionCacheKey,
CACHE_NAMESPACE,
CACHE_TTL_SECONDS,
} from '../common/cache/cache-keys';
import { CacheTagsService } from '../common/cache/cache-tags.service';
interface SsoStateClaims { interface SsoStateClaims {
provider: SsoProvider; provider: SsoProvider;
@@ -45,6 +53,28 @@ interface SsoStateClaims {
typ: 'sso_state'; typ: 'sso_state';
} }
interface AuthSessionWithUser {
id: string;
refresh_token_hash: string;
expires_at: Date;
revoked_at: Date | null;
provider: 'GOOGLE' | 'DISCORD' | null;
user_id: string;
email: string;
roles: string[];
}
interface CachedAuthSessionWithUser {
id: string;
refresh_token_hash: string;
expires_at: string;
revoked_at: string | null;
provider: 'GOOGLE' | 'DISCORD' | null;
user_id: string;
email: string;
roles: string[];
}
@Injectable() @Injectable()
export class AuthService { export class AuthService {
private readonly logger = new Logger(AuthService.name); private readonly logger = new Logger(AuthService.name);
@@ -59,6 +89,8 @@ export class AuthService {
private readonly prisma: PrismaService, private readonly prisma: PrismaService,
private readonly configService: ConfigService, private readonly configService: ConfigService,
private readonly eventEmitter: EventEmitter2, private readonly eventEmitter: EventEmitter2,
private readonly cacheService: CacheService,
private readonly cacheTagsService: CacheTagsService,
) { ) {
this.jwtSecret = this.configService.get<string>('JWT_SECRET') || ''; this.jwtSecret = this.configService.get<string>('JWT_SECRET') || '';
this.jwtIssuer = this.jwtIssuer =
@@ -162,7 +194,7 @@ export class AuthService {
} }
if (session.refresh_token_hash !== refreshTokenHash) { if (session.refresh_token_hash !== refreshTokenHash) {
await this.revokeSessionOnReplay(session.id); await this.revokeSessionOnReplay(session.id, session.user_id);
throw new UnauthorizedException('Invalid refresh token'); throw new UnauthorizedException('Invalid refresh token');
} }
@@ -174,7 +206,7 @@ export class AuthService {
); );
if (!updated) { if (!updated) {
await this.revokeSessionOnReplay(session.id); await this.revokeSessionOnReplay(session.id, session.user_id);
throw new UnauthorizedException('Invalid refresh token'); throw new UnauthorizedException('Invalid refresh token');
} }
@@ -560,28 +592,34 @@ export class AuthService {
return rows[0] ?? null; return rows[0] ?? null;
} }
private async getSessionWithUser(sessionId: string): Promise<{ private async getSessionWithUser(
id: string; sessionId: string,
refresh_token_hash: string; ): Promise<AuthSessionWithUser | null> {
expires_at: Date; const sessionCacheKey = this.getAuthSessionCacheKey(sessionId);
revoked_at: Date | null; const cachedSessionRaw = await this.cacheService.get(sessionCacheKey);
provider: 'GOOGLE' | 'DISCORD' | null;
user_id: string; if (cachedSessionRaw) {
email: string; try {
roles: string[]; const cachedSession = JSON.parse(
} | null> { cachedSessionRaw,
const rows = await this.prisma.$queryRaw< ) as CachedAuthSessionWithUser;
Array<{ return {
id: string; ...cachedSession,
refresh_token_hash: string; expires_at: new Date(cachedSession.expires_at),
expires_at: Date; revoked_at: cachedSession.revoked_at
revoked_at: Date | null; ? new Date(cachedSession.revoked_at)
provider: 'GOOGLE' | 'DISCORD' | null; : null,
user_id: string; };
email: string; } catch (error) {
roles: string[]; this.cacheService.recordError(
}> 'auth session parse',
>` sessionCacheKey,
error,
);
}
}
const rows = await this.prisma.$queryRaw<Array<AuthSessionWithUser>>`
SELECT s.id, s.refresh_token_hash, s.expires_at, s.revoked_at, s.provider, s.user_id, u.email, u.roles SELECT s.id, s.refresh_token_hash, s.expires_at, s.revoked_at, s.provider, s.user_id, u.email, u.roles
FROM auth_sessions AS s FROM auth_sessions AS s
INNER JOIN users AS u ON u.id = s.user_id INNER JOIN users AS u ON u.id = s.user_id
@@ -589,7 +627,29 @@ export class AuthService {
LIMIT 1 LIMIT 1
`; `;
return rows[0] ?? null; const session = rows[0] ?? null;
if (!session) {
return null;
}
const cachePayload: CachedAuthSessionWithUser = {
...session,
expires_at: session.expires_at.toISOString(),
revoked_at: session.revoked_at ? session.revoked_at.toISOString() : null,
};
await this.cacheService.set(
sessionCacheKey,
JSON.stringify(cachePayload),
CACHE_TTL_SECONDS.AUTH_SESSION,
);
await this.cacheTagsService.rememberKeyForTag(
CACHE_NAMESPACE.AUTH_SESSION,
authSessionUserTag(session.user_id),
authSessionCacheKey(session.id),
);
return session;
} }
private async rotateRefreshSession( private async rotateRefreshSession(
@@ -597,6 +657,8 @@ export class AuthService {
refreshTokenHash: string, refreshTokenHash: string,
nextRefreshToken: string, nextRefreshToken: string,
): Promise<boolean> { ): Promise<boolean> {
await this.cacheService.del(this.getAuthSessionCacheKey(sessionId));
const rows = await this.prisma.$queryRaw<Array<{ id: string }>>` const rows = await this.prisma.$queryRaw<Array<{ id: string }>>`
UPDATE auth_sessions UPDATE auth_sessions
SET refresh_token_hash = ${sha256(nextRefreshToken)}, SET refresh_token_hash = ${sha256(nextRefreshToken)},
@@ -610,6 +672,10 @@ export class AuthService {
RETURNING id RETURNING id
`; `;
if (rows.length === 1) {
await this.cacheService.del(this.getAuthSessionCacheKey(sessionId));
}
return rows.length === 1; return rows.length === 1;
} }
@@ -617,6 +683,8 @@ export class AuthService {
sessionId: string, sessionId: string,
refreshTokenHash: string, refreshTokenHash: string,
): Promise<boolean> { ): Promise<boolean> {
await this.cacheService.del(this.getAuthSessionCacheKey(sessionId));
const rows = await this.prisma.$queryRaw<Array<{ id: string }>>` const rows = await this.prisma.$queryRaw<Array<{ id: string }>>`
UPDATE auth_sessions UPDATE auth_sessions
SET revoked_at = NOW(), SET revoked_at = NOW(),
@@ -628,17 +696,41 @@ export class AuthService {
RETURNING id RETURNING id
`; `;
if (rows.length === 1) {
await this.cacheService.del(this.getAuthSessionCacheKey(sessionId));
}
return rows.length === 1; return rows.length === 1;
} }
private async revokeSessionOnReplay(sessionId: string): Promise<void> { private async revokeSessionOnReplay(
sessionId: string,
userId: string,
): Promise<void> {
await this.cacheService.del(this.getAuthSessionCacheKey(sessionId));
await this.revokeAllUserSessions(userId);
}
private async revokeAllUserSessions(userId: string): Promise<void> {
await this.prisma.$queryRaw<Array<{ id: string }>>` await this.prisma.$queryRaw<Array<{ id: string }>>`
UPDATE auth_sessions UPDATE auth_sessions
SET revoked_at = NOW(), SET revoked_at = NOW(),
updated_at = NOW() updated_at = NOW()
WHERE id = ${sessionId} WHERE user_id = ${userId}
AND revoked_at IS NULL AND revoked_at IS NULL
RETURNING id RETURNING id
`; `;
await this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.AUTH_SESSION,
authSessionUserTag(userId),
);
}
private getAuthSessionCacheKey(sessionId: string): string {
return this.cacheService.getNamespacedKey(
CACHE_NAMESPACE.AUTH_SESSION,
authSessionCacheKey(sessionId),
);
} }
} }

View File

@@ -20,6 +20,14 @@ const DEFAULT_REVOKED_RETENTION_DAYS = 7;
const CLEANUP_LOCK_KEY = 'lock:auth:cleanup'; const CLEANUP_LOCK_KEY = 'lock:auth:cleanup';
const CLEANUP_LOCK_TTL_MS = 55_000; const CLEANUP_LOCK_TTL_MS = 55_000;
const RELEASE_LOCK_SCRIPT = `
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
`;
@Injectable() @Injectable()
export class AuthCleanupService implements OnModuleInit, OnModuleDestroy { export class AuthCleanupService implements OnModuleInit, OnModuleDestroy {
private readonly logger = new Logger(AuthCleanupService.name); private readonly logger = new Logger(AuthCleanupService.name);
@@ -141,10 +149,12 @@ export class AuthCleanupService implements OnModuleInit, OnModuleDestroy {
} finally { } finally {
if (lockAcquired && this.redisClient) { if (lockAcquired && this.redisClient) {
try { try {
const currentLockValue = await this.redisClient.get(CLEANUP_LOCK_KEY); await this.redisClient.eval(
if (currentLockValue === lockToken) { RELEASE_LOCK_SCRIPT,
await this.redisClient.del(CLEANUP_LOCK_KEY); 1,
} CLEANUP_LOCK_KEY,
lockToken,
);
} catch (error) { } catch (error) {
this.logger.warn( this.logger.warn(
'Failed to release auth cleanup lock', 'Failed to release auth cleanup lock',

View File

@@ -4,6 +4,8 @@ export const CACHE_NAMESPACE = {
FRIENDS_LIST: 'friends-list', FRIENDS_LIST: 'friends-list',
DOLLS_LIST: 'dolls-list', DOLLS_LIST: 'dolls-list',
USERS_SEARCH: 'users-search', USERS_SEARCH: 'users-search',
FRIENDSHIP_CHECK: 'friendship-check',
AUTH_SESSION: 'auth-session',
} as const; } as const;
function normalizeKeyPart(value: string | undefined): string { function normalizeKeyPart(value: string | undefined): string {
@@ -18,6 +20,8 @@ export const CACHE_TTL_SECONDS = {
FRIENDS_LIST: 30, FRIENDS_LIST: 30,
DOLLS_LIST: 30, DOLLS_LIST: 30,
USERS_SEARCH: 20, USERS_SEARCH: 20,
FRIENDSHIP_CHECK: 120,
AUTH_SESSION: 30,
} as const; } as const;
export function friendsListCacheKey(userId: string): string { export function friendsListCacheKey(userId: string): string {
@@ -56,6 +60,25 @@ export function usersSearchCacheKey(
export const USERS_SEARCH_GLOBAL_TAG = 'global'; export const USERS_SEARCH_GLOBAL_TAG = 'global';
export function friendshipCheckCacheKey(
userId: string,
friendId: string,
): string {
return `${normalizeKeyPart(userId)}:${normalizeKeyPart(friendId)}`;
}
export function friendshipCheckUserTag(userId: string): string {
return `user:${normalizeKeyPart(userId)}`;
}
export function authSessionCacheKey(sessionId: string): string {
return normalizeKeyPart(sessionId);
}
export function authSessionUserTag(userId: string): string {
return `user:${normalizeKeyPart(userId)}`;
}
export function usersSearchUserTag(userId: string): string { export function usersSearchUserTag(userId: string): string {
return `user:${normalizeKeyPart(userId)}`; return `user:${normalizeKeyPart(userId)}`;
} }

View File

@@ -1,11 +1,24 @@
import { Injectable } from '@nestjs/common'; import { Injectable } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { CacheService } from './cache.service'; import { CacheService } from './cache.service';
import { parsePositiveInteger } from '../config/env.utils';
const CACHE_TAG_SET_TTL_SECONDS = 86_400; const CACHE_TAG_SET_TTL_SECONDS = 86_400;
const DEFAULT_CACHE_TAG_MAX_ENTRIES = 5_000;
@Injectable() @Injectable()
export class CacheTagsService { export class CacheTagsService {
constructor(private readonly cacheService: CacheService) {} private readonly cacheTagMaxEntries: number;
constructor(
private readonly cacheService: CacheService,
private readonly configService: ConfigService,
) {
this.cacheTagMaxEntries = parsePositiveInteger(
this.configService.get<string>('CACHE_TAG_MAX_ENTRIES'),
DEFAULT_CACHE_TAG_MAX_ENTRIES,
);
}
async rememberKeyForTag( async rememberKeyForTag(
namespace: string, namespace: string,
@@ -28,6 +41,11 @@ export class CacheTagsService {
redisClient.sadd(tagSetKey, keyWithNamespace), redisClient.sadd(tagSetKey, keyWithNamespace),
redisClient.expire(tagSetKey, CACHE_TAG_SET_TTL_SECONDS), redisClient.expire(tagSetKey, CACHE_TAG_SET_TTL_SECONDS),
]); ]);
const size = await redisClient.scard(tagSetKey);
if (size > this.cacheTagMaxEntries) {
await this.trimTagSet(tagSetKey, size - this.cacheTagMaxEntries);
}
} catch (error) { } catch (error) {
this.cacheService.recordError('tag remember', tagSetKey, error); this.cacheService.recordError('tag remember', tagSetKey, error);
} }
@@ -63,4 +81,26 @@ export class CacheTagsService {
`${namespace}:${tag}`, `${namespace}:${tag}`,
); );
} }
private async trimTagSet(
tagSetKey: string,
countToDrop: number,
): Promise<void> {
const redisClient = this.cacheService.getRedisClient();
if (!redisClient || countToDrop <= 0) {
return;
}
try {
const sample = await redisClient.srandmember(tagSetKey, countToDrop);
const members = Array.isArray(sample) ? sample : [sample].filter(Boolean);
if (members.length === 0) {
return;
}
await redisClient.srem(tagSetKey, ...members);
} catch (error) {
this.cacheService.recordError('tag trim', tagSetKey, error);
}
}
} }

View File

@@ -1,3 +1,4 @@
export { CacheModule } from './cache.module'; export { CacheModule } from './cache.module';
export { CacheService } from './cache.service'; export { CacheService } from './cache.service';
export { CacheTagsService } from './cache-tags.service';
export { RedisThrottlerStorage } from './redis-throttler.storage'; export { RedisThrottlerStorage } from './redis-throttler.storage';

View File

@@ -6,9 +6,15 @@ import { DollsNotificationService } from './dolls-notification.service';
import { DatabaseModule } from '../database/database.module'; import { DatabaseModule } from '../database/database.module';
import { AuthModule } from '../auth/auth.module'; import { AuthModule } from '../auth/auth.module';
import { WsModule } from '../ws/ws.module'; import { WsModule } from '../ws/ws.module';
import { FriendsModule } from '../friends/friends.module';
@Module({ @Module({
imports: [DatabaseModule, AuthModule, forwardRef(() => WsModule)], imports: [
DatabaseModule,
AuthModule,
FriendsModule,
forwardRef(() => WsModule),
],
controllers: [DollsController], controllers: [DollsController],
providers: [ providers: [
DollsService, DollsService,

View File

@@ -24,6 +24,7 @@ import {
dollsListOwnerTag, dollsListOwnerTag,
dollsListViewerTag, dollsListViewerTag,
} from '../common/cache/cache-keys'; } from '../common/cache/cache-keys';
import { FriendsService } from '../friends/friends.service';
@Injectable() @Injectable()
export class DollsService { export class DollsService {
@@ -34,16 +35,9 @@ export class DollsService {
private readonly eventEmitter: EventEmitter2, private readonly eventEmitter: EventEmitter2,
private readonly cacheService: CacheService, private readonly cacheService: CacheService,
private readonly cacheTagsService: CacheTagsService, private readonly cacheTagsService: CacheTagsService,
private readonly friendsService: FriendsService,
) {} ) {}
async getFriendIds(userId: string): Promise<string[]> {
const friendships = await this.prisma.friendship.findMany({
where: { userId },
select: { friendId: true },
});
return friendships.map((f) => f.friendId);
}
async create( async create(
requestingUserId: string, requestingUserId: string,
createDollDto: CreateDollDto, createDollDto: CreateDollDto,
@@ -144,8 +138,11 @@ export class DollsService {
} }
// If requesting someone else's dolls, check friendship // If requesting someone else's dolls, check friendship
const friendIds = await this.getFriendIds(requestingUserId); const isFriend = await this.friendsService.areFriends(
if (!friendIds.includes(ownerId)) { requestingUserId,
ownerId,
);
if (!isFriend) {
throw new ForbiddenException('You are not friends with this user'); throw new ForbiddenException('You are not friends with this user');
} }
@@ -161,13 +158,9 @@ export class DollsService {
} }
async findOne(id: string, requestingUserId: string): Promise<Doll> { async findOne(id: string, requestingUserId: string): Promise<Doll> {
const friendIds = await this.getFriendIds(requestingUserId);
const accessibleUserIds = [requestingUserId, ...friendIds];
const doll = await this.prisma.doll.findFirst({ const doll = await this.prisma.doll.findFirst({
where: { where: {
id, id,
userId: { in: accessibleUserIds },
deletedAt: null, deletedAt: null,
}, },
}); });
@@ -178,6 +171,18 @@ export class DollsService {
); );
} }
if (doll.userId !== requestingUserId) {
const isFriend = await this.friendsService.areFriends(
requestingUserId,
doll.userId,
);
if (!isFriend) {
throw new NotFoundException(
`Doll with ID ${id} not found or access denied`,
);
}
}
return doll; return doll;
} }

View File

@@ -4,6 +4,7 @@ import { CacheTagsService } from '../common/cache/cache-tags.service';
import { import {
CACHE_NAMESPACE, CACHE_NAMESPACE,
dollsListViewerTag, dollsListViewerTag,
friendshipCheckUserTag,
friendsListDependsOnUserTag, friendsListDependsOnUserTag,
friendsListOwnerTag, friendsListOwnerTag,
} from '../common/cache/cache-keys'; } from '../common/cache/cache-keys';
@@ -60,6 +61,14 @@ export class FriendsCacheInvalidationService {
CACHE_NAMESPACE.DOLLS_LIST, CACHE_NAMESPACE.DOLLS_LIST,
dollsListViewerTag(secondUserId), dollsListViewerTag(secondUserId),
), ),
this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.FRIENDSHIP_CHECK,
friendshipCheckUserTag(firstUserId),
),
this.cacheTagsService.invalidateTag(
CACHE_NAMESPACE.FRIENDSHIP_CHECK,
friendshipCheckUserTag(secondUserId),
),
]); ]);
} }
} }

View File

@@ -20,6 +20,8 @@ import { CacheTagsService } from '../common/cache/cache-tags.service';
import { import {
CACHE_NAMESPACE, CACHE_NAMESPACE,
CACHE_TTL_SECONDS, CACHE_TTL_SECONDS,
friendshipCheckCacheKey,
friendshipCheckUserTag,
friendsListCacheKey, friendsListCacheKey,
friendsListDependsOnUserTag, friendsListDependsOnUserTag,
friendsListOwnerTag, friendsListOwnerTag,
@@ -378,6 +380,21 @@ export class FriendsService {
} }
async areFriends(userId: string, friendId: string): Promise<boolean> { async areFriends(userId: string, friendId: string): Promise<boolean> {
const cacheKey = friendshipCheckCacheKey(userId, friendId);
const namespacedKey = this.cacheService.getNamespacedKey(
CACHE_NAMESPACE.FRIENDSHIP_CHECK,
cacheKey,
);
const cached = await this.cacheService.get(namespacedKey);
if (cached === '1') {
return true;
}
if (cached === '0') {
return false;
}
const friendship = await this.prisma.friendship.findFirst({ const friendship = await this.prisma.friendship.findFirst({
where: { where: {
userId, userId,
@@ -385,6 +402,27 @@ export class FriendsService {
}, },
}); });
return !!friendship; const areFriends = !!friendship;
await this.cacheService.set(
namespacedKey,
areFriends ? '1' : '0',
CACHE_TTL_SECONDS.FRIENDSHIP_CHECK,
);
await Promise.all([
this.cacheTagsService.rememberKeyForTag(
CACHE_NAMESPACE.FRIENDSHIP_CHECK,
friendshipCheckUserTag(userId),
cacheKey,
),
this.cacheTagsService.rememberKeyForTag(
CACHE_NAMESPACE.FRIENDSHIP_CHECK,
friendshipCheckUserTag(friendId),
cacheKey,
),
]);
return areFriends;
} }
} }