blob: c1b99044c4cb685b111205f3566b4a706f95524b [file] [log] [blame]
Misha Krieger-Raynauldb933fbb2022-11-15 15:11:09 -05001/*
2 * Copyright (C) 2022 Savoir-faire Linux Inc.
3 *
4 * This program is free software; you can redistribute it and/or modify
5 * it under the terms of the GNU Affero General Public License as
6 * published by the Free Software Foundation; either version 3 of the
7 * License, or (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU Affero General Public License for more details.
13 *
14 * You should have received a copy of the GNU Affero General Public
15 * License along with this program. If not, see
16 * <https://www.gnu.org/licenses/>.
17 */
18import { IncomingMessage } from 'node:http';
19import { Duplex } from 'node:stream';
20
Misha Krieger-Raynauld20cf1c82022-11-23 20:26:50 -050021import { WebSocketMessage, WebSocketMessageTable, WebSocketMessageType } from 'jami-web-common';
Misha Krieger-Raynauldb933fbb2022-11-15 15:11:09 -050022import log from 'loglevel';
23import { Service } from 'typedi';
24import { URL } from 'whatwg-url';
25import * as WebSocket from 'ws';
26
27import { verifyJwt } from '../utils/jwt.js';
28
Misha Krieger-Raynauld20cf1c82022-11-23 20:26:50 -050029type WebSocketCallback<T extends WebSocketMessageType> = (accountId: string, data: WebSocketMessageTable[T]) => void;
30
31type WebSocketCallbacks = {
32 [key in WebSocketMessageType]: Set<WebSocketCallback<key>>;
33};
34
Misha Krieger-Raynauldb933fbb2022-11-15 15:11:09 -050035@Service()
36export class WebSocketServer {
37 private wss = new WebSocket.WebSocketServer({ noServer: true });
38 private sockets = new Map<string, WebSocket.WebSocket[]>();
Misha Krieger-Raynauld20cf1c82022-11-23 20:26:50 -050039 private callbacks: WebSocketCallbacks;
Misha Krieger-Raynauldb933fbb2022-11-15 15:11:09 -050040
41 constructor() {
Misha Krieger-Raynauld20cf1c82022-11-23 20:26:50 -050042 this.callbacks = {} as WebSocketCallbacks;
43 for (const messageType of Object.values(WebSocketMessageType)) {
44 this.callbacks[messageType] = new Set<WebSocketCallback<typeof messageType>>();
45 }
46
Misha Krieger-Raynauldb933fbb2022-11-15 15:11:09 -050047 this.wss.on('connection', (ws: WebSocket.WebSocket, _request: IncomingMessage, accountId: string) => {
48 log.info('New connection for account', accountId);
49 const accountSockets = this.sockets.get(accountId);
50 if (accountSockets) {
51 accountSockets.push(ws);
52 } else {
53 this.sockets.set(accountId, [ws]);
54 }
55
56 ws.on('message', <T extends WebSocketMessageType>(messageString: string) => {
57 const message: WebSocketMessage<T> = JSON.parse(messageString);
58 if (message.type === undefined || message.data === undefined) {
59 log.warn('WebSocket message is not a valid WebSocketMessage (missing type or data fields)');
60 return;
61 }
62
63 if (!Object.values(WebSocketMessageType).includes(message.type)) {
64 log.warn(`Invalid WebSocket message type: ${message.type}`);
65 return;
66 }
67
68 const callbacks = this.callbacks[message.type];
69 for (const callback of callbacks) {
Misha Krieger-Raynauld20cf1c82022-11-23 20:26:50 -050070 callback(accountId, message.data);
Misha Krieger-Raynauldb933fbb2022-11-15 15:11:09 -050071 }
72 });
73
74 ws.on('close', () => {
75 log.info('Closing connection for account', accountId);
76 const accountSockets = this.sockets.get(accountId);
77 if (accountSockets === undefined) {
78 return;
79 }
80
81 const index = accountSockets.indexOf(ws);
82 if (index !== -1) {
83 accountSockets.splice(index, 1);
84 if (accountSockets.length === 0) {
85 this.sockets.delete(accountId);
86 }
87 }
88 });
89 });
90 }
91
92 async upgrade(request: IncomingMessage, socket: Duplex, head: Buffer): Promise<void> {
93 // Do not use parseURL because it returns a URLRecord and not a URL
94 const url = new URL(request.url ?? '/', 'http://localhost/');
95 const token = url.searchParams.get('accessToken') ?? undefined;
96 if (token === undefined) {
97 socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n');
98 socket.destroy();
99 return;
100 }
101
102 try {
103 const { payload } = await verifyJwt(token);
104 const accountId = payload.accountId as string;
105 log.info('Authentication successful for account', accountId);
106 this.wss.handleUpgrade(request, socket, head, (ws) => {
107 this.wss.emit('connection', ws, request, accountId);
108 });
109 } catch (e) {
110 log.debug('Authentication failed:', e);
111 socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n');
112 socket.destroy();
113 }
114 }
115
simona5c54ef2022-11-18 05:26:06 -0500116 bind<T extends WebSocketMessageType>(type: T, callback: WebSocketCallback<T>): void {
117 this.callbacks[type].add(callback);
Misha Krieger-Raynauldb933fbb2022-11-15 15:11:09 -0500118 }
119
120 send<T extends WebSocketMessageType>(accountId: string, type: T, data: WebSocketMessageTable[T]): boolean {
121 const accountSockets = this.sockets.get(accountId);
122 if (accountSockets === undefined) {
123 return false;
124 }
125
126 const webSocketMessageString = JSON.stringify({ type, data });
127 for (const accountSocket of accountSockets) {
128 accountSocket.send(webSocketMessageString);
129 }
130 return true;
131 }
132}