This commit is contained in:
2025-12-04 23:43:41 +08:00
parent f5c573c52f
commit f04ffea612
12 changed files with 435 additions and 72 deletions

View File

@@ -1,12 +1,28 @@
import { Test, TestingModule } from '@nestjs/testing';
import { StateGateway } from './state.gateway';
import { Socket } from 'socket.io';
import { AuthenticatedSocket } from '../../types/socket';
import { AuthService } from '../../auth/auth.service';
import { JwtVerificationService } from '../../auth/services/jwt-verification.service';
interface MockSocket extends Partial<AuthenticatedSocket> {
id: string;
data: {
user?: {
keycloakSub: string;
};
};
handshake?: any;
disconnect?: jest.Mock;
}
describe('StateGateway', () => {
let gateway: StateGateway;
let mockLoggerLog: jest.SpyInstance;
let mockLoggerDebug: jest.SpyInstance;
let mockServer: any;
let mockLoggerWarn: jest.SpyInstance;
let mockServer: { sockets: { sockets: { size: number } } };
let mockAuthService: Partial<AuthService>;
let mockJwtVerificationService: Partial<JwtVerificationService>;
beforeEach(async () => {
mockServer = {
@@ -17,23 +33,40 @@ describe('StateGateway', () => {
},
};
mockAuthService = {
syncUserFromToken: jest.fn().mockResolvedValue({
id: 'user-id',
keycloakSub: 'test-sub',
}),
};
mockJwtVerificationService = {
extractToken: jest.fn((handshake) => handshake.auth?.token),
verifyToken: jest.fn().mockResolvedValue({
sub: 'test-sub',
email: 'test@example.com',
}),
};
const module: TestingModule = await Test.createTestingModule({
providers: [StateGateway],
providers: [
StateGateway,
{ provide: AuthService, useValue: mockAuthService },
{
provide: JwtVerificationService,
useValue: mockJwtVerificationService,
},
],
}).compile();
gateway = module.get<StateGateway>(StateGateway);
gateway.io = mockServer;
gateway.io = mockServer as any;
// Spy on logger methods
mockLoggerLog = jest
// gateway is private in `state.gateway.ts` so we have to
// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access
.spyOn((gateway as any).logger, 'log')
.mockImplementation();
mockLoggerLog = jest.spyOn(gateway['logger'], 'log').mockImplementation();
mockLoggerDebug = jest
// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access
.spyOn((gateway as any).logger, 'debug')
.spyOn(gateway['logger'], 'debug')
.mockImplementation();
mockLoggerWarn = jest.spyOn(gateway['logger'], 'warn').mockImplementation();
});
afterEach(() => {
@@ -53,45 +86,127 @@ describe('StateGateway', () => {
});
describe('handleConnection', () => {
it('should log client connection and number of connected clients', () => {
const mockClient = { id: 'client1' } as Socket;
it('should log client connection and sync user when authenticated', async () => {
const mockClient: MockSocket = {
id: 'client1',
data: { user: { keycloakSub: 'test-sub' } },
handshake: {
auth: { token: 'mock-token' },
headers: {},
},
};
gateway.handleConnection(mockClient);
await gateway.handleConnection(
mockClient as unknown as AuthenticatedSocket,
);
expect(mockJwtVerificationService.extractToken).toHaveBeenCalledWith(
mockClient.handshake,
);
expect(mockJwtVerificationService.verifyToken).toHaveBeenCalledWith(
'mock-token',
);
expect(mockAuthService.syncUserFromToken).toHaveBeenCalledWith(
expect.objectContaining({
keycloakSub: 'test-sub',
}),
);
expect(mockLoggerLog).toHaveBeenCalledWith(
`Client id: ${mockClient.id} connected`,
`Client id: ${mockClient.id} connected (user: test-sub)`,
);
expect(mockLoggerDebug).toHaveBeenCalledWith(
'Number of connected clients: 5',
);
});
it('should disconnect client when no token provided', async () => {
const mockClient: MockSocket = {
id: 'client1',
data: {},
handshake: {
auth: {},
headers: {},
},
disconnect: jest.fn(),
};
(mockJwtVerificationService.extractToken as jest.Mock).mockReturnValue(
undefined,
);
await gateway.handleConnection(
mockClient as unknown as AuthenticatedSocket,
);
expect(mockLoggerWarn).toHaveBeenCalledWith(
'WebSocket connection attempt without token',
);
expect(mockClient.disconnect).toHaveBeenCalled();
});
});
describe('handleDisconnect', () => {
it('should log client disconnection', () => {
const mockClient = { id: 'client1' } as Socket;
const mockClient: MockSocket = {
id: 'client1',
data: { user: { keycloakSub: 'test-sub' } },
};
gateway.handleDisconnect(mockClient);
gateway.handleDisconnect(mockClient as unknown as AuthenticatedSocket);
expect(mockLoggerLog).toHaveBeenCalledWith(
`Cliend id:${mockClient.id} disconnected`,
`Client id: ${mockClient.id} disconnected (user: test-sub)`,
);
});
it('should handle disconnection when no user data', () => {
const mockClient: MockSocket = {
id: 'client1',
data: {},
};
gateway.handleDisconnect(mockClient as unknown as AuthenticatedSocket);
expect(mockLoggerLog).toHaveBeenCalledWith(
`Client id: ${mockClient.id} disconnected (user: unknown)`,
);
});
});
describe('handleCursorReportPosition', () => {
it('should log message received from client', () => {
const mockClient = { id: 'client1' } as Socket;
it('should log message received from authenticated client', () => {
const mockClient: MockSocket = {
id: 'client1',
data: { user: { keycloakSub: 'test-sub' } },
};
const data = { x: 100, y: 200 };
gateway.handleCursorReportPosition(mockClient, data);
gateway.handleCursorReportPosition(
mockClient as unknown as AuthenticatedSocket,
data,
);
expect(mockLoggerLog).toHaveBeenCalledWith(
`Message received from client id: ${mockClient.id}`,
`Message received from client id: ${mockClient.id} (user: test-sub)`,
);
expect(mockLoggerDebug).toHaveBeenCalledWith(
`Payload: ${JSON.stringify(data, null, 0)}`,
);
});
it('should throw exception when client is not authenticated', () => {
const mockClient: MockSocket = {
id: 'client1',
data: {},
};
const data = { x: 100, y: 200 };
expect(() => {
gateway.handleCursorReportPosition(
mockClient as unknown as AuthenticatedSocket,
data,
);
}).toThrow('Unauthorized');
});
});
});

View File

@@ -6,11 +6,25 @@ import {
SubscribeMessage,
WebSocketGateway,
WebSocketServer,
WsException,
} from '@nestjs/websockets';
import { Server, Socket } from 'socket.io';
import type { Server } from 'socket.io';
import type { AuthenticatedSocket } from '../../types/socket';
import { AuthService } from '../../auth/auth.service';
import { JwtVerificationService } from '../../auth/services/jwt-verification.service';
import { CursorPositionDto } from '../dto/cursor-position.dto';
@WebSocketGateway()
const WS_EVENT = {
CURSOR_REPORT_POSITION: 'cursor-report-position',
} as const;
@WebSocketGateway({
cors: {
origin: true,
credentials: true,
},
})
export class StateGateway
implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect
{
@@ -18,24 +32,84 @@ export class StateGateway
@WebSocketServer() io: Server;
constructor(
private readonly authService: AuthService,
private readonly jwtVerificationService: JwtVerificationService,
) {}
afterInit() {
this.logger.log('Initialized');
}
handleConnection(client: Socket) {
const { sockets } = this.io.sockets;
async handleConnection(client: AuthenticatedSocket) {
try {
this.logger.debug(
`Connection attempt - handshake auth: ${JSON.stringify(client.handshake.auth)}`,
);
this.logger.debug(
`Connection attempt - handshake headers: ${JSON.stringify(client.handshake.headers)}`,
);
this.logger.log(`Client id: ${client.id} connected`);
this.logger.debug(`Number of connected clients: ${sockets.size}`);
const token = this.jwtVerificationService.extractToken(client.handshake);
if (!token) {
this.logger.warn('WebSocket connection attempt without token');
client.disconnect();
return;
}
const payload = await this.jwtVerificationService.verifyToken(token);
if (!payload.sub) {
throw new WsException('Invalid token: missing subject');
}
client.data.user = {
keycloakSub: payload.sub,
email: payload.email,
name: payload.name,
username: payload.preferred_username,
picture: payload.picture,
};
this.logger.log(`WebSocket authenticated: ${payload.sub}`);
await this.authService.syncUserFromToken(client.data.user);
const { sockets } = this.io.sockets;
this.logger.log(
`Client id: ${client.id} connected (user: ${payload.sub})`,
);
this.logger.debug(`Number of connected clients: ${sockets.size}`);
} catch (error: unknown) {
const errorMessage =
error instanceof Error ? error.message : 'Unknown error';
this.logger.error(`Connection error: ${errorMessage}`);
client.disconnect();
}
}
handleDisconnect(client: Socket) {
this.logger.log(`Cliend id:${client.id} disconnected`);
handleDisconnect(client: AuthenticatedSocket) {
const user = client.data.user;
this.logger.log(
`Client id: ${client.id} disconnected (user: ${user?.keycloakSub || 'unknown'})`,
);
}
@SubscribeMessage('cursor-report-position')
handleCursorReportPosition(client: Socket, data: any) {
this.logger.log(`Message received from client id: ${client.id}`);
@SubscribeMessage(WS_EVENT.CURSOR_REPORT_POSITION)
handleCursorReportPosition(
client: AuthenticatedSocket,
data: CursorPositionDto,
) {
const user = client.data.user;
if (!user) {
throw new WsException('Unauthorized');
}
this.logger.log(
`Message received from client id: ${client.id} (user: ${user.keycloakSub})`,
);
this.logger.debug(`Payload: ${JSON.stringify(data, null, 0)}`);
}
}