Misha Krieger-Raynauld | b933fbb | 2022-11-15 15:11:09 -0500 | [diff] [blame] | 1 | /* |
| 2 | * Copyright (C) 2022 Savoir-faire Linux Inc. |
| 3 | * |
| 4 | * This program is free software; you can redistribute it and/or modify |
| 5 | * it under the terms of the GNU Affero General Public License as |
| 6 | * published by the Free Software Foundation; either version 3 of the |
| 7 | * License, or (at your option) any later version. |
| 8 | * |
| 9 | * This program is distributed in the hope that it will be useful, |
| 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 12 | * GNU Affero General Public License for more details. |
| 13 | * |
| 14 | * You should have received a copy of the GNU Affero General Public |
| 15 | * License along with this program. If not, see |
| 16 | * <https://www.gnu.org/licenses/>. |
| 17 | */ |
| 18 | import { IncomingMessage } from 'node:http'; |
| 19 | import { Duplex } from 'node:stream'; |
| 20 | |
simon | a5c54ef | 2022-11-18 05:26:06 -0500 | [diff] [blame] | 21 | import { |
| 22 | buildWebSocketCallbacks, |
| 23 | WebSocketCallback, |
| 24 | WebSocketCallbacks, |
| 25 | WebSocketMessage, |
| 26 | WebSocketMessageTable, |
| 27 | WebSocketMessageType, |
| 28 | } from 'jami-web-common'; |
Misha Krieger-Raynauld | b933fbb | 2022-11-15 15:11:09 -0500 | [diff] [blame] | 29 | import log from 'loglevel'; |
| 30 | import { Service } from 'typedi'; |
| 31 | import { URL } from 'whatwg-url'; |
| 32 | import * as WebSocket from 'ws'; |
| 33 | |
| 34 | import { verifyJwt } from '../utils/jwt.js'; |
| 35 | |
| 36 | @Service() |
| 37 | export class WebSocketServer { |
| 38 | private wss = new WebSocket.WebSocketServer({ noServer: true }); |
| 39 | private sockets = new Map<string, WebSocket.WebSocket[]>(); |
simon | a5c54ef | 2022-11-18 05:26:06 -0500 | [diff] [blame] | 40 | private callbacks: WebSocketCallbacks = buildWebSocketCallbacks(); |
Misha Krieger-Raynauld | b933fbb | 2022-11-15 15:11:09 -0500 | [diff] [blame] | 41 | |
| 42 | constructor() { |
| 43 | this.wss.on('connection', (ws: WebSocket.WebSocket, _request: IncomingMessage, accountId: string) => { |
| 44 | log.info('New connection for account', accountId); |
| 45 | const accountSockets = this.sockets.get(accountId); |
| 46 | if (accountSockets) { |
| 47 | accountSockets.push(ws); |
| 48 | } else { |
| 49 | this.sockets.set(accountId, [ws]); |
| 50 | } |
| 51 | |
| 52 | ws.on('message', <T extends WebSocketMessageType>(messageString: string) => { |
| 53 | const message: WebSocketMessage<T> = JSON.parse(messageString); |
| 54 | if (message.type === undefined || message.data === undefined) { |
| 55 | log.warn('WebSocket message is not a valid WebSocketMessage (missing type or data fields)'); |
| 56 | return; |
| 57 | } |
| 58 | |
| 59 | if (!Object.values(WebSocketMessageType).includes(message.type)) { |
| 60 | log.warn(`Invalid WebSocket message type: ${message.type}`); |
| 61 | return; |
| 62 | } |
| 63 | |
| 64 | const callbacks = this.callbacks[message.type]; |
| 65 | for (const callback of callbacks) { |
| 66 | callback(message.data); |
| 67 | } |
| 68 | }); |
| 69 | |
| 70 | ws.on('close', () => { |
| 71 | log.info('Closing connection for account', accountId); |
| 72 | const accountSockets = this.sockets.get(accountId); |
| 73 | if (accountSockets === undefined) { |
| 74 | return; |
| 75 | } |
| 76 | |
| 77 | const index = accountSockets.indexOf(ws); |
| 78 | if (index !== -1) { |
| 79 | accountSockets.splice(index, 1); |
| 80 | if (accountSockets.length === 0) { |
| 81 | this.sockets.delete(accountId); |
| 82 | } |
| 83 | } |
| 84 | }); |
| 85 | }); |
| 86 | } |
| 87 | |
| 88 | async upgrade(request: IncomingMessage, socket: Duplex, head: Buffer): Promise<void> { |
| 89 | // Do not use parseURL because it returns a URLRecord and not a URL |
| 90 | const url = new URL(request.url ?? '/', 'http://localhost/'); |
| 91 | const token = url.searchParams.get('accessToken') ?? undefined; |
| 92 | if (token === undefined) { |
| 93 | socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n'); |
| 94 | socket.destroy(); |
| 95 | return; |
| 96 | } |
| 97 | |
| 98 | try { |
| 99 | const { payload } = await verifyJwt(token); |
| 100 | const accountId = payload.accountId as string; |
| 101 | log.info('Authentication successful for account', accountId); |
| 102 | this.wss.handleUpgrade(request, socket, head, (ws) => { |
| 103 | this.wss.emit('connection', ws, request, accountId); |
| 104 | }); |
| 105 | } catch (e) { |
| 106 | log.debug('Authentication failed:', e); |
| 107 | socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n'); |
| 108 | socket.destroy(); |
| 109 | } |
| 110 | } |
| 111 | |
simon | a5c54ef | 2022-11-18 05:26:06 -0500 | [diff] [blame] | 112 | bind<T extends WebSocketMessageType>(type: T, callback: WebSocketCallback<T>): void { |
| 113 | this.callbacks[type].add(callback); |
Misha Krieger-Raynauld | b933fbb | 2022-11-15 15:11:09 -0500 | [diff] [blame] | 114 | } |
| 115 | |
| 116 | send<T extends WebSocketMessageType>(accountId: string, type: T, data: WebSocketMessageTable[T]): boolean { |
| 117 | const accountSockets = this.sockets.get(accountId); |
| 118 | if (accountSockets === undefined) { |
| 119 | return false; |
| 120 | } |
| 121 | |
| 122 | const webSocketMessageString = JSON.stringify({ type, data }); |
| 123 | for (const accountSocket of accountSockets) { |
| 124 | accountSocket.send(webSocketMessageString); |
| 125 | } |
| 126 | return true; |
| 127 | } |
| 128 | } |