blob: 310dbd5b6de0e71da4db18904d19fb7417417377 [file] [log] [blame]
Misha Krieger-Raynauldb933fbb2022-11-15 15:11:09 -05001/*
2 * Copyright (C) 2022 Savoir-faire Linux Inc.
3 *
4 * This program is free software; you can redistribute it and/or modify
5 * it under the terms of the GNU Affero General Public License as
6 * published by the Free Software Foundation; either version 3 of the
7 * License, or (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU Affero General Public License for more details.
13 *
14 * You should have received a copy of the GNU Affero General Public
15 * License along with this program. If not, see
16 * <https://www.gnu.org/licenses/>.
17 */
18import { IncomingMessage } from 'node:http';
19import { Duplex } from 'node:stream';
20
Misha Krieger-Raynauld20cf1c82022-11-23 20:26:50 -050021import { WebSocketMessage, WebSocketMessageTable, WebSocketMessageType } from 'jami-web-common';
Misha Krieger-Raynauldb933fbb2022-11-15 15:11:09 -050022import log from 'loglevel';
23import { Service } from 'typedi';
24import { URL } from 'whatwg-url';
25import * as WebSocket from 'ws';
26
27import { verifyJwt } from '../utils/jwt.js';
28
Misha Krieger-Raynauld20cf1c82022-11-23 20:26:50 -050029type WebSocketCallback<T extends WebSocketMessageType> = (accountId: string, data: WebSocketMessageTable[T]) => void;
30
31type WebSocketCallbacks = {
32 [key in WebSocketMessageType]: Set<WebSocketCallback<key>>;
33};
34
idillon6283f912023-01-18 15:43:17 -050035const buildWebSocketCallbacks = (): WebSocketCallbacks => {
36 const webSocketCallback = {} as WebSocketCallbacks;
37 for (const messageType of Object.values(WebSocketMessageType)) {
38 // TODO: type this properly to prevent mistakes
39 // The end result of the function is still typed properly
40 webSocketCallback[messageType] = new Set() as any;
41 }
42 return webSocketCallback;
43};
44
Misha Krieger-Raynauldb933fbb2022-11-15 15:11:09 -050045@Service()
46export class WebSocketServer {
47 private wss = new WebSocket.WebSocketServer({ noServer: true });
48 private sockets = new Map<string, WebSocket.WebSocket[]>();
Misha Krieger-Raynauld20cf1c82022-11-23 20:26:50 -050049 private callbacks: WebSocketCallbacks;
Misha Krieger-Raynauldb933fbb2022-11-15 15:11:09 -050050
51 constructor() {
idillon6283f912023-01-18 15:43:17 -050052 this.callbacks = buildWebSocketCallbacks();
Misha Krieger-Raynauld20cf1c82022-11-23 20:26:50 -050053
Misha Krieger-Raynauldb933fbb2022-11-15 15:11:09 -050054 this.wss.on('connection', (ws: WebSocket.WebSocket, _request: IncomingMessage, accountId: string) => {
55 log.info('New connection for account', accountId);
56 const accountSockets = this.sockets.get(accountId);
57 if (accountSockets) {
58 accountSockets.push(ws);
59 } else {
60 this.sockets.set(accountId, [ws]);
61 }
62
63 ws.on('message', <T extends WebSocketMessageType>(messageString: string) => {
64 const message: WebSocketMessage<T> = JSON.parse(messageString);
65 if (message.type === undefined || message.data === undefined) {
66 log.warn('WebSocket message is not a valid WebSocketMessage (missing type or data fields)');
67 return;
68 }
69
70 if (!Object.values(WebSocketMessageType).includes(message.type)) {
71 log.warn(`Invalid WebSocket message type: ${message.type}`);
72 return;
73 }
74
75 const callbacks = this.callbacks[message.type];
76 for (const callback of callbacks) {
Misha Krieger-Raynauld20cf1c82022-11-23 20:26:50 -050077 callback(accountId, message.data);
Misha Krieger-Raynauldb933fbb2022-11-15 15:11:09 -050078 }
79 });
80
81 ws.on('close', () => {
82 log.info('Closing connection for account', accountId);
83 const accountSockets = this.sockets.get(accountId);
84 if (accountSockets === undefined) {
85 return;
86 }
87
88 const index = accountSockets.indexOf(ws);
89 if (index !== -1) {
90 accountSockets.splice(index, 1);
91 if (accountSockets.length === 0) {
92 this.sockets.delete(accountId);
93 }
94 }
95 });
96 });
97 }
98
99 async upgrade(request: IncomingMessage, socket: Duplex, head: Buffer): Promise<void> {
100 // Do not use parseURL because it returns a URLRecord and not a URL
101 const url = new URL(request.url ?? '/', 'http://localhost/');
102 const token = url.searchParams.get('accessToken') ?? undefined;
103 if (token === undefined) {
104 socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n');
105 socket.destroy();
106 return;
107 }
108
109 try {
110 const { payload } = await verifyJwt(token);
111 const accountId = payload.accountId as string;
112 log.info('Authentication successful for account', accountId);
113 this.wss.handleUpgrade(request, socket, head, (ws) => {
114 this.wss.emit('connection', ws, request, accountId);
115 });
116 } catch (e) {
117 log.debug('Authentication failed:', e);
118 socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n');
119 socket.destroy();
120 }
121 }
122
simona5c54ef2022-11-18 05:26:06 -0500123 bind<T extends WebSocketMessageType>(type: T, callback: WebSocketCallback<T>): void {
124 this.callbacks[type].add(callback);
Misha Krieger-Raynauldb933fbb2022-11-15 15:11:09 -0500125 }
126
127 send<T extends WebSocketMessageType>(accountId: string, type: T, data: WebSocketMessageTable[T]): boolean {
128 const accountSockets = this.sockets.get(accountId);
129 if (accountSockets === undefined) {
130 return false;
131 }
132
133 const webSocketMessageString = JSON.stringify({ type, data });
134 for (const accountSocket of accountSockets) {
135 accountSocket.send(webSocketMessageString);
136 }
137 return true;
138 }
139}