CppServer 1.0.6.0
C++ Server Library
Loading...
Searching...
No Matches
ws.cpp
Go to the documentation of this file.
1
8
9#include "server/ws/ws.h"
10
11#include "string/encoding.h"
12#include "string/format.h"
13#include "string/string_utils.h"
14
15#include <algorithm>
16#include <openssl/sha.h>
17
18namespace CppServer {
19namespace WS {
20
22{
23 std::generate(_ws_nonce.begin(), _ws_nonce.end(), []() { return (uint8_t)std::rand(); });
24}
25
26bool WebSocket::PerformClientUpgrade(const HTTP::HTTPResponse& response, const CppCommon::UUID& id)
27{
28 if (response.status() != 101)
29 return false;
30
31 bool error = false;
32 bool accept = false;
33 bool connection = false;
34 bool upgrade = false;
35
36 // Validate WebSocket handshake headers
37 for (size_t i = 0; i < response.headers(); ++i)
38 {
39 auto header = response.header(i);
40 auto key = std::get<0>(header);
41 auto value = std::get<1>(header);
42
43 if (CppCommon::StringUtils::CompareNoCase(key, "Connection"))
44 {
45 if (!CppCommon::StringUtils::CompareNoCase(value, "Upgrade"))
46 {
47 error = true;
48 onWSError("Invalid WebSocket handshaked response: 'Connection' header value must be 'Upgrade'");
49 break;
50 }
51
52 connection = true;
53 }
54 else if (CppCommon::StringUtils::CompareNoCase(key, "Upgrade"))
55 {
56 if (!CppCommon::StringUtils::CompareNoCase(value, "websocket"))
57 {
58 error = true;
59 onWSError("Invalid WebSocket handshaked response: 'Upgrade' header value must be 'websocket'");
60 break;
61 }
62
63 upgrade = true;
64 }
65 else if (CppCommon::StringUtils::CompareNoCase(key, "Sec-WebSocket-Accept"))
66 {
67 // Calculate the original WebSocket hash
68 std::string wskey = CppCommon::Encoding::Base64Encode(ws_nonce()) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
69 char wshash[SHA_DIGEST_LENGTH];
70 SHA1((const unsigned char*)wskey.data(), wskey.size(), (unsigned char*)wshash);
71
72 // Get the received WebSocket hash
73 wskey = CppCommon::Encoding::Base64Decode(value);
74
75 // Compare original and received hashes
76 if (std::strncmp(wskey.data(), wshash, std::min(wskey.size(), sizeof(wshash))) != 0)
77 {
78 error = true;
79 onWSError("Invalid WebSocket handshaked response: 'Sec-WebSocket-Accept' value validation failed");
80 break;
81 }
82
83 accept = true;
84 }
85 }
86
87 // Failed to perform WebSocket handshake
88 if (!accept || !connection || !upgrade)
89 {
90 if (!error)
91 onWSError("Invalid WebSocket response");
92 return false;
93 }
94
95 // WebSocket successfully handshaked!
96 _ws_handshaked = true;
97 *((uint32_t*)_ws_send_mask) = rand();
98 onWSConnected(response);
99
100 return true;
101}
102
104{
105 if (request.method() != "GET")
106 return false;
107
108 bool error = false;
109 bool connection = false;
110 bool upgrade = false;
111 bool ws_key = false;
112 bool ws_version = false;
113
114 std::string accept;
115
116 // Validate WebSocket handshake headers
117 for (size_t i = 0; i < request.headers(); ++i)
118 {
119 auto header = request.header(i);
120 auto key = std::get<0>(header);
121 auto value = std::get<1>(header);
122
123 if (CppCommon::StringUtils::CompareNoCase(key, "Connection"))
124 {
125 if (!CppCommon::StringUtils::CompareNoCase(value, "Upgrade") && !CppCommon::StringUtils::CompareNoCase(CppCommon::StringUtils::RemoveBlank(value), "keep-alive,Upgrade"))
126 {
127 error = true;
128 response.MakeErrorResponse(400, "Invalid WebSocket handshaked request: 'Connection' header value must be 'Upgrade' or 'keep-alive, Upgrade'");
129 break;
130 }
131
132 connection = true;
133 }
134 else if (CppCommon::StringUtils::CompareNoCase(key, "Upgrade"))
135 {
136 if (!CppCommon::StringUtils::CompareNoCase(value, "websocket"))
137 {
138 error = true;
139 response.MakeErrorResponse(400, "Invalid WebSocket handshaked request: 'Upgrade' header value must be 'websocket'");
140 break;
141 }
142
143 upgrade = true;
144 }
145 else if (CppCommon::StringUtils::CompareNoCase(key, "Sec-WebSocket-Key"))
146 {
147 if (value.empty())
148 {
149 error = true;
150 response.MakeErrorResponse(400, "Invalid WebSocket handshaked request: 'Sec-WebSocket-Key' header value must be non empty");
151 break;
152 }
153
154 // Calculate WebSocket accept value
155 std::string wskey = std::string(value) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
156 char wshash[SHA_DIGEST_LENGTH];
157 SHA1((const unsigned char*)wskey.data(), wskey.size(), (unsigned char*)wshash);
158
159 accept = CppCommon::Encoding::Base64Encode(std::string(wshash, sizeof(wshash)));
160
161 ws_key = true;
162 }
163 else if (CppCommon::StringUtils::CompareNoCase(key, "Sec-WebSocket-Version"))
164 {
165 if (!CppCommon::StringUtils::CompareNoCase(value, "13"))
166 {
167 error = true;
168 response.MakeErrorResponse(400, "Invalid WebSocket handshaked request: 'Sec-WebSocket-Version' header value must be '13'");
169 break;
170 }
171
172 ws_version = true;
173 }
174 }
175
176 // Filter out non WebSocket handshake requests
177 if (!connection && !upgrade && !ws_key && !ws_version)
178 return false;
179
180 // Failed to perform WebSocket handshake
181 if (!connection || !upgrade || !ws_key || !ws_version)
182 {
183 if (!error)
184 response.MakeErrorResponse(400, "Invalid WebSocket response");
185 SendResponse(response);
186 return false;
187 }
188
189 // Prepare WebSocket upgrade success response
190 response.Clear();
191 response.SetBegin(101, "HTTP/1.1");
192 response.SetHeader("Connection", "Upgrade");
193 response.SetHeader("Upgrade", "websocket");
194 response.SetHeader("Sec-WebSocket-Accept", accept);
195
196 // Validate WebSocket upgrade request and response
197 if (!onWSConnecting(request, response))
198 return false;
199
200 // Set body of the WebSocket upgrade response
201 response.SetBody();
202
203 // Send WebSocket upgrade response
204 SendResponse(response);
205
206 // WebSocket successfully handshaked!
207 _ws_handshaked = true;
208 *((uint32_t*)_ws_send_mask) = 0;
209 onWSConnected(request);
210
211 return true;
212}
213
214void WebSocket::PrepareSendFrame(uint8_t opcode, bool mask, const void* buffer, size_t size, int status)
215{
216 // Check if we need to store additional 2 bytes of close status frame
217 bool store_status = ((opcode & WS_CLOSE) == WS_CLOSE) && ((size > 0) || (status != 0));
218 if (store_status)
219 size += 2;
220
221 // Clear the previous WebSocket send buffer
222 _ws_send_buffer.clear();
223
224 // Append WebSocket frame opcode
225 _ws_send_buffer.push_back(opcode);
226
227 // Append WebSocket frame size
228 if (size <= 125)
229 _ws_send_buffer.push_back((size & 0xFF) | (mask ? 0x80 : 0));
230 else if (size <= 65535)
231 {
232 _ws_send_buffer.push_back(126 | (mask ? 0x80 : 0));
233 _ws_send_buffer.push_back((size >> 8) & 0xFF);
234 _ws_send_buffer.push_back(size & 0xFF);
235 }
236 else
237 {
238 _ws_send_buffer.push_back(127 | (mask ? 0x80 : 0));
239 for (int i = 7; i >= 0; --i)
240 _ws_send_buffer.push_back((size >> (8 * i)) & 0xFF);
241 }
242
243 if (mask)
244 {
245 // Append WebSocket frame mask
246 _ws_send_buffer.push_back(_ws_send_mask[0]);
247 _ws_send_buffer.push_back(_ws_send_mask[1]);
248 _ws_send_buffer.push_back(_ws_send_mask[2]);
249 _ws_send_buffer.push_back(_ws_send_mask[3]);
250 }
251
252 // Resize WebSocket frame buffer
253 size_t offset = _ws_send_buffer.size();
254 _ws_send_buffer.resize(offset + size);
255
256 size_t index = 0;
257 const uint8_t* data = (const uint8_t*)buffer;
258
259 // Append WebSocket close status
260 // RFC 6455: If there is a body, the first two bytes of the body MUST
261 // be a 2-byte unsigned integer (in network byte order) representing
262 // a status code with value code.
263 if (store_status)
264 {
265 index += 2;
266 _ws_send_buffer[offset + 0] = ((status >> 8) & 0xFF) ^ _ws_send_mask[0];
267 _ws_send_buffer[offset + 1] = (status & 0xFF) ^ _ws_send_mask[1];
268 }
269
270 // Mask WebSocket frame content
271 for (size_t i = index; i < size; ++i)
272 _ws_send_buffer[offset + i] = data[i - index] ^ _ws_send_mask[i % 4];
273}
274
275void WebSocket::PrepareReceiveFrame(const void* buffer, size_t size)
276{
277 const uint8_t* data = (const uint8_t*)buffer;
278
279 // Clear received data after WebSocket frame was processed
281 {
282 _ws_frame_received = false;
283 _ws_header_size = 0;
286 *((uint32_t*)_ws_receive_mask) = 0;
287 }
289 {
290 _ws_final_received = false;
292 }
293
294 while (size > 0)
295 {
296 // Clear received data after WebSocket frame was processed
298 {
299 _ws_frame_received = false;
300 _ws_header_size = 0;
303 *((uint32_t*)_ws_receive_mask) = 0;
304 }
306 {
307 _ws_final_received = false;
309 }
310
311 // Prepare WebSocket frame opcode and mask flag
312 if (_ws_receive_frame_buffer.size() < 2)
313 {
314 for (size_t i = 0; i < 2; ++i, ++data, --size)
315 {
316 if (size == 0)
317 return;
318 _ws_receive_frame_buffer.push_back(*data);
319 }
320 }
321
322 uint8_t opcode = _ws_receive_frame_buffer[0] & 0x0F;
323 bool fin = ((_ws_receive_frame_buffer[0] >> 7) & 0x01) != 0;
324 bool mask = ((_ws_receive_frame_buffer[1] >> 7) & 0x01) != 0;
325 size_t payload = (size_t)_ws_receive_frame_buffer[1] & (~0x80);
326
327 // Prepare WebSocket opcode
328 _ws_opcode = (opcode != 0) ? opcode : _ws_opcode;
329
330 // Prepare WebSocket frame size
331 if (payload <= 125)
332 {
333 _ws_header_size = 2 + (mask ? 4 : 0);
334 _ws_payload_size = payload;
337 }
338 else if (payload == 126)
339 {
340 if (_ws_receive_frame_buffer.size() < 4)
341 {
342 for (size_t i = 0; i < 2; ++i, ++data, --size)
343 {
344 if (size == 0)
345 return;
346 _ws_receive_frame_buffer.push_back(*data);
347 }
348 }
349
350 payload = (((size_t)_ws_receive_frame_buffer[2] << 8) | ((size_t)_ws_receive_frame_buffer[3] << 0));
351 _ws_header_size = 4 + (mask ? 4 : 0);
352 _ws_payload_size = payload;
355 }
356 else if (payload == 127)
357 {
358 if (_ws_receive_frame_buffer.size() < 10)
359 {
360 for (size_t i = 0; i < 8; ++i, ++data, --size)
361 {
362 if (size == 0)
363 return;
364 _ws_receive_frame_buffer.push_back(*data);
365 }
366 }
367
368 payload = (((size_t)_ws_receive_frame_buffer[2] << 56) | ((size_t)_ws_receive_frame_buffer[3] << 48) | ((size_t)_ws_receive_frame_buffer[4] << 40) | ((size_t)_ws_receive_frame_buffer[5] << 32) | ((size_t)_ws_receive_frame_buffer[6] << 24) | ((size_t)_ws_receive_frame_buffer[7] << 16) | ((size_t)_ws_receive_frame_buffer[8] << 8) | ((size_t)_ws_receive_frame_buffer[9] << 0));
369 _ws_header_size = 10 + (mask ? 4 : 0);
370 _ws_payload_size = payload;
373 }
374
375 // Prepare WebSocket frame mask
376 if (mask)
377 {
379 {
380 for (size_t i = 0; i < 4; ++i, ++data, --size)
381 {
382 if (size == 0)
383 return;
384 _ws_receive_frame_buffer.push_back(*data);
385 _ws_receive_mask[i] = *data;
386 }
387 }
388 }
389
390 size_t total = _ws_header_size + _ws_payload_size;
391 size_t length = std::min(total - _ws_receive_frame_buffer.size(), size);
392
393 // Prepare WebSocket frame payload
394 _ws_receive_frame_buffer.insert(_ws_receive_frame_buffer.end(), data, data + length);
395 data += length;
396 size -= length;
397
398 // Process WebSocket frame
399 if (_ws_receive_frame_buffer.size() == total)
400 {
401 // Unmask WebSocket frame content
402 if (mask)
403 {
404 for (size_t i = 0; i < _ws_payload_size; ++i)
406 }
407 else
409
410 _ws_frame_received = true;
411
412 // Finalize WebSocket frame
413 if (fin)
414 {
415 _ws_final_received = true;
416
417 switch (_ws_opcode)
418 {
419 case WS_PING:
420 {
421 // Call the WebSocket ping handler
423 break;
424 }
425 case WS_PONG:
426 {
427 // Call the WebSocket pong handler
429 break;
430 }
431 case WS_CLOSE:
432 {
433 size_t sindex = 0;
434 int status = 1000;
435
436 // Read WebSocket close status
437 if (_ws_receive_final_buffer.size() >= 2)
438 {
439 sindex += 2;
440 status = (((int)_ws_receive_final_buffer[0] << 8) | ((int)_ws_receive_final_buffer[1] << 0));
441 }
442
443 // Call the WebSocket close handler
444 onWSClose(_ws_receive_final_buffer.data() + sindex, _ws_receive_final_buffer.size() - sindex, status);
445 break;
446 }
447 case WS_BINARY:
448 case WS_TEXT:
449 {
450 // Call the WebSocket received handler
452 break;
453 }
454 }
455 }
456 }
457 }
458}
459
461{
463 return 0;
464
465 // Required WebSocket frame opcode and mask flag
466 if (_ws_receive_frame_buffer.size() < 2)
467 return 2 - _ws_receive_frame_buffer.size();
468
469 bool mask = ((_ws_receive_frame_buffer[1] >> 7) & 0x01) != 0;
470 size_t payload = (size_t)_ws_receive_frame_buffer[1] & (~0x80);
471
472 // Required WebSocket frame size
473 if ((payload == 126) && (_ws_receive_frame_buffer.size() < 4))
474 return 4 - _ws_receive_frame_buffer.size();
475 if ((payload == 127) && (_ws_receive_frame_buffer.size() < 10))
476 return 10 - _ws_receive_frame_buffer.size();
477
478 // Required WebSocket frame mask
479 if ((mask) && (_ws_receive_frame_buffer.size() < _ws_header_size))
481
482 // Required WebSocket frame payload
484}
485
487{
488 _ws_frame_received = false;
489 _ws_final_received = false;
490 _ws_header_size = 0;
494 *((uint32_t*)_ws_receive_mask) = 0;
495
496 std::scoped_lock locker(_ws_send_lock);
497
498 _ws_send_buffer.clear();
499 *((uint32_t*)_ws_send_mask) = 0;
500}
501
502} // namespace WS
503} // namespace CppServer
size_t headers() const noexcept
Get the HTTP request headers count.
std::string_view method() const noexcept
Get the HTTP request method.
std::tuple< std::string_view, std::string_view > header(size_t i) const noexcept
Get the HTTP request header by index.
HTTPResponse & SetBody(std::string_view body="")
Set the HTTP response body.
HTTPResponse & SetHeader(std::string_view key, std::string_view value)
Set the HTTP response header.
HTTPResponse & Clear()
Clear the HTTP response cache.
size_t headers() const noexcept
Get the HTTP response headers count.
int status() const noexcept
Get the HTTP response status.
HTTPResponse & MakeErrorResponse(std::string_view content="", std::string_view content_type="text/plain; charset=UTF-8")
Make ERROR response.
HTTPResponse & SetBegin(int status, std::string_view protocol="HTTP/1.1")
Set the HTTP response begin with a given status and protocol.
std::tuple< std::string_view, std::string_view > header(size_t i) const noexcept
Get the HTTP response header by index.
size_t _ws_payload_size
Received frame payload size.
Definition ws.h:182
static const uint8_t WS_CLOSE
Close frame.
Definition ws.h:39
void PrepareSendFrame(uint8_t opcode, bool mask, const void *buffer, size_t size, int status=0)
Prepare WebSocket send frame.
Definition ws.cpp:214
static const uint8_t WS_TEXT
Text frame.
Definition ws.h:35
std::vector< uint8_t > _ws_receive_frame_buffer
Receive frame buffer.
Definition ws.h:184
virtual void onWSClose(const void *buffer, size_t size, int status=1000)
Handle WebSocket client close notification.
Definition ws.h:149
bool _ws_frame_received
Received frame flag.
Definition ws.h:176
static const uint8_t WS_BINARY
Binary frame.
Definition ws.h:37
std::string_view ws_nonce() const noexcept
Get the WebSocket random nonce.
Definition ws.h:54
static const uint8_t WS_PING
Ping frame.
Definition ws.h:41
void InitWSNonce()
Initialize WebSocket random nonce.
Definition ws.cpp:21
void PrepareReceiveFrame(const void *buffer, size_t size)
Prepare WebSocket receive frame.
Definition ws.cpp:275
virtual void onWSConnecting(HTTP::HTTPRequest &request)
Handle WebSocket client connecting notification.
Definition ws.h:107
virtual void onWSReceived(const void *buffer, size_t size)
Handle WebSocket received notification.
Definition ws.h:141
virtual void onWSPing(const void *buffer, size_t size)
Handle WebSocket ping notification.
Definition ws.h:155
std::vector< uint8_t > _ws_receive_final_buffer
Receive final buffer.
Definition ws.h:186
uint8_t _ws_receive_mask[4]
Receive mask.
Definition ws.h:188
std::array< uint8_t, 16 > _ws_nonce
WebSocket random nonce of 16 bytes.
Definition ws.h:198
bool PerformClientUpgrade(const HTTP::HTTPResponse &response, const CppCommon::UUID &id)
Perform WebSocket client upgrade.
Definition ws.cpp:26
std::mutex _ws_send_lock
Send buffer lock.
Definition ws.h:191
virtual void onWSConnected(const HTTP::HTTPResponse &response)
Handle WebSocket client connected notification.
Definition ws.h:112
uint8_t _ws_opcode
Received frame opcode.
Definition ws.h:174
void ClearWSBuffers()
Clear WebSocket send/receive buffers.
Definition ws.cpp:486
virtual void onWSPong(const void *buffer, size_t size)
Handle WebSocket pong notification.
Definition ws.h:161
static const uint8_t WS_PONG
Pong frame.
Definition ws.h:43
virtual void onWSError(const std::string &message)
Handle WebSocket error notification.
Definition ws.h:167
size_t RequiredReceiveFrameSize()
Required WebSocket receive frame size.
Definition ws.cpp:460
std::vector< uint8_t > _ws_send_buffer
Send buffer.
Definition ws.h:193
uint8_t _ws_send_mask[4]
Send mask.
Definition ws.h:195
virtual void SendResponse(const HTTP::HTTPResponse &response)
Send WebSocket server upgrade response.
Definition ws.h:204
bool _ws_final_received
Received final flag.
Definition ws.h:178
size_t _ws_header_size
Received frame header size.
Definition ws.h:180
bool PerformServerUpgrade(const HTTP::HTTPRequest &request, HTTP::HTTPResponse &response)
Perform WebSocket server upgrade.
Definition ws.cpp:103
bool _ws_handshaked
Handshaked flag.
Definition ws.h:171
WebSocket definitions.
C++ Server project definitions.
Definition asio.h:56
WebSocket C++ Library definition.