blob: 6750129214726a3f6a766d2bda3645768977e433 [file] [log] [blame]
Misha Krieger-Raynauldb933fbb2022-11-15 15:11:09 -05001/*
2 * Copyright (C) 2022 Savoir-faire Linux Inc.
3 *
4 * This program is free software; you can redistribute it and/or modify
5 * it under the terms of the GNU Affero General Public License as
6 * published by the Free Software Foundation; either version 3 of the
7 * License, or (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU Affero General Public License for more details.
13 *
14 * You should have received a copy of the GNU Affero General Public
15 * License along with this program. If not, see
16 * <https://www.gnu.org/licenses/>.
17 */
18import { IncomingMessage } from 'node:http';
19import { Duplex } from 'node:stream';
20
21import { WebSocketCallbacks, WebSocketMessage, WebSocketMessageTable, WebSocketMessageType } from 'jami-web-common';
22import log from 'loglevel';
23import { Service } from 'typedi';
24import { URL } from 'whatwg-url';
25import * as WebSocket from 'ws';
26
27import { verifyJwt } from '../utils/jwt.js';
28
29@Service()
30export class WebSocketServer {
31 private wss = new WebSocket.WebSocketServer({ noServer: true });
32 private sockets = new Map<string, WebSocket.WebSocket[]>();
33 private callbacks: WebSocketCallbacks = {
34 [WebSocketMessageType.ConversationMessage]: [],
35 [WebSocketMessageType.ConversationView]: [],
36 [WebSocketMessageType.WebRTCOffer]: [],
37 [WebSocketMessageType.WebRTCAnswer]: [],
38 [WebSocketMessageType.IceCandidate]: [],
39 };
40
41 constructor() {
42 this.wss.on('connection', (ws: WebSocket.WebSocket, _request: IncomingMessage, accountId: string) => {
43 log.info('New connection for account', accountId);
44 const accountSockets = this.sockets.get(accountId);
45 if (accountSockets) {
46 accountSockets.push(ws);
47 } else {
48 this.sockets.set(accountId, [ws]);
49 }
50
51 ws.on('message', <T extends WebSocketMessageType>(messageString: string) => {
52 const message: WebSocketMessage<T> = JSON.parse(messageString);
53 if (message.type === undefined || message.data === undefined) {
54 log.warn('WebSocket message is not a valid WebSocketMessage (missing type or data fields)');
55 return;
56 }
57
58 if (!Object.values(WebSocketMessageType).includes(message.type)) {
59 log.warn(`Invalid WebSocket message type: ${message.type}`);
60 return;
61 }
62
63 const callbacks = this.callbacks[message.type];
64 for (const callback of callbacks) {
65 callback(message.data);
66 }
67 });
68
69 ws.on('close', () => {
70 log.info('Closing connection for account', accountId);
71 const accountSockets = this.sockets.get(accountId);
72 if (accountSockets === undefined) {
73 return;
74 }
75
76 const index = accountSockets.indexOf(ws);
77 if (index !== -1) {
78 accountSockets.splice(index, 1);
79 if (accountSockets.length === 0) {
80 this.sockets.delete(accountId);
81 }
82 }
83 });
84 });
85 }
86
87 async upgrade(request: IncomingMessage, socket: Duplex, head: Buffer): Promise<void> {
88 // Do not use parseURL because it returns a URLRecord and not a URL
89 const url = new URL(request.url ?? '/', 'http://localhost/');
90 const token = url.searchParams.get('accessToken') ?? undefined;
91 if (token === undefined) {
92 socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n');
93 socket.destroy();
94 return;
95 }
96
97 try {
98 const { payload } = await verifyJwt(token);
99 const accountId = payload.accountId as string;
100 log.info('Authentication successful for account', accountId);
101 this.wss.handleUpgrade(request, socket, head, (ws) => {
102 this.wss.emit('connection', ws, request, accountId);
103 });
104 } catch (e) {
105 log.debug('Authentication failed:', e);
106 socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n');
107 socket.destroy();
108 }
109 }
110
111 bind<T extends WebSocketMessageType>(type: T, callback: (data: WebSocketMessageTable[T]) => void): void {
112 this.callbacks[type].push(callback);
113 }
114
115 send<T extends WebSocketMessageType>(accountId: string, type: T, data: WebSocketMessageTable[T]): boolean {
116 const accountSockets = this.sockets.get(accountId);
117 if (accountSockets === undefined) {
118 return false;
119 }
120
121 const webSocketMessageString = JSON.stringify({ type, data });
122 for (const accountSocket of accountSockets) {
123 accountSocket.send(webSocketMessageString);
124 }
125 return true;
126 }
127}