Compare commits

...

24 Commits

Author SHA1 Message Date
4c1670c5af fix(ws): restore interaction event subscription in gateway 2026-03-30 03:30:39 +08:00
7efd7a4deb reverted prisma cusotm output path 2026-03-30 02:04:42 +08:00
4f9bb6adb7 updated pnpm lock file 2026-03-29 19:54:20 +08:00
c2a6783f26 test: update tests for WS perf and auth changes 2026-03-29 19:32:01 +08:00
a13e8d1c35 chore: bump version to 0.1.0 2026-03-29 19:30:21 +08:00
f5f1c8ac42 fix(ws): make notification publish fail open with local fallback 2026-03-29 19:29:34 +08:00
4464328c0a perf(ws): reduce user lookups and cache sender metadata 2026-03-29 19:29:02 +08:00
8e3f1b5bd4 fix(db): close Prisma side pool on shutdown 2026-03-29 19:28:25 +08:00
fd2043ba7e feat(auth): add scheduled auth artifact cleanup 2026-03-29 19:27:38 +08:00
765d4507c9 feat(security): production bootstrap hardening 2026-03-29 19:26:59 +08:00
6793460d31 feat(ws): harden Redis socket adapter lifecycle 2026-03-29 18:50:32 +08:00
4dfefadc9e feat(redis): harden Redis module startup and shutdown behavior 2026-03-29 18:48:25 +08:00
114d6ff2f5 refactor(config): add env parsing helpers and tighten startup validation 2026-03-29 18:47:50 +08:00
3ce15d9762 enforce 50 char content limit for message interaction 2026-03-20 22:50:39 +08:00
5e3001b9bf attempt to fix version number display in prod 2026-03-20 03:07:25 +08:00
16a32e82d6 5req/s universal rate limit 2026-03-19 21:53:53 +08:00
96135493a6 restored prisma config ts file & added copy in dockerfile 2026-03-19 14:01:34 +08:00
85b7d0ee6f rename prisma config from ts to js to fix docker compose issues 2026-03-19 03:18:14 +08:00
0d0cd6d41b add copying of prisma files to dockerfile 2026-03-19 03:08:42 +08:00
1b9daa9e1f add dummy database url so dockerbuild passes 2026-03-18 23:24:45 +08:00
3b6d38692f added prima generate to dockerfile 2026-03-18 23:16:02 +08:00
Adam C
39972af899 Add GitHub Actions workflow for server deployment 2026-03-18 23:05:37 +08:00
6cc102cfc1 dockerfile 2026-03-18 14:13:02 +08:00
32746756d4 auto-populate missing username field with email local-part 2026-03-17 19:27:41 +08:00
31 changed files with 1035 additions and 210 deletions

View File

@@ -9,6 +9,9 @@ DATABASE_URL="postgresql://postgres:postgres@localhost:5432/friendolls_dev?schem
# Redis
REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_REQUIRED=false
REDIS_CONNECT_TIMEOUT_MS=5000
REDIS_STARTUP_RETRIES=10
# JWT Configuration
JWT_SECRET=replace-with-strong-random-secret
@@ -16,6 +19,11 @@ JWT_ISSUER=friendolls
JWT_AUDIENCE=friendolls-api
JWT_EXPIRES_IN_SECONDS=3600
# Auth cleanup
AUTH_CLEANUP_ENABLED=true
AUTH_CLEANUP_INTERVAL_MS=900000
AUTH_SESSION_REVOKED_RETENTION_DAYS=7
# Google OAuth
GOOGLE_CLIENT_ID="replace-with-google-client-id"
GOOGLE_CLIENT_SECRET="replace-with-google-client-secret"

28
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,28 @@
name: Deploy Server
on:
push:
branches: [release]
jobs:
build-and-push:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Log in to ghcr.io
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and push
uses: docker/build-push-action@v5
with:
push: true
tags: ghcr.io/wind-explorer/friendolls-server:latest
- name: Trigger homelab deploy
run: |
curl -X POST https://wh.adamcv.com/hooks/friendolls-server-redeploy \
-H "Authorization: Bearer ${{ secrets.DEPLOY_WEBHOOK_SECRET }}"

16
Dockerfile Normal file
View File

@@ -0,0 +1,16 @@
FROM node:20-alpine AS builder
WORKDIR /app
COPY package.json pnpm-lock.yaml ./
RUN npm i -g pnpm && pnpm install --frozen-lockfile
COPY . .
RUN DATABASE_URL="postgresql://dummy:dummy@localhost:5432/dummy" pnpm prisma:generate
RUN pnpm build
FROM node:20-alpine
WORKDIR /app
COPY --from=builder /app/dist ./dist
COPY --from=builder /app/node_modules ./node_modules
COPY --from=builder /app/prisma ./prisma
COPY --from=builder /app/prisma.config.ts ./prisma.config.ts
COPY --from=builder /app/package.json ./package.json
CMD ["node", "dist/src/main.js"]

View File

@@ -1,6 +1,6 @@
{
"name": "friendolls-server",
"version": "0.0.1",
"version": "0.1.0",
"description": "",
"author": "",
"private": true,
@@ -49,6 +49,7 @@
"jsonwebtoken": "^9.0.2",
"passport": "^0.7.0",
"passport-discord": "^0.1.4",
"helmet": "^8.1.0",
"passport-google-oauth20": "^2.0.0",
"passport-jwt": "^4.0.1",
"pg": "^8.16.3",

9
pnpm-lock.yaml generated
View File

@@ -62,6 +62,9 @@ importers:
dotenv:
specifier: ^17.2.3
version: 17.2.3
helmet:
specifier: ^8.1.0
version: 8.1.0
ioredis:
specifier: ^5.8.2
version: 5.8.2
@@ -2298,6 +2301,10 @@ packages:
resolution: {integrity: sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==}
engines: {node: '>= 0.4'}
helmet@8.1.0:
resolution: {integrity: sha512-jOiHyAZsmnr8LqoPGmCjYAaiuWwjAPLgY8ZX2XrmHawt99/u1y6RgrZMTeoPfpUbV96HOalYgz1qzkRbw54Pmg==}
engines: {node: '>=18.0.0'}
hono@4.7.10:
resolution: {integrity: sha512-QkACju9MiN59CKSY5JsGZCYmPZkA6sIW6OFCUp7qDjZu6S6KHtJHhAc9Uy9mV9F8PJ1/HQ3ybZF2yjCa/73fvQ==}
engines: {node: '>=16.9.0'}
@@ -6227,6 +6234,8 @@ snapshots:
dependencies:
function-bind: 1.1.2
helmet@8.1.0: {}
hono@4.7.10: {}
html-escaper@2.0.2: {}

View File

@@ -1,14 +1,14 @@
// This file was generated by Prisma and assumes you have installed the following:
// npm install --save-dev prisma dotenv
import "dotenv/config";
import { defineConfig, env } from "prisma/config";
import 'dotenv/config';
import { defineConfig, env } from 'prisma/config';
export default defineConfig({
schema: "prisma/schema.prisma",
schema: 'prisma/schema.prisma',
migrations: {
path: "prisma/migrations",
path: 'prisma/migrations',
},
datasource: {
url: env("DATABASE_URL"),
url: env('DATABASE_URL'),
},
});

View File

@@ -0,0 +1,2 @@
-- This migration was generated locally but superseded before it was applied.
-- It remains as a no-op to preserve Prisma migration history.

View File

@@ -0,0 +1,27 @@
UPDATE users
SET username = lower(split_part(email, '@', 1))
WHERE username IS NULL OR btrim(username) = '';
WITH ranked_users AS (
SELECT id,
username,
row_number() OVER (PARTITION BY username ORDER BY created_at, id) AS rn
FROM users
),
deduplicated AS (
SELECT id,
CASE
WHEN rn = 1 THEN username
ELSE left(username, greatest(1, 24 - char_length(rn::text))) || rn::text
END AS next_username
FROM ranked_users
)
UPDATE users u
SET username = d.next_username
FROM deduplicated d
WHERE u.id = d.id;
ALTER TABLE users
ALTER COLUMN username SET NOT NULL;
CREATE UNIQUE INDEX users_username_key ON users (username);

View File

@@ -1,8 +1,19 @@
DO $$
BEGIN
IF EXISTS (
SELECT 1
FROM (
SELECT LOWER(TRIM("email")) AS normalized_email
FROM "users"
GROUP BY LOWER(TRIM("email"))
HAVING COUNT(*) > 1
) duplicates
) THEN
RAISE EXCEPTION
'Cannot normalize user emails: duplicate values would conflict after lowercasing/trimming';
END IF;
END $$;
UPDATE "users"
SET "email" = LOWER(TRIM("email"))
WHERE "email" <> LOWER(TRIM("email"));
ALTER TABLE "users"
DROP CONSTRAINT IF EXISTS "users_email_key";
CREATE UNIQUE INDEX "users_email_key" ON "users"(LOWER("email"));

View File

@@ -24,8 +24,8 @@ model User {
/// User's email address
email String @unique
/// User's preferred username from Keycloak
username String?
/// User's preferred app handle used for discovery
username String @unique
/// URL to user's profile picture
picture String?
@@ -49,10 +49,10 @@ model User {
activeDollId String? @map("active_doll_id")
activeDoll Doll? @relation("ActiveDoll", fields: [activeDollId], references: [id])
sentFriendRequests FriendRequest[] @relation("SentFriendRequests")
receivedFriendRequests FriendRequest[] @relation("ReceivedFriendRequests")
userFriendships Friendship[] @relation("UserFriendships")
friendFriendships Friendship[] @relation("FriendFriendships")
sentFriendRequests FriendRequest[] @relation("SentFriendRequests")
receivedFriendRequests FriendRequest[] @relation("ReceivedFriendRequests")
userFriendships Friendship[] @relation("UserFriendships")
friendFriendships Friendship[] @relation("FriendFriendships")
dolls Doll[]
authIdentities AuthIdentity[]
authSessions AuthSession[]
@@ -62,17 +62,17 @@ model User {
}
model AuthIdentity {
id String @id @default(uuid())
provider AuthProvider
providerSubject String @map("provider_subject")
providerEmail String? @map("provider_email")
providerName String? @map("provider_name")
providerUsername String? @map("provider_username")
providerPicture String? @map("provider_picture")
emailVerified Boolean @default(false) @map("email_verified")
createdAt DateTime @default(now()) @map("created_at")
updatedAt DateTime @updatedAt @map("updated_at")
userId String @map("user_id")
id String @id @default(uuid())
provider AuthProvider
providerSubject String @map("provider_subject")
providerEmail String? @map("provider_email")
providerName String? @map("provider_name")
providerUsername String? @map("provider_username")
providerPicture String? @map("provider_picture")
emailVerified Boolean @default(false) @map("email_verified")
createdAt DateTime @default(now()) @map("created_at")
updatedAt DateTime @updatedAt @map("updated_at")
userId String @map("user_id")
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@ -82,14 +82,14 @@ model AuthIdentity {
}
model AuthSession {
id String @id @default(uuid())
id String @id @default(uuid())
provider AuthProvider?
refreshTokenHash String @unique @map("refresh_token_hash")
expiresAt DateTime @map("expires_at")
revokedAt DateTime? @map("revoked_at")
createdAt DateTime @default(now()) @map("created_at")
updatedAt DateTime @updatedAt @map("updated_at")
userId String @map("user_id")
refreshTokenHash String @unique @map("refresh_token_hash")
expiresAt DateTime @map("expires_at")
revokedAt DateTime? @map("revoked_at")
createdAt DateTime @default(now()) @map("created_at")
updatedAt DateTime @updatedAt @map("updated_at")
userId String @map("user_id")
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@ -98,13 +98,13 @@ model AuthSession {
}
model AuthExchangeCode {
id String @id @default(uuid())
provider AuthProvider
codeHash String @unique @map("code_hash")
expiresAt DateTime @map("expires_at")
consumedAt DateTime? @map("consumed_at")
createdAt DateTime @default(now()) @map("created_at")
userId String @map("user_id")
id String @id @default(uuid())
provider AuthProvider
codeHash String @unique @map("code_hash")
expiresAt DateTime @map("expires_at")
consumedAt DateTime? @map("consumed_at")
createdAt DateTime @default(now()) @map("created_at")
userId String @map("user_id")
user User @relation(fields: [userId], references: [id], onDelete: Cascade)

View File

@@ -1,7 +1,8 @@
import { Module } from '@nestjs/common';
import { ConfigModule, ConfigService } from '@nestjs/config';
import { APP_GUARD } from '@nestjs/core';
import { EventEmitterModule } from '@nestjs/event-emitter';
import { ThrottlerModule } from '@nestjs/throttler';
import { ThrottlerGuard, ThrottlerModule } from '@nestjs/throttler';
import { AppController } from './app.controller';
import { AppService } from './app.service';
import { UsersModule } from './users/users.module';
@@ -11,13 +12,24 @@ import { RedisModule } from './database/redis.module';
import { WsModule } from './ws/ws.module';
import { FriendsModule } from './friends/friends.module';
import { DollsModule } from './dolls/dolls.module';
import { parseRedisRequired } from './common/config/env.utils';
/**
* Validates required environment variables.
* Throws an error if any required variables are missing or invalid.
* Returns the validated config.
*/
function validateEnvironment(config: Record<string, any>): Record<string, any> {
function getOptionalEnvString(
config: Record<string, unknown>,
key: string,
): string | undefined {
const value = config[key];
return typeof value === 'string' ? value : undefined;
}
function validateEnvironment(
config: Record<string, unknown>,
): Record<string, unknown> {
const requiredVars = ['JWT_SECRET', 'DATABASE_URL'];
const missingVars = requiredVars.filter((varName) => !config[varName]);
@@ -29,10 +41,44 @@ function validateEnvironment(config: Record<string, any>): Record<string, any> {
}
// Validate PORT if provided
if (config.PORT && isNaN(Number(config.PORT))) {
if (config.PORT !== undefined && !Number.isFinite(Number(config.PORT))) {
throw new Error('PORT must be a valid number');
}
if (config.NODE_ENV === 'production') {
if (
typeof config.JWT_SECRET !== 'string' ||
config.JWT_SECRET.length < 32
) {
throw new Error(
'JWT_SECRET must be at least 32 characters in production',
);
}
}
const redisRequired = parseRedisRequired({
nodeEnv: getOptionalEnvString(config, 'NODE_ENV'),
redisRequired: getOptionalEnvString(config, 'REDIS_REQUIRED'),
});
if (redisRequired && !config.REDIS_HOST) {
throw new Error(
'REDIS_REQUIRED is enabled but REDIS_HOST is not configured',
);
}
const redisConnectTimeout = getOptionalEnvString(
config,
'REDIS_CONNECT_TIMEOUT_MS',
);
if (
redisConnectTimeout !== undefined &&
(!Number.isFinite(Number(redisConnectTimeout)) ||
Number(redisConnectTimeout) <= 0)
) {
throw new Error('REDIS_CONNECT_TIMEOUT_MS must be a positive number');
}
validateOptionalProvider(config, 'GOOGLE');
validateOptionalProvider(config, 'DISCORD');
@@ -40,7 +86,7 @@ function validateEnvironment(config: Record<string, any>): Record<string, any> {
}
function validateOptionalProvider(
config: Record<string, any>,
config: Record<string, unknown>,
provider: 'GOOGLE' | 'DISCORD',
): void {
const vars = [
@@ -76,8 +122,8 @@ function validateOptionalProvider(
inject: [ConfigService],
useFactory: (config: ConfigService) => [
{
ttl: config.get('THROTTLE_TTL', 60000),
limit: config.get('THROTTLE_LIMIT', 10),
ttl: config.get('THROTTLE_TTL', 1000),
limit: config.get('THROTTLE_LIMIT', 5),
},
],
}),
@@ -91,6 +137,12 @@ function validateOptionalProvider(
DollsModule,
],
controllers: [AppController],
providers: [AppService],
providers: [
AppService,
{
provide: APP_GUARD,
useClass: ThrottlerGuard,
},
],
})
export class AppModule {}

View File

@@ -1,8 +1,19 @@
import { Injectable } from '@nestjs/common';
import { readFileSync } from 'fs';
import { join } from 'path';
import { PrismaService } from './database/prisma.service';
const appVersion =
process.env.APP_VERSION ?? process.env.npm_package_version ?? 'unknown';
const appVersion = (() => {
if (process.env.APP_VERSION) return process.env.APP_VERSION;
try {
const pkg = JSON.parse(
readFileSync(join(__dirname, '../../package.json'), 'utf-8'),
) as { version: string };
return pkg.version;
} catch {
return 'unknown';
}
})();
export type DatabaseHealth = 'OK' | 'DOWN';

View File

@@ -10,6 +10,7 @@ import { UsersModule } from '../users/users.module';
import { AuthController } from './auth.controller';
import { GoogleAuthGuard } from './guards/google-auth.guard';
import { DiscordAuthGuard } from './guards/discord-auth.guard';
import { AuthCleanupService } from './services/auth-cleanup.service';
@Module({
imports: [
@@ -26,6 +27,7 @@ import { DiscordAuthGuard } from './guards/discord-auth.guard';
DiscordAuthGuard,
AuthService,
JwtVerificationService,
AuthCleanupService,
],
exports: [AuthService, PassportModule, JwtVerificationService],
})

View File

@@ -50,6 +50,7 @@ describe('AuthService', () => {
},
user: {
update: jest.fn(),
findFirst: jest.fn(),
findUnique: jest.fn(),
create: jest.fn(),
},
@@ -424,5 +425,89 @@ describe('AuthService', () => {
data: expect.objectContaining({ providerEmail: 'jane@example.com' }),
});
});
it('derives username from email local-part when provider username is missing', async () => {
const state = service.startSso(
'google',
'http://127.0.0.1:43123/callback',
).state;
const profileWithoutUsername: SocialAuthProfile = {
...socialProfile,
email: 'Alice@example.com',
username: undefined,
};
const txUserCreate = jest.fn().mockResolvedValue({ id: 'user-1' });
const txIdentityCreate = jest.fn().mockResolvedValue(undefined);
mockPrismaService.authIdentity.findUnique.mockResolvedValue(null);
mockPrismaService.user.findFirst = jest.fn().mockResolvedValue(null);
mockPrismaService.$transaction.mockImplementation((callback) =>
Promise.resolve(
callback({
user: {
findUnique: jest.fn().mockResolvedValue(null),
create: txUserCreate,
},
authIdentity: {
create: txIdentityCreate,
},
}),
),
);
mockPrismaService.authExchangeCode.create.mockResolvedValue({
id: 'code-1',
});
await service.completeSso('google', state, profileWithoutUsername);
expect(txUserCreate).toHaveBeenCalledWith({
data: expect.objectContaining({ username: 'alice' }),
});
expect(txIdentityCreate).toHaveBeenCalledWith({
data: expect.objectContaining({ providerUsername: 'alice' }),
});
});
it('adds a numeric suffix when derived username is already taken', async () => {
const state = service.startSso(
'google',
'http://127.0.0.1:43123/callback',
).state;
const profileWithoutUsername: SocialAuthProfile = {
...socialProfile,
email: 'Alice@example.com',
username: undefined,
};
const txUserCreate = jest.fn().mockResolvedValue({ id: 'user-1' });
const txIdentityCreate = jest.fn().mockResolvedValue(undefined);
mockPrismaService.authIdentity.findUnique.mockResolvedValue(null);
mockPrismaService.user.findFirst = jest
.fn()
.mockResolvedValueOnce({ id: 'existing-user' })
.mockResolvedValueOnce(null);
mockPrismaService.$transaction.mockImplementation((callback) =>
Promise.resolve(
callback({
user: {
findUnique: jest.fn().mockResolvedValue(null),
create: txUserCreate,
},
authIdentity: {
create: txIdentityCreate,
},
}),
),
);
mockPrismaService.authExchangeCode.create.mockResolvedValue({
id: 'code-1',
});
await service.completeSso('google', state, profileWithoutUsername);
expect(txUserCreate).toHaveBeenCalledWith({
data: expect.objectContaining({ username: 'alice2' }),
});
});
});
});

View File

@@ -29,8 +29,10 @@ import {
asProviderName,
isLoopbackRedirect,
normalizeEmail,
normalizeUsername,
randomOpaqueToken,
sha256,
usernameFromEmail,
} from './auth.utils';
import type { SsoProvider } from './dto/sso-provider';
@@ -220,6 +222,11 @@ export class AuthService {
const normalizedProviderEmail = profile.email
? normalizeEmail(profile.email)
: null;
const resolvedUsername = await this.resolveUsername(
profile.username,
normalizedProviderEmail,
existingIdentity.user.id,
);
await this.prisma.authIdentity.update({
where: { id: existingIdentity.id },
@@ -228,7 +235,7 @@ export class AuthService {
? { providerEmail: normalizedProviderEmail }
: {}),
providerName: profile.displayName,
providerUsername: profile.username,
providerUsername: resolvedUsername,
providerPicture: profile.picture,
emailVerified: profile.emailVerified,
},
@@ -241,7 +248,7 @@ export class AuthService {
? { email: normalizedProviderEmail }
: {}),
name: profile.displayName,
username: profile.username,
username: resolvedUsername,
picture: profile.picture,
lastLoginAt: now,
},
@@ -255,6 +262,10 @@ export class AuthService {
}
const email = normalizeEmail(profile.email);
const resolvedUsername = await this.resolveUsername(
profile.username,
email,
);
if (!profile.emailVerified) {
throw new BadRequestException(
@@ -277,7 +288,7 @@ export class AuthService {
data: {
email,
name: profile.displayName,
username: profile.username,
username: resolvedUsername,
picture: profile.picture,
roles: [],
lastLoginAt: now,
@@ -291,7 +302,7 @@ export class AuthService {
providerSubject: profile.providerSubject,
providerEmail: email,
providerName: profile.displayName,
providerUsername: profile.username,
providerUsername: resolvedUsername,
providerPicture: profile.picture,
emailVerified: profile.emailVerified,
userId: user.id,
@@ -302,6 +313,58 @@ export class AuthService {
});
}
private async resolveUsername(
providerUsername: string | undefined,
email: string | null,
excludeUserId?: string,
): Promise<string> {
const candidates = [
providerUsername ? normalizeUsername(providerUsername) : '',
email ? usernameFromEmail(email) : '',
'friendoll',
].filter(
(value, index, all) => value.length > 0 && all.indexOf(value) === index,
);
for (const base of candidates) {
const available = await this.isUsernameAvailable(base, excludeUserId);
if (available) {
return base;
}
for (let suffix = 2; suffix < 10_000; suffix += 1) {
const maxBaseLength = Math.max(1, 24 - suffix.toString().length);
const candidate = `${base.slice(0, maxBaseLength)}${suffix}`;
const available = await this.isUsernameAvailable(
candidate,
excludeUserId,
);
if (available) {
return candidate;
}
}
}
throw new ServiceUnavailableException('Unable to assign a unique username');
}
private async isUsernameAvailable(
username: string,
excludeUserId?: string,
): Promise<boolean> {
const existing = await this.prisma.user.findFirst({
where: {
username,
...(excludeUserId ? { id: { not: excludeUserId } } : {}),
},
select: {
id: true,
},
});
return !existing;
}
private async issueTokens(
userId: string,
email: string,

View File

@@ -28,3 +28,19 @@ export function asProviderName(value: SsoProvider): 'GOOGLE' | 'DISCORD' {
export function normalizeEmail(email: string): string {
return email.trim().toLowerCase();
}
export function normalizeUsername(value: string): string {
return value
.trim()
.toLowerCase()
.normalize('NFKD')
.replace(/[\u0300-\u036f]/g, '')
.replace(/[^a-z0-9]+/g, '_')
.replace(/^_+|_+$/g, '')
.slice(0, 24);
}
export function usernameFromEmail(email: string): string {
const localPart = normalizeEmail(email).split('@')[0] ?? '';
return normalizeUsername(localPart);
}

View File

@@ -0,0 +1,159 @@
import {
Injectable,
Inject,
Logger,
OnModuleDestroy,
OnModuleInit,
} from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { PrismaService } from '../../database/prisma.service';
import Redis from 'ioredis';
import {
parseBoolean,
parsePositiveInteger,
} from '../../common/config/env.utils';
import { REDIS_CLIENT } from '../../database/redis.module';
const MIN_CLEANUP_INTERVAL_MS = 60_000;
const DEFAULT_CLEANUP_INTERVAL_MS = 15 * 60_000;
const DEFAULT_REVOKED_RETENTION_DAYS = 7;
const CLEANUP_LOCK_KEY = 'lock:auth:cleanup';
const CLEANUP_LOCK_TTL_MS = 55_000;
@Injectable()
export class AuthCleanupService implements OnModuleInit, OnModuleDestroy {
private readonly logger = new Logger(AuthCleanupService.name);
private cleanupTimer: NodeJS.Timeout | null = null;
private isCleanupRunning = false;
constructor(
private readonly prisma: PrismaService,
private readonly configService: ConfigService,
@Inject(REDIS_CLIENT) private readonly redisClient: Redis | null,
) {}
onModuleInit(): void {
const enabled = parseBoolean(
this.configService.get<string>('AUTH_CLEANUP_ENABLED'),
true,
);
if (!enabled) {
this.logger.log('Auth cleanup task disabled');
return;
}
const configuredInterval = parsePositiveInteger(
this.configService.get<string>('AUTH_CLEANUP_INTERVAL_MS'),
DEFAULT_CLEANUP_INTERVAL_MS,
);
const cleanupIntervalMs = Math.max(
configuredInterval,
MIN_CLEANUP_INTERVAL_MS,
);
this.cleanupTimer = setInterval(() => {
void this.cleanupExpiredAuthData();
}, cleanupIntervalMs);
this.cleanupTimer.unref();
void this.cleanupExpiredAuthData();
this.logger.log(`Auth cleanup task scheduled every ${cleanupIntervalMs}ms`);
}
onModuleDestroy(): void {
if (!this.cleanupTimer) {
return;
}
clearInterval(this.cleanupTimer);
this.cleanupTimer = null;
}
private async cleanupExpiredAuthData(): Promise<void> {
if (this.isCleanupRunning) {
this.logger.warn(
'Skipping auth cleanup run because previous run is still in progress',
);
return;
}
this.isCleanupRunning = true;
const lockToken = `${process.pid}-${Date.now()}-${Math.random().toString(36).slice(2)}`;
let lockAcquired = false;
try {
if (this.redisClient) {
try {
const lockResult = await this.redisClient.set(
CLEANUP_LOCK_KEY,
lockToken,
'PX',
CLEANUP_LOCK_TTL_MS,
'NX',
);
if (lockResult !== 'OK') {
return;
}
lockAcquired = true;
} catch (error) {
this.logger.warn(
'Failed to acquire auth cleanup lock; running cleanup without distributed lock',
error as Error,
);
}
}
const now = new Date();
const revokedRetentionDays = parsePositiveInteger(
this.configService.get<string>('AUTH_SESSION_REVOKED_RETENTION_DAYS'),
DEFAULT_REVOKED_RETENTION_DAYS,
);
const revokedCutoff = new Date(
now.getTime() - revokedRetentionDays * 24 * 60 * 60 * 1000,
);
const [codes, sessions] = await Promise.all([
this.prisma.authExchangeCode.deleteMany({
where: {
OR: [{ expiresAt: { lt: now } }, { consumedAt: { not: null } }],
},
}),
this.prisma.authSession.deleteMany({
where: {
OR: [
{ expiresAt: { lt: now } },
{ revokedAt: { lt: revokedCutoff } },
],
},
}),
]);
const totalDeleted = codes.count + sessions.count;
if (totalDeleted > 0) {
this.logger.log(
`Auth cleanup removed ${totalDeleted} records (${codes.count} exchange codes, ${sessions.count} sessions)`,
);
}
} catch (error) {
this.logger.error('Auth cleanup task failed', error as Error);
} finally {
if (lockAcquired && this.redisClient) {
try {
const currentLockValue = await this.redisClient.get(CLEANUP_LOCK_KEY);
if (currentLockValue === lockToken) {
await this.redisClient.del(CLEANUP_LOCK_KEY);
}
} catch (error) {
this.logger.warn(
'Failed to release auth cleanup lock',
error as Error,
);
}
}
this.isCleanupRunning = false;
}
}
}

View File

@@ -0,0 +1,66 @@
export function parseBoolean(
value: string | undefined,
fallback: boolean,
): boolean {
if (value === undefined) {
return fallback;
}
const normalized = value.trim().toLowerCase();
if (['true', '1', 'yes', 'y', 'on'].includes(normalized)) {
return true;
}
if (['false', '0', 'no', 'n', 'off'].includes(normalized)) {
return false;
}
return fallback;
}
export function parsePositiveInteger(
value: string | undefined,
fallback: number,
): number {
if (!value) {
return fallback;
}
const parsed = Number(value);
if (!Number.isFinite(parsed) || parsed <= 0) {
return fallback;
}
return Math.floor(parsed);
}
export function parseCsvList(value: string | undefined): string[] {
if (!value) {
return [];
}
return value
.split(',')
.map((item) => item.trim())
.filter((item) => item.length > 0);
}
export function isLikelyHttpOrigin(origin: string): boolean {
try {
const parsed = new URL(origin);
return parsed.protocol === 'http:' || parsed.protocol === 'https:';
} catch {
return false;
}
}
export function parseRedisRequired(config: {
nodeEnv?: string;
redisRequired?: string;
}): boolean {
if (config.redisRequired === undefined) {
return config.nodeEnv === 'production';
}
return parseBoolean(config.redisRequired, false);
}

View File

@@ -40,6 +40,7 @@ export class PrismaService
implements OnModuleInit, OnModuleDestroy
{
private readonly logger = new Logger(PrismaService.name);
private readonly pool: Pool;
constructor(private configService: ConfigService) {
const databaseUrl = configService.get<string>('DATABASE_URL');
@@ -62,6 +63,8 @@ export class PrismaService
],
});
this.pool = pool;
// Log database queries in development mode
if (process.env.NODE_ENV === 'development') {
this.$on('query' as never, (e: QueryEvent) => {
@@ -101,6 +104,7 @@ export class PrismaService
async onModuleDestroy() {
try {
await this.$disconnect();
await this.pool.end();
this.logger.log('Successfully disconnected from database');
} catch (error) {
this.logger.error('Error disconnecting from database', error);

View File

@@ -1,47 +1,132 @@
import { Module, Global, Logger } from '@nestjs/common';
import {
Inject,
Injectable,
Logger,
Module,
Global,
OnModuleDestroy,
} from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import Redis from 'ioredis';
import {
parsePositiveInteger,
parseRedisRequired,
} from '../common/config/env.utils';
export const REDIS_CLIENT = 'REDIS_CLIENT';
export const REDIS_SUBSCRIBER_CLIENT = 'REDIS_SUBSCRIBER_CLIENT';
const DEFAULT_REDIS_STARTUP_RETRIES = 10;
@Injectable()
class RedisLifecycleService implements OnModuleDestroy {
private readonly logger = new Logger(RedisLifecycleService.name);
constructor(
@Inject(REDIS_CLIENT) private readonly redisClient: Redis | null,
@Inject(REDIS_SUBSCRIBER_CLIENT)
private readonly redisSubscriber: Redis | null,
) {}
async onModuleDestroy(): Promise<void> {
const clients = [this.redisClient, this.redisSubscriber].filter(
(client): client is Redis => client !== null,
);
if (clients.length === 0) {
return;
}
await Promise.all(
clients.map(async (client) => {
try {
await client.quit();
} catch (error) {
this.logger.warn(
'Redis quit failed, forcing disconnect',
error as Error,
);
client.disconnect();
}
}),
);
}
}
@Global()
@Module({
providers: [
{
provide: REDIS_CLIENT,
useFactory: (configService: ConfigService) => {
useFactory: async (configService: ConfigService) => {
const logger = new Logger('RedisModule');
const host = configService.get<string>('REDIS_HOST');
const port = configService.get<number>('REDIS_PORT');
const port = parsePositiveInteger(
configService.get<string>('REDIS_PORT'),
6379,
);
const password = configService.get<string>('REDIS_PASSWORD');
const connectTimeout = parsePositiveInteger(
configService.get<string>('REDIS_CONNECT_TIMEOUT_MS'),
5000,
);
const redisRequired = parseRedisRequired({
nodeEnv: configService.get<string>('NODE_ENV'),
redisRequired: configService.get<string>('REDIS_REQUIRED'),
});
const startupRetries = parsePositiveInteger(
configService.get<string>('REDIS_STARTUP_RETRIES'),
DEFAULT_REDIS_STARTUP_RETRIES,
);
// Fallback or "disabled" mode if no host is provided
if (!host) {
logger.warn(
'REDIS_HOST not defined. Redis features will be disabled or fall back to local memory.',
);
if (redisRequired) {
throw new Error(
'REDIS_REQUIRED is enabled but REDIS_HOST is not configured',
);
}
logger.warn('REDIS_HOST not defined. Redis features are disabled.');
return null;
}
const client = new Redis({
host,
port: port || 6379,
password: password,
// Retry strategy: keep trying to reconnect
port,
password,
lazyConnect: true,
connectTimeout,
maxRetriesPerRequest: 1,
enableOfflineQueue: false,
retryStrategy(times) {
if (times > startupRetries) {
return null;
}
const delay = Math.min(times * 50, 2000);
return delay;
},
});
client.on('error', (err) => {
logger.error('Redis connection error', err);
client.on('connect', () => {
logger.log(`Connected to Redis at ${host}:${port}`);
});
client.on('connect', () => {
logger.log(`Connected to Redis at ${host}:${port || 6379}`);
});
try {
await client.connect();
await client.ping();
} catch {
client.disconnect();
if (redisRequired) {
throw new Error(
`Failed to connect to required Redis at ${host}:${port}`,
);
}
logger.warn('Redis connection failed; Redis features are disabled.');
return null;
}
return client;
},
@@ -49,11 +134,26 @@ export const REDIS_SUBSCRIBER_CLIENT = 'REDIS_SUBSCRIBER_CLIENT';
},
{
provide: REDIS_SUBSCRIBER_CLIENT,
useFactory: (configService: ConfigService) => {
useFactory: async (configService: ConfigService) => {
const logger = new Logger('RedisSubscriberModule');
const host = configService.get<string>('REDIS_HOST');
const port = configService.get<number>('REDIS_PORT');
const port = parsePositiveInteger(
configService.get<string>('REDIS_PORT'),
6379,
);
const password = configService.get<string>('REDIS_PASSWORD');
const connectTimeout = parsePositiveInteger(
configService.get<string>('REDIS_CONNECT_TIMEOUT_MS'),
5000,
);
const redisRequired = parseRedisRequired({
nodeEnv: configService.get<string>('NODE_ENV'),
redisRequired: configService.get<string>('REDIS_REQUIRED'),
});
const startupRetries = parsePositiveInteger(
configService.get<string>('REDIS_STARTUP_RETRIES'),
DEFAULT_REDIS_STARTUP_RETRIES,
);
if (!host) {
return null;
@@ -61,9 +161,17 @@ export const REDIS_SUBSCRIBER_CLIENT = 'REDIS_SUBSCRIBER_CLIENT';
const client = new Redis({
host,
port: port || 6379,
password: password,
port,
password,
lazyConnect: true,
connectTimeout,
maxRetriesPerRequest: 1,
enableOfflineQueue: false,
retryStrategy(times) {
if (times > startupRetries) {
return null;
}
const delay = Math.min(times * 50, 2000);
return delay;
},
@@ -82,10 +190,29 @@ export const REDIS_SUBSCRIBER_CLIENT = 'REDIS_SUBSCRIBER_CLIENT';
logger.error('Redis subscriber connection error', err);
});
try {
await client.connect();
await client.ping();
} catch {
client.disconnect();
if (redisRequired) {
throw new Error(
`Failed to connect to required Redis subscriber at ${host}:${port}`,
);
}
logger.warn(
'Redis subscriber connection failed; cross-instance subscriptions are disabled.',
);
return null;
}
return client;
},
inject: [ConfigService],
},
RedisLifecycleService,
],
exports: [REDIS_CLIENT, REDIS_SUBSCRIBER_CLIENT],
})

View File

@@ -1,5 +1,4 @@
import { Test, TestingModule } from '@nestjs/testing';
import { ThrottlerModule } from '@nestjs/throttler';
import { FriendsController } from './friends.controller';
import { FriendsService } from './friends.service';
import { UsersService } from '../users/users.service';
@@ -18,6 +17,7 @@ describe('FriendsController', () => {
userId: 'user-1',
email: 'user1@example.com',
roles: [],
tokenType: 'access' as const,
};
const mockUser1 = {
@@ -83,14 +83,6 @@ describe('FriendsController', () => {
beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({
imports: [
ThrottlerModule.forRoot([
{
ttl: 60000,
limit: 10,
},
]),
],
controllers: [FriendsController],
providers: [
{ provide: FriendsService, useValue: mockFriendsService },

View File

@@ -7,8 +7,8 @@ import {
Body,
Query,
HttpCode,
UseGuards,
Logger,
UseGuards,
} from '@nestjs/common';
import {
ApiTags,
@@ -19,7 +19,6 @@ import {
ApiUnauthorizedResponse,
ApiQuery,
} from '@nestjs/swagger';
import { ThrottlerGuard, Throttle } from '@nestjs/throttler';
import { User, FriendRequest, Prisma } from '@prisma/client';
import { FriendsService } from './friends.service';
import { JwtAuthGuard } from '../auth/guards/jwt-auth.guard';
@@ -62,8 +61,6 @@ export class FriendsController {
) {}
@Get('search')
@UseGuards(ThrottlerGuard)
@Throttle({ default: { limit: 10, ttl: 60000 } })
@ApiOperation({
summary: 'Search users by username',
description: 'Search for users by username to send friend requests',

View File

@@ -2,6 +2,7 @@ import { NestFactory } from '@nestjs/core';
import { ValidationPipe, Logger } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { DocumentBuilder, SwaggerModule } from '@nestjs/swagger';
import helmet from 'helmet';
import { AppModule } from './app.module';
import { AllExceptionsFilter } from './common/filters/all-exceptions.filter';
import { RedisIoAdapter } from './ws/redis-io.adapter';
@@ -10,12 +11,28 @@ async function bootstrap() {
const logger = new Logger('Bootstrap');
const app = await NestFactory.create(AppModule);
const configService = app.get(ConfigService);
const nodeEnv = configService.get<string>('NODE_ENV') || 'development';
const isProduction = nodeEnv === 'production';
app.enableShutdownHooks();
app.use(
helmet({
contentSecurityPolicy: false,
crossOriginEmbedderPolicy: false,
}),
);
// Configure Redis Adapter for horizontal scaling (if enabled)
const redisIoAdapter = new RedisIoAdapter(app, configService);
await redisIoAdapter.connectToRedis();
app.useWebSocketAdapter(redisIoAdapter);
app.enableCors({
origin: true,
credentials: true,
});
// Enable global exception filter for consistent error responses
app.useGlobalFilters(new AllExceptionsFilter());
@@ -29,43 +46,54 @@ async function bootstrap() {
// Automatically transform payloads to DTO instances
transform: true,
// Provide detailed error messages
disableErrorMessages: false,
disableErrorMessages: isProduction,
}),
);
// Configure Swagger documentation
const config = new DocumentBuilder()
.setTitle('Friendolls API')
.setDescription(
'API for managing users in Friendolls application.\n\n' +
'Authentication is handled via Passport.js social sign-in for desktop clients.\n' +
'Desktop clients exchange one-time SSO codes for Friendolls JWT tokens.\n\n' +
'Include the JWT token in the Authorization header as: `Bearer <token>`',
)
.setVersion('1.0')
.addBearerAuth(
{
type: 'http',
scheme: 'bearer',
bearerFormat: 'JWT',
name: 'Authorization',
description: 'Enter Friendolls JWT access token',
in: 'header',
},
'bearer',
)
.addTag('users', 'User profile management endpoints')
.build();
if (!isProduction) {
const config = new DocumentBuilder()
.setTitle('Friendolls API')
.setDescription(
'API for managing users in Friendolls application.\n\n' +
'Authentication is handled via Passport.js social sign-in for desktop clients.\n' +
'Desktop clients exchange one-time SSO codes for Friendolls JWT tokens.\n\n' +
'Include the JWT token in the Authorization header as: `Bearer <token>`',
)
.setVersion('1.0')
.addBearerAuth(
{
type: 'http',
scheme: 'bearer',
bearerFormat: 'JWT',
name: 'Authorization',
description: 'Enter Friendolls JWT access token',
in: 'header',
},
'bearer',
)
.addTag('users', 'User profile management endpoints')
.build();
const document = SwaggerModule.createDocument(app, config);
SwaggerModule.setup('api', app, document);
const document = SwaggerModule.createDocument(app, config);
SwaggerModule.setup('api', app, document);
}
const host = process.env.HOST ?? 'localhost';
const port = process.env.PORT ?? 3000;
await app.listen(port);
const httpServer = app.getHttpServer() as {
once?: (event: 'close', listener: () => void) => void;
} | null;
httpServer?.once?.('close', () => {
void redisIoAdapter.close();
});
logger.log(`Application is running on: http://${host}:${port}`);
logger.log(`Swagger documentation available at: http://${host}:${port}/api`);
if (!isProduction) {
logger.log(
`Swagger documentation available at: http://${host}:${port}/api`,
);
}
}
void bootstrap();

View File

@@ -11,5 +11,7 @@ export type AuthenticatedSocket = BaseSocket<
userId?: string;
activeDollId?: string | null;
friends?: Set<string>; // Set of friend user IDs
senderName?: string;
senderNameCachedAt?: number;
}
>;

View File

@@ -22,6 +22,7 @@ describe('UsersController', () => {
const mockAuthUser: AuthenticatedUser = {
userId: 'uuid-123',
email: 'test@example.com',
tokenType: 'access',
roles: ['user'],
};

View File

@@ -4,10 +4,18 @@ import { createAdapter } from '@socket.io/redis-adapter';
import Redis from 'ioredis';
import { ConfigService } from '@nestjs/config';
import { INestApplicationContext, Logger } from '@nestjs/common';
import {
parsePositiveInteger,
parseRedisRequired,
} from '../common/config/env.utils';
const DEFAULT_REDIS_STARTUP_RETRIES = 10;
export class RedisIoAdapter extends IoAdapter {
private adapterConstructor: ReturnType<typeof createAdapter>;
private readonly logger = new Logger(RedisIoAdapter.name);
private pubClient: Redis | null = null;
private subClient: Redis | null = null;
constructor(
private app: INestApplicationContext,
@@ -18,41 +26,63 @@ export class RedisIoAdapter extends IoAdapter {
async connectToRedis(): Promise<void> {
const host = this.configService.get<string>('REDIS_HOST');
const port = this.configService.get<number>('REDIS_PORT');
const port = parsePositiveInteger(
this.configService.get<string>('REDIS_PORT'),
6379,
);
const password = this.configService.get<string>('REDIS_PASSWORD');
const startupRetries = parsePositiveInteger(
this.configService.get<string>('REDIS_STARTUP_RETRIES'),
DEFAULT_REDIS_STARTUP_RETRIES,
);
const redisRequired = parseRedisRequired({
nodeEnv: this.configService.get<string>('NODE_ENV'),
redisRequired: this.configService.get<string>('REDIS_REQUIRED'),
});
// Only set up Redis adapter if host is configured
if (!host) {
if (redisRequired) {
throw new Error(
'REDIS_REQUIRED is enabled but REDIS_HOST is not configured',
);
}
this.logger.log('Redis adapter disabled (REDIS_HOST not set)');
return;
}
this.logger.log(`Connecting Redis adapter to ${host}:${port || 6379}`);
this.logger.log(`Connecting Redis adapter to ${host}:${port}`);
try {
const connectTimeout = parsePositiveInteger(
this.configService.get<string>('REDIS_CONNECT_TIMEOUT_MS'),
5000,
);
const pubClient = new Redis({
host,
port: port || 6379,
password: password,
port,
password,
lazyConnect: true,
connectTimeout,
maxRetriesPerRequest: 1,
enableOfflineQueue: false,
retryStrategy(times) {
// Retry connecting but don't crash if Redis is temporarily down during startup
if (times > startupRetries) {
return null;
}
return Math.min(times * 50, 2000);
},
});
const subClient = pubClient.duplicate();
// Wait for connection to ensure it's valid
await new Promise<void>((resolve, reject) => {
pubClient.once('connect', () => {
this.logger.log('Redis Pub client connected');
resolve();
});
pubClient.once('error', (err) => {
this.logger.error('Redis Pub client error', err);
reject(err);
});
});
await pubClient.connect();
await subClient.connect();
await pubClient.ping();
await subClient.ping();
this.logger.log('Redis Pub/Sub clients connected');
// Handle subsequent errors gracefully
pubClient.on('error', (err) => {
@@ -73,21 +103,53 @@ export class RedisIoAdapter extends IoAdapter {
});
this.adapterConstructor = createAdapter(pubClient, subClient);
this.pubClient = pubClient;
this.subClient = subClient;
this.logger.log('Redis adapter initialized successfully');
} catch (error) {
await this.close();
this.logger.error('Failed to initialize Redis adapter', error);
// We don't throw here to allow the app to start without Redis if connection fails,
// though functionality will be degraded if multiple instances are running.
if (redisRequired) {
throw error;
}
}
}
createIOServer(port: number, options?: ServerOptions): any {
const cors = {
origin: true,
credentials: true,
};
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
const server = super.createIOServer(port, options);
const server = super.createIOServer(port, {
...(options ?? {}),
cors,
});
if (this.adapterConstructor) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-call, @typescript-eslint/no-unsafe-member-access
server.adapter(this.adapterConstructor);
}
return server;
}
async close(): Promise<void> {
const clients = [this.pubClient, this.subClient].filter(
(client): client is Redis => client !== null,
);
await Promise.all(
clients.map(async (client) => {
try {
await client.quit();
} catch {
client.disconnect();
}
}),
);
this.pubClient = null;
this.subClient = null;
}
}

View File

@@ -6,13 +6,11 @@ import { PrismaService } from '../../../database/prisma.service';
import { UserSocketService } from '../user-socket.service';
import { WsNotificationService } from '../ws-notification.service';
import { WS_EVENT } from '../ws-events';
import { UsersService } from '../../../users/users.service';
export class ConnectionHandler {
constructor(
private readonly jwtVerificationService: JwtVerificationService,
private readonly prisma: PrismaService,
private readonly usersService: UsersService,
private readonly userSocketService: UserSocketService,
private readonly wsNotificationService: WsNotificationService,
private readonly logger: Logger,
@@ -94,42 +92,42 @@ export class ConnectionHandler {
this.logger.log(
`WebSocket authenticated via initialize fallback (Pending Init): ${payload.sub}`,
);
this.logger.log(
`WebSocket authenticated via initialize fallback (Pending Init): ${payload.sub}`,
);
}
if (!userTokenData) {
throw new WsException('Unauthorized: No user data found');
}
const user = await this.usersService.findOne(userTokenData.userId);
// 2. Register socket mapping (Redis Write)
await this.userSocketService.setSocket(user.id, client.id);
client.data.userId = user.id;
// 3. Fetch initial state (DB Read)
const [userWithDoll, friends] = await Promise.all([
// 2. Fetch initial state (DB Read)
const [userState, friends] = await Promise.all([
this.prisma.user.findUnique({
where: { id: user.id },
select: { activeDollId: true },
where: { id: userTokenData.userId },
select: { id: true, name: true, username: true, activeDollId: true },
}),
this.prisma.friendship.findMany({
where: { userId: user.id },
where: { userId: userTokenData.userId },
select: { friendId: true },
}),
]);
client.data.activeDollId = userWithDoll?.activeDollId || null;
client.data.friends = new Set(friends.map((f) => f.friendId));
if (!userState) {
throw new WsException('Unauthorized: No user data found');
}
this.logger.log(`Client initialized: ${user.id} (${client.id})`);
// 3. Register socket mapping (Redis Write)
await this.userSocketService.setSocket(userState.id, client.id);
client.data.userId = userState.id;
client.data.activeDollId = userState.activeDollId || null;
client.data.friends = new Set(friends.map((f) => f.friendId));
client.data.senderName = userState.name || userState.username;
client.data.senderNameCachedAt = Date.now();
this.logger.log(`Client initialized: ${userState.id} (${client.id})`);
// 4. Notify client
client.emit(WS_EVENT.INITIALIZED, {
userId: user.id,
userId: userState.id,
activeDollId: client.data.activeDollId,
});
} catch (error) {
@@ -157,7 +155,9 @@ export class ConnectionHandler {
// Notify friends that this user has disconnected
const friends = client.data.friends;
if (friends) {
const friendIds = Array.from(friends);
const friendIds = Array.from(friends).filter(
(friendId): friendId is string => typeof friendId === 'string',
);
const friendSockets =
await this.userSocketService.getFriendsSockets(friendIds);
@@ -179,9 +179,5 @@ export class ConnectionHandler {
this.logger.log(
`Client id: ${client.id} disconnected (user: ${user?.userId || 'unknown'})`,
);
this.logger.log(
`Client id: ${client.id} disconnected (user: ${user?.userId || 'unknown'})`,
);
}
}

View File

@@ -9,6 +9,8 @@ import { WsNotificationService } from '../ws-notification.service';
import { WS_EVENT } from '../ws-events';
import { Validator } from '../utils/validation';
const SENDER_NAME_CACHE_TTL_MS = 10 * 60 * 1000;
export class InteractionHandler {
private readonly logger = new Logger(InteractionHandler.name);
@@ -18,6 +20,32 @@ export class InteractionHandler {
private readonly wsNotificationService: WsNotificationService,
) {}
private async resolveSenderName(
client: AuthenticatedSocket,
userId: string,
): Promise<string> {
const cachedName = client.data.senderName;
const cachedAt = client.data.senderNameCachedAt;
const cacheIsFresh =
cachedName &&
typeof cachedAt === 'number' &&
Date.now() - cachedAt < SENDER_NAME_CACHE_TTL_MS;
if (cacheIsFresh) {
return cachedName;
}
const sender = await this.prisma.user.findUnique({
where: { id: userId },
select: { name: true, username: true },
});
const senderName = sender?.name || sender?.username || 'Unknown';
client.data.senderName = senderName;
client.data.senderNameCachedAt = Date.now();
return senderName;
}
async handleSendInteraction(
client: AuthenticatedSocket,
data: SendInteractionDto,
@@ -39,7 +67,16 @@ export class InteractionHandler {
return;
}
// 2. Check if recipient is online
// 2. Validate text content length
if (data.type === 'text' && data.content && data.content.length > 50) {
client.emit(WS_EVENT.INTERACTION_DELIVERY_FAILED, {
recipientUserId: data.recipientUserId,
reason: 'Text content exceeds 50 characters',
});
return;
}
// 3. Check if recipient is online
const isOnline = await this.userSocketService.isUserOnline(
data.recipientUserId,
);
@@ -52,11 +89,7 @@ export class InteractionHandler {
}
// 3. Construct payload
const sender = await this.prisma.user.findUnique({
where: { id: currentUserId },
select: { name: true, username: true },
});
const senderName = sender?.name || sender?.username || 'Unknown';
const senderName = await this.resolveSenderName(client, currentUserId);
const payload: InteractionPayloadDto = {
senderUserId: currentUserId,

View File

@@ -3,7 +3,6 @@ import { Test, TestingModule } from '@nestjs/testing';
import { StateGateway } from './state.gateway';
import { AuthenticatedSocket } from '../../types/socket';
import { JwtVerificationService } from '../../auth/services/jwt-verification.service';
import { UsersService } from '../../users/users.service';
import { PrismaService } from '../../database/prisma.service';
import { UserSocketService } from './user-socket.service';
import { WsNotificationService } from './ws-notification.service';
@@ -23,6 +22,8 @@ type MockSocket = {
userId?: string;
activeDollId?: string | null;
friends?: Set<string>;
senderName?: string;
senderNameCachedAt?: number;
};
handshake?: any;
disconnect?: jest.Mock;
@@ -39,7 +40,6 @@ describe('StateGateway', () => {
sockets: { sockets: { size: number; get: jest.Mock } };
to: jest.Mock;
};
let mockUsersService: Partial<UsersService>;
let mockJwtVerificationService: Partial<JwtVerificationService>;
let mockPrismaService: Partial<PrismaService>;
let mockUserSocketService: Partial<UserSocketService>;
@@ -67,12 +67,6 @@ describe('StateGateway', () => {
}),
};
mockUsersService = {
findOne: jest.fn().mockResolvedValue({
id: 'user-id',
}),
};
mockJwtVerificationService = {
extractToken: jest.fn((handshake) => handshake.auth?.token),
verifyToken: jest.fn().mockReturnValue({
@@ -83,7 +77,12 @@ describe('StateGateway', () => {
mockPrismaService = {
user: {
findUnique: jest.fn().mockResolvedValue({ activeDollId: 'doll-123' }),
findUnique: jest.fn().mockResolvedValue({
id: 'user-id',
name: 'Test User',
username: 'test-user',
activeDollId: 'doll-123',
}),
} as any,
friendship: {
findMany: jest.fn().mockResolvedValue([]),
@@ -119,7 +118,6 @@ describe('StateGateway', () => {
const module: TestingModule = await Test.createTestingModule({
providers: [
StateGateway,
{ provide: UsersService, useValue: mockUsersService },
{
provide: JwtVerificationService,
useValue: mockJwtVerificationService,
@@ -190,7 +188,6 @@ describe('StateGateway', () => {
);
// Should NOT call these anymore in handleConnection
expect(mockUsersService.findOne).not.toHaveBeenCalled();
expect(mockUserSocketService.setSocket).not.toHaveBeenCalled();
// Should set data on client
@@ -244,6 +241,9 @@ describe('StateGateway', () => {
// Mock Prisma responses
(mockPrismaService.user!.findUnique as jest.Mock).mockResolvedValue({
id: 'user-id',
name: 'Test User',
username: 'test-user',
activeDollId: 'doll-123',
});
(mockPrismaService.friendship!.findMany as jest.Mock).mockResolvedValue([
@@ -255,32 +255,29 @@ describe('StateGateway', () => {
mockClient as unknown as AuthenticatedSocket,
);
// 1. Load User
expect(mockUsersService.findOne).toHaveBeenCalledWith('test-sub');
// 2. Set Socket
// 1. Set Socket
expect(mockUserSocketService.setSocket).toHaveBeenCalledWith(
'user-id',
'client1',
);
// 3. Fetch State (DB)
// 2. Fetch State (DB)
expect(mockPrismaService.user!.findUnique).toHaveBeenCalledWith({
where: { id: 'user-id' },
select: { activeDollId: true },
where: { id: 'test-sub' },
select: { id: true, name: true, username: true, activeDollId: true },
});
expect(mockPrismaService.friendship!.findMany).toHaveBeenCalledWith({
where: { userId: 'user-id' },
where: { userId: 'test-sub' },
select: { friendId: true },
});
// 4. Update Client Data
// 3. Update Client Data
expect(mockClient.data.userId).toBe('user-id');
expect(mockClient.data.activeDollId).toBe('doll-123');
expect(mockClient.data.friends).toContain('friend-1');
expect(mockClient.data.friends).toContain('friend-2');
// 5. Emit Initialized
// 4. Emit Initialized
expect(mockClient.emit).toHaveBeenCalledWith('initialized', {
userId: 'user-id',
activeDollId: 'doll-123',

View File

@@ -1,4 +1,4 @@
import { Logger, Inject } from '@nestjs/common';
import { Logger, Inject, OnModuleDestroy } from '@nestjs/common';
import {
OnGatewayConnection,
OnGatewayDisconnect,
@@ -22,7 +22,6 @@ import { PrismaService } from '../../database/prisma.service';
import { UserSocketService } from './user-socket.service';
import { WsNotificationService } from './ws-notification.service';
import { WS_EVENT, REDIS_CHANNEL } from './ws-events';
import { UsersService } from '../../users/users.service';
import { ConnectionHandler } from './connection/handler';
import { CursorHandler } from './cursor/handler';
import { StatusHandler } from './status/handler';
@@ -31,14 +30,13 @@ import { RedisHandler } from './utils/redis-handler';
import { Broadcaster } from './utils/broadcasting';
import { Throttler } from './utils/throttling';
@WebSocketGateway({
cors: {
origin: true,
credentials: true,
},
})
@WebSocketGateway()
export class StateGateway
implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect
implements
OnGatewayInit,
OnGatewayConnection,
OnGatewayDisconnect,
OnModuleDestroy
{
private readonly logger = new Logger(StateGateway.name);
@@ -55,7 +53,6 @@ export class StateGateway
constructor(
private readonly jwtVerificationService: JwtVerificationService,
private readonly prisma: PrismaService,
private readonly usersService: UsersService,
private readonly userSocketService: UserSocketService,
private readonly wsNotificationService: WsNotificationService,
@Inject(REDIS_CLIENT) private readonly redisClient: Redis | null,
@@ -70,7 +67,6 @@ export class StateGateway
this.connectionHandler = new ConnectionHandler(
this.jwtVerificationService,
this.prisma,
this.usersService,
this.userSocketService,
this.wsNotificationService,
this.logger,
@@ -163,4 +159,10 @@ export class StateGateway
) {
await this.interactionHandler.handleSendInteraction(client, data);
}
onModuleDestroy() {
if (this.redisSubscriber) {
this.redisSubscriber.removeAllListeners('message');
}
}
}

View File

@@ -48,13 +48,27 @@ export class WsNotificationService {
action: 'add' | 'delete',
) {
if (this.redisClient) {
await this.redisClient.publish(
REDIS_CHANNEL.FRIEND_CACHE_UPDATE,
JSON.stringify({ userId, friendId, action }),
);
} else {
// Fallback: update locally
try {
await this.redisClient.publish(
REDIS_CHANNEL.FRIEND_CACHE_UPDATE,
JSON.stringify({ userId, friendId, action }),
);
return;
} catch (error) {
this.logger.warn(
'Redis publish failed for friend cache update; applying local cache update only',
error as Error,
);
}
}
try {
await this.updateFriendsCacheLocal(userId, friendId, action);
} catch (error) {
this.logger.error(
'Failed to apply local friend cache update',
error as Error,
);
}
}
@@ -89,13 +103,27 @@ export class WsNotificationService {
async publishActiveDollUpdate(userId: string, dollId: string | null) {
if (this.redisClient) {
await this.redisClient.publish(
REDIS_CHANNEL.ACTIVE_DOLL_UPDATE,
JSON.stringify({ userId, dollId }),
);
} else {
// Fallback: update locally
try {
await this.redisClient.publish(
REDIS_CHANNEL.ACTIVE_DOLL_UPDATE,
JSON.stringify({ userId, dollId }),
);
return;
} catch (error) {
this.logger.warn(
'Redis publish failed for active doll update; applying local cache update only',
error as Error,
);
}
}
try {
await this.updateActiveDollCache(userId, dollId);
} catch (error) {
this.logger.error(
'Failed to apply local active doll cache update',
error as Error,
);
}
}
}