Misha Krieger-Raynauld | b933fbb | 2022-11-15 15:11:09 -0500 | [diff] [blame] | 1 | /* |
| 2 | * Copyright (C) 2022 Savoir-faire Linux Inc. |
| 3 | * |
| 4 | * This program is free software; you can redistribute it and/or modify |
| 5 | * it under the terms of the GNU Affero General Public License as |
| 6 | * published by the Free Software Foundation; either version 3 of the |
| 7 | * License, or (at your option) any later version. |
| 8 | * |
| 9 | * This program is distributed in the hope that it will be useful, |
| 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 12 | * GNU Affero General Public License for more details. |
| 13 | * |
| 14 | * You should have received a copy of the GNU Affero General Public |
| 15 | * License along with this program. If not, see |
| 16 | * <https://www.gnu.org/licenses/>. |
| 17 | */ |
| 18 | import { IncomingMessage } from 'node:http'; |
| 19 | import { Duplex } from 'node:stream'; |
| 20 | |
Misha Krieger-Raynauld | 20cf1c8 | 2022-11-23 20:26:50 -0500 | [diff] [blame] | 21 | import { WebSocketMessage, WebSocketMessageTable, WebSocketMessageType } from 'jami-web-common'; |
Misha Krieger-Raynauld | b933fbb | 2022-11-15 15:11:09 -0500 | [diff] [blame] | 22 | import log from 'loglevel'; |
| 23 | import { Service } from 'typedi'; |
| 24 | import { URL } from 'whatwg-url'; |
| 25 | import * as WebSocket from 'ws'; |
| 26 | |
| 27 | import { verifyJwt } from '../utils/jwt.js'; |
| 28 | |
Misha Krieger-Raynauld | 20cf1c8 | 2022-11-23 20:26:50 -0500 | [diff] [blame] | 29 | type WebSocketCallback<T extends WebSocketMessageType> = (accountId: string, data: WebSocketMessageTable[T]) => void; |
| 30 | |
| 31 | type WebSocketCallbacks = { |
| 32 | [key in WebSocketMessageType]: Set<WebSocketCallback<key>>; |
| 33 | }; |
| 34 | |
idillon | 6283f91 | 2023-01-18 15:43:17 -0500 | [diff] [blame] | 35 | const buildWebSocketCallbacks = (): WebSocketCallbacks => { |
| 36 | const webSocketCallback = {} as WebSocketCallbacks; |
| 37 | for (const messageType of Object.values(WebSocketMessageType)) { |
| 38 | // TODO: type this properly to prevent mistakes |
| 39 | // The end result of the function is still typed properly |
| 40 | webSocketCallback[messageType] = new Set() as any; |
| 41 | } |
| 42 | return webSocketCallback; |
| 43 | }; |
| 44 | |
Misha Krieger-Raynauld | b933fbb | 2022-11-15 15:11:09 -0500 | [diff] [blame] | 45 | @Service() |
| 46 | export class WebSocketServer { |
| 47 | private wss = new WebSocket.WebSocketServer({ noServer: true }); |
| 48 | private sockets = new Map<string, WebSocket.WebSocket[]>(); |
Misha Krieger-Raynauld | 20cf1c8 | 2022-11-23 20:26:50 -0500 | [diff] [blame] | 49 | private callbacks: WebSocketCallbacks; |
Misha Krieger-Raynauld | b933fbb | 2022-11-15 15:11:09 -0500 | [diff] [blame] | 50 | |
| 51 | constructor() { |
idillon | 6283f91 | 2023-01-18 15:43:17 -0500 | [diff] [blame] | 52 | this.callbacks = buildWebSocketCallbacks(); |
Misha Krieger-Raynauld | 20cf1c8 | 2022-11-23 20:26:50 -0500 | [diff] [blame] | 53 | |
Misha Krieger-Raynauld | b933fbb | 2022-11-15 15:11:09 -0500 | [diff] [blame] | 54 | this.wss.on('connection', (ws: WebSocket.WebSocket, _request: IncomingMessage, accountId: string) => { |
| 55 | log.info('New connection for account', accountId); |
| 56 | const accountSockets = this.sockets.get(accountId); |
| 57 | if (accountSockets) { |
| 58 | accountSockets.push(ws); |
| 59 | } else { |
| 60 | this.sockets.set(accountId, [ws]); |
| 61 | } |
| 62 | |
| 63 | ws.on('message', <T extends WebSocketMessageType>(messageString: string) => { |
| 64 | const message: WebSocketMessage<T> = JSON.parse(messageString); |
| 65 | if (message.type === undefined || message.data === undefined) { |
| 66 | log.warn('WebSocket message is not a valid WebSocketMessage (missing type or data fields)'); |
| 67 | return; |
| 68 | } |
| 69 | |
| 70 | if (!Object.values(WebSocketMessageType).includes(message.type)) { |
| 71 | log.warn(`Invalid WebSocket message type: ${message.type}`); |
| 72 | return; |
| 73 | } |
| 74 | |
| 75 | const callbacks = this.callbacks[message.type]; |
| 76 | for (const callback of callbacks) { |
Misha Krieger-Raynauld | 20cf1c8 | 2022-11-23 20:26:50 -0500 | [diff] [blame] | 77 | callback(accountId, message.data); |
Misha Krieger-Raynauld | b933fbb | 2022-11-15 15:11:09 -0500 | [diff] [blame] | 78 | } |
| 79 | }); |
| 80 | |
| 81 | ws.on('close', () => { |
| 82 | log.info('Closing connection for account', accountId); |
| 83 | const accountSockets = this.sockets.get(accountId); |
| 84 | if (accountSockets === undefined) { |
| 85 | return; |
| 86 | } |
| 87 | |
| 88 | const index = accountSockets.indexOf(ws); |
| 89 | if (index !== -1) { |
| 90 | accountSockets.splice(index, 1); |
| 91 | if (accountSockets.length === 0) { |
| 92 | this.sockets.delete(accountId); |
| 93 | } |
| 94 | } |
| 95 | }); |
| 96 | }); |
| 97 | } |
| 98 | |
| 99 | async upgrade(request: IncomingMessage, socket: Duplex, head: Buffer): Promise<void> { |
| 100 | // Do not use parseURL because it returns a URLRecord and not a URL |
| 101 | const url = new URL(request.url ?? '/', 'http://localhost/'); |
| 102 | const token = url.searchParams.get('accessToken') ?? undefined; |
| 103 | if (token === undefined) { |
| 104 | socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n'); |
| 105 | socket.destroy(); |
| 106 | return; |
| 107 | } |
| 108 | |
| 109 | try { |
| 110 | const { payload } = await verifyJwt(token); |
| 111 | const accountId = payload.accountId as string; |
| 112 | log.info('Authentication successful for account', accountId); |
| 113 | this.wss.handleUpgrade(request, socket, head, (ws) => { |
| 114 | this.wss.emit('connection', ws, request, accountId); |
| 115 | }); |
| 116 | } catch (e) { |
| 117 | log.debug('Authentication failed:', e); |
| 118 | socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n'); |
| 119 | socket.destroy(); |
| 120 | } |
| 121 | } |
| 122 | |
simon | a5c54ef | 2022-11-18 05:26:06 -0500 | [diff] [blame] | 123 | bind<T extends WebSocketMessageType>(type: T, callback: WebSocketCallback<T>): void { |
| 124 | this.callbacks[type].add(callback); |
Misha Krieger-Raynauld | b933fbb | 2022-11-15 15:11:09 -0500 | [diff] [blame] | 125 | } |
| 126 | |
| 127 | send<T extends WebSocketMessageType>(accountId: string, type: T, data: WebSocketMessageTable[T]): boolean { |
| 128 | const accountSockets = this.sockets.get(accountId); |
| 129 | if (accountSockets === undefined) { |
| 130 | return false; |
| 131 | } |
| 132 | |
| 133 | const webSocketMessageString = JSON.stringify({ type, data }); |
| 134 | for (const accountSocket of accountSockets) { |
| 135 | accountSocket.send(webSocketMessageString); |
| 136 | } |
| 137 | return true; |
| 138 | } |
| 139 | } |