CppServer 1.0.5.0
C++ Server Library
Loading...
Searching...
No Matches
ws.cpp
Go to the documentation of this file.
1
9#include "server/ws/ws.h"
10
11#include "string/encoding.h"
12#include "string/format.h"
13#include "string/string_utils.h"
14
15#include <algorithm>
16#include <openssl/sha.h>
17
18namespace CppServer {
19namespace WS {
20
22{
23 std::generate(_ws_nonce.begin(), _ws_nonce.end(), []() { return (uint8_t)std::rand(); });
24}
25
26bool WebSocket::PerformClientUpgrade(const HTTP::HTTPResponse& response, const CppCommon::UUID& id)
27{
28 if (response.status() != 101)
29 return false;
30
31 bool error = false;
32 bool accept = false;
33 bool connection = false;
34 bool upgrade = false;
35
36 // Validate WebSocket handshake headers
37 for (size_t i = 0; i < response.headers(); ++i)
38 {
39 auto header = response.header(i);
40 auto key = std::get<0>(header);
41 auto value = std::get<1>(header);
42
43 if (CppCommon::StringUtils::CompareNoCase(key, "Connection"))
44 {
45 if (!CppCommon::StringUtils::CompareNoCase(value, "Upgrade"))
46 {
47 error = true;
48 onWSError("Invalid WebSocket handshaked response: 'Connection' header value must be 'Upgrade'");
49 break;
50 }
51
52 connection = true;
53 }
54 else if (CppCommon::StringUtils::CompareNoCase(key, "Upgrade"))
55 {
56 if (!CppCommon::StringUtils::CompareNoCase(value, "websocket"))
57 {
58 error = true;
59 onWSError("Invalid WebSocket handshaked response: 'Upgrade' header value must be 'websocket'");
60 break;
61 }
62
63 upgrade = true;
64 }
65 else if (CppCommon::StringUtils::CompareNoCase(key, "Sec-WebSocket-Accept"))
66 {
67 // Calculate the original WebSocket hash
68 std::string wskey = CppCommon::Encoding::Base64Encode(ws_nonce()) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
69 char wshash[SHA_DIGEST_LENGTH];
70 SHA1((const unsigned char*)wskey.data(), wskey.size(), (unsigned char*)wshash);
71
72 // Get the received WebSocket hash
73 wskey = CppCommon::Encoding::Base64Decode(value);
74
75 // Compare original and received hashes
76 if (std::strncmp(wskey.data(), wshash, std::min(wskey.size(), sizeof(wshash))) != 0)
77 {
78 error = true;
79 onWSError("Invalid WebSocket handshaked response: 'Sec-WebSocket-Accept' value validation failed");
80 break;
81 }
82
83 accept = true;
84 }
85 }
86
87 // Failed to perform WebSocket handshake
88 if (!accept || !connection || !upgrade)
89 {
90 if (!error)
91 onWSError("Invalid WebSocket response");
92 return false;
93 }
94
95 // WebSocket successfully handshaked!
96 _ws_handshaked = true;
97 *((uint32_t*)_ws_send_mask) = rand();
98 onWSConnected(response);
99
100 return true;
101}
102
104{
105 if (request.method() != "GET")
106 return false;
107
108 bool error = false;
109 bool connection = false;
110 bool upgrade = false;
111 bool ws_key = false;
112 bool ws_version = false;
113
114 std::string accept;
115
116 // Validate WebSocket handshake headers
117 for (size_t i = 0; i < request.headers(); ++i)
118 {
119 auto header = request.header(i);
120 auto key = std::get<0>(header);
121 auto value = std::get<1>(header);
122
123 if (CppCommon::StringUtils::CompareNoCase(key, "Connection"))
124 {
125 if (!CppCommon::StringUtils::CompareNoCase(value, "Upgrade") && !CppCommon::StringUtils::CompareNoCase(CppCommon::StringUtils::RemoveBlank(value), "keep-alive,Upgrade"))
126 {
127 error = true;
128 response.MakeErrorResponse(400, "Invalid WebSocket handshaked request: 'Connection' header value must be 'Upgrade' or 'keep-alive, Upgrade'");
129 break;
130 }
131
132 connection = true;
133 }
134 else if (CppCommon::StringUtils::CompareNoCase(key, "Upgrade"))
135 {
136 if (!CppCommon::StringUtils::CompareNoCase(value, "websocket"))
137 {
138 error = true;
139 response.MakeErrorResponse(400, "Invalid WebSocket handshaked request: 'Upgrade' header value must be 'websocket'");
140 break;
141 }
142
143 upgrade = true;
144 }
145 else if (CppCommon::StringUtils::CompareNoCase(key, "Sec-WebSocket-Key"))
146 {
147 if (value.empty())
148 {
149 error = true;
150 response.MakeErrorResponse(400, "Invalid WebSocket handshaked request: 'Sec-WebSocket-Key' header value must be non empty");
151 break;
152 }
153
154 // Calculate WebSocket accept value
155 std::string wskey = std::string(value) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
156 char wshash[SHA_DIGEST_LENGTH];
157 SHA1((const unsigned char*)wskey.data(), wskey.size(), (unsigned char*)wshash);
158
159 accept = CppCommon::Encoding::Base64Encode(std::string(wshash, sizeof(wshash)));
160
161 ws_key = true;
162 }
163 else if (CppCommon::StringUtils::CompareNoCase(key, "Sec-WebSocket-Version"))
164 {
165 if (!CppCommon::StringUtils::CompareNoCase(value, "13"))
166 {
167 error = true;
168 response.MakeErrorResponse(400, "Invalid WebSocket handshaked request: 'Sec-WebSocket-Version' header value must be '13'");
169 break;
170 }
171
172 ws_version = true;
173 }
174 }
175
176 // Filter out non WebSocket handshake requests
177 if (!connection && !upgrade && !ws_key && !ws_version)
178 return false;
179
180 // Failed to perform WebSocket handshake
181 if (!connection || !upgrade || !ws_key || !ws_version)
182 {
183 if (!error)
184 response.MakeErrorResponse(400, "Invalid WebSocket response");
185 SendResponse(response);
186 return false;
187 }
188
189 // Prepare WebSocket upgrade success response
190 response.Clear();
191 response.SetBegin(101, "HTTP/1.1");
192 response.SetHeader("Connection", "Upgrade");
193 response.SetHeader("Upgrade", "websocket");
194 response.SetHeader("Sec-WebSocket-Accept", accept);
195 response.SetBody();
196
197 // Validate WebSocket upgrade request and response
198 if (!onWSConnecting(request, response))
199 return false;
200
201 // Send WebSocket upgrade response
202 SendResponse(response);
203
204 // WebSocket successfully handshaked!
205 _ws_handshaked = true;
206 *((uint32_t*)_ws_send_mask) = 0;
207 onWSConnected(request);
208
209 return true;
210}
211
212void WebSocket::PrepareSendFrame(uint8_t opcode, bool mask, const void* buffer, size_t size, int status)
213{
214 // Check if we need to store additional 2 bytes of close status frame
215 bool store_status = ((opcode & WS_CLOSE) == WS_CLOSE) && ((size > 0) || (status != 0));
216 if (store_status)
217 size += 2;
218
219 // Clear the previous WebSocket send buffer
220 _ws_send_buffer.clear();
221
222 // Append WebSocket frame opcode
223 _ws_send_buffer.push_back(opcode);
224
225 // Append WebSocket frame size
226 if (size <= 125)
227 _ws_send_buffer.push_back((size & 0xFF) | (mask ? 0x80 : 0));
228 else if (size <= 65535)
229 {
230 _ws_send_buffer.push_back(126 | (mask ? 0x80 : 0));
231 _ws_send_buffer.push_back((size >> 8) & 0xFF);
232 _ws_send_buffer.push_back(size & 0xFF);
233 }
234 else
235 {
236 _ws_send_buffer.push_back(127 | (mask ? 0x80 : 0));
237 for (int i = 7; i >= 0; --i)
238 _ws_send_buffer.push_back((size >> (8 * i)) & 0xFF);
239 }
240
241 if (mask)
242 {
243 // Append WebSocket frame mask
244 _ws_send_buffer.push_back(_ws_send_mask[0]);
245 _ws_send_buffer.push_back(_ws_send_mask[1]);
246 _ws_send_buffer.push_back(_ws_send_mask[2]);
247 _ws_send_buffer.push_back(_ws_send_mask[3]);
248 }
249
250 // Resize WebSocket frame buffer
251 size_t offset = _ws_send_buffer.size();
252 _ws_send_buffer.resize(offset + size);
253
254 size_t index = 0;
255 const uint8_t* data = (const uint8_t*)buffer;
256
257 // Append WebSocket close status
258 // RFC 6455: If there is a body, the first two bytes of the body MUST
259 // be a 2-byte unsigned integer (in network byte order) representing
260 // a status code with value code.
261 if (store_status)
262 {
263 index += 2;
264 _ws_send_buffer[offset + 0] = ((status >> 8) & 0xFF) ^ _ws_send_mask[0];
265 _ws_send_buffer[offset + 1] = (status & 0xFF) ^ _ws_send_mask[1];
266 }
267
268 // Mask WebSocket frame content
269 for (size_t i = index; i < size; ++i)
270 _ws_send_buffer[offset + i] = data[i - index] ^ _ws_send_mask[i % 4];
271}
272
273void WebSocket::PrepareReceiveFrame(const void* buffer, size_t size)
274{
275 const uint8_t* data = (const uint8_t*)buffer;
276
277 // Clear received data after WebSocket frame was processed
279 {
280 _ws_frame_received = false;
281 _ws_header_size = 0;
284 *((uint32_t*)_ws_receive_mask) = 0;
285 }
287 {
288 _ws_final_received = false;
290 }
291
292 while (size > 0)
293 {
294 // Clear received data after WebSocket frame was processed
296 {
297 _ws_frame_received = false;
298 _ws_header_size = 0;
301 *((uint32_t*)_ws_receive_mask) = 0;
302 }
304 {
305 _ws_final_received = false;
307 }
308
309 // Prepare WebSocket frame opcode and mask flag
310 if (_ws_receive_frame_buffer.size() < 2)
311 {
312 for (size_t i = 0; i < 2; ++i, ++data, --size)
313 {
314 if (size == 0)
315 return;
316 _ws_receive_frame_buffer.push_back(*data);
317 }
318 }
319
320 uint8_t opcode = _ws_receive_frame_buffer[0] & 0x0F;
321 bool fin = ((_ws_receive_frame_buffer[0] >> 7) & 0x01) != 0;
322 bool mask = ((_ws_receive_frame_buffer[1] >> 7) & 0x01) != 0;
323 size_t payload = _ws_receive_frame_buffer[1] & (~0x80);
324
325 // Prepare WebSocket opcode
326 _ws_opcode = (opcode != 0) ? opcode : _ws_opcode;
327
328 // Prepare WebSocket frame size
329 if (payload <= 125)
330 {
331 _ws_header_size = 2 + (mask ? 4 : 0);
332 _ws_payload_size = payload;
335 }
336 else if (payload == 126)
337 {
338 if (_ws_receive_frame_buffer.size() < 4)
339 {
340 for (size_t i = 0; i < 2; ++i, ++data, --size)
341 {
342 if (size == 0)
343 return;
344 _ws_receive_frame_buffer.push_back(*data);
345 }
346 }
347
348 payload = (((size_t)_ws_receive_frame_buffer[2] << 8) | ((size_t)_ws_receive_frame_buffer[3] << 0));
349 _ws_header_size = 4 + (mask ? 4 : 0);
350 _ws_payload_size = payload;
353 }
354 else if (payload == 127)
355 {
356 if (_ws_receive_frame_buffer.size() < 10)
357 {
358 for (size_t i = 0; i < 8; ++i, ++data, --size)
359 {
360 if (size == 0)
361 return;
362 _ws_receive_frame_buffer.push_back(*data);
363 }
364 }
365
366 payload = (((size_t)_ws_receive_frame_buffer[2] << 56) | ((size_t)_ws_receive_frame_buffer[3] << 48) | ((size_t)_ws_receive_frame_buffer[4] << 40) | ((size_t)_ws_receive_frame_buffer[5] << 32) | ((size_t)_ws_receive_frame_buffer[6] << 24) | ((size_t)_ws_receive_frame_buffer[7] << 16) | ((size_t)_ws_receive_frame_buffer[8] << 8) | ((size_t)_ws_receive_frame_buffer[9] << 0));
367 _ws_header_size = 10 + (mask ? 4 : 0);
368 _ws_payload_size = payload;
371 }
372
373 // Prepare WebSocket frame mask
374 if (mask)
375 {
377 {
378 for (size_t i = 0; i < 4; ++i, ++data, --size)
379 {
380 if (size == 0)
381 return;
382 _ws_receive_frame_buffer.push_back(*data);
383 _ws_receive_mask[i] = *data;
384 }
385 }
386 }
387
388 size_t total = _ws_header_size + _ws_payload_size;
389 size_t length = std::min(total - _ws_receive_frame_buffer.size(), size);
390
391 // Prepare WebSocket frame payload
392 _ws_receive_frame_buffer.insert(_ws_receive_frame_buffer.end(), data, data + length);
393 data += length;
394 size -= length;
395
396 // Process WebSocket frame
397 if (_ws_receive_frame_buffer.size() == total)
398 {
399 // Unmask WebSocket frame content
400 if (mask)
401 {
402 for (size_t i = 0; i < _ws_payload_size; ++i)
404 }
405 else
407
408 _ws_frame_received = true;
409
410 // Finalize WebSocket frame
411 if (fin)
412 {
413 _ws_final_received = true;
414
415 switch (_ws_opcode)
416 {
417 case WS_PING:
418 {
419 // Call the WebSocket ping handler
421 break;
422 }
423 case WS_PONG:
424 {
425 // Call the WebSocket pong handler
427 break;
428 }
429 case WS_CLOSE:
430 {
431 size_t sindex = 0;
432 int status = 1000;
433
434 // Read WebSocket close status
435 if (_ws_receive_final_buffer.size() >= 2)
436 {
437 sindex += 2;
438 status = ((_ws_receive_final_buffer[0] << 8) | (_ws_receive_final_buffer[1] << 0));
439 }
440
441 // Call the WebSocket close handler
442 onWSClose(_ws_receive_final_buffer.data() + sindex, _ws_receive_final_buffer.size() - sindex, status);
443 break;
444 }
445 case WS_BINARY:
446 case WS_TEXT:
447 {
448 // Call the WebSocket received handler
450 break;
451 }
452 }
453 }
454 }
455 }
456}
457
459{
461 return 0;
462
463 // Required WebSocket frame opcode and mask flag
464 if (_ws_receive_frame_buffer.size() < 2)
465 return 2 - _ws_receive_frame_buffer.size();
466
467 bool mask = ((_ws_receive_frame_buffer[1] >> 7) & 0x01) != 0;
468 size_t payload = _ws_receive_frame_buffer[1] & (~0x80);
469
470 // Required WebSocket frame size
471 if ((payload == 126) && (_ws_receive_frame_buffer.size() < 4))
472 return 4 - _ws_receive_frame_buffer.size();
473 if ((payload == 127) && (_ws_receive_frame_buffer.size() < 10))
474 return 10 - _ws_receive_frame_buffer.size();
475
476 // Required WebSocket frame mask
477 if ((mask) && (_ws_receive_frame_buffer.size() < _ws_header_size))
479
480 // Required WebSocket frame payload
482}
483
485{
486 _ws_frame_received = false;
487 _ws_final_received = false;
488 _ws_header_size = 0;
492 *((uint32_t*)_ws_receive_mask) = 0;
493
494 std::scoped_lock locker(_ws_send_lock);
495
496 _ws_send_buffer.clear();
497 *((uint32_t*)_ws_send_mask) = 0;
498}
499
500} // namespace WS
501} // namespace CppServer
size_t headers() const noexcept
Get the HTTP request headers count.
std::string_view method() const noexcept
Get the HTTP request method.
std::tuple< std::string_view, std::string_view > header(size_t i) const noexcept
Get the HTTP request header by index.
HTTPResponse & SetBody(std::string_view body="")
Set the HTTP response body.
HTTPResponse & SetHeader(std::string_view key, std::string_view value)
Set the HTTP response header.
HTTPResponse & Clear()
Clear the HTTP response cache.
size_t headers() const noexcept
Get the HTTP response headers count.
int status() const noexcept
Get the HTTP response status.
HTTPResponse & MakeErrorResponse(std::string_view content="", std::string_view content_type="text/plain; charset=UTF-8")
Make ERROR response.
HTTPResponse & SetBegin(int status, std::string_view protocol="HTTP/1.1")
Set the HTTP response begin with a given status and protocol.
std::tuple< std::string_view, std::string_view > header(size_t i) const noexcept
Get the HTTP response header by index.
size_t _ws_payload_size
Received frame payload size.
Definition ws.h:180
static const uint8_t WS_CLOSE
Close frame.
Definition ws.h:39
void PrepareSendFrame(uint8_t opcode, bool mask, const void *buffer, size_t size, int status=0)
Prepare WebSocket send frame.
Definition ws.cpp:212
static const uint8_t WS_TEXT
Text frame.
Definition ws.h:35
std::vector< uint8_t > _ws_receive_frame_buffer
Receive frame buffer.
Definition ws.h:182
virtual void onWSClose(const void *buffer, size_t size, int status=1000)
Handle WebSocket client close notification.
Definition ws.h:147
bool _ws_frame_received
Received frame flag.
Definition ws.h:174
static const uint8_t WS_BINARY
Binary frame.
Definition ws.h:37
std::string_view ws_nonce() const noexcept
Get the WebSocket random nonce.
Definition ws.h:54
static const uint8_t WS_PING
Ping frame.
Definition ws.h:41
void InitWSNonce()
Initialize WebSocket random nonce.
Definition ws.cpp:21
void PrepareReceiveFrame(const void *buffer, size_t size)
Prepare WebSocket receive frame.
Definition ws.cpp:273
virtual void onWSConnecting(HTTP::HTTPRequest &request)
Handle WebSocket client connecting notification.
Definition ws.h:107
virtual void onWSReceived(const void *buffer, size_t size)
Handle WebSocket received notification.
Definition ws.h:139
virtual void onWSPing(const void *buffer, size_t size)
Handle WebSocket ping notification.
Definition ws.h:153
std::vector< uint8_t > _ws_receive_final_buffer
Receive final buffer.
Definition ws.h:184
uint8_t _ws_receive_mask[4]
Receive mask.
Definition ws.h:186
std::array< uint8_t, 16 > _ws_nonce
WebSocket random nonce of 16 bytes.
Definition ws.h:196
bool PerformClientUpgrade(const HTTP::HTTPResponse &response, const CppCommon::UUID &id)
Perform WebSocket client upgrade.
Definition ws.cpp:26
std::mutex _ws_send_lock
Send buffer lock.
Definition ws.h:189
virtual void onWSConnected(const HTTP::HTTPResponse &response)
Handle WebSocket client connected notification.
Definition ws.h:112
uint8_t _ws_opcode
Received frame opcode.
Definition ws.h:172
void ClearWSBuffers()
Clear WebSocket send/receive buffers.
Definition ws.cpp:484
virtual void onWSPong(const void *buffer, size_t size)
Handle WebSocket pong notification.
Definition ws.h:159
static const uint8_t WS_PONG
Pong frame.
Definition ws.h:43
virtual void onWSError(const std::string &message)
Handle WebSocket error notification.
Definition ws.h:165
size_t RequiredReceiveFrameSize()
Required WebSocket receive frame size.
Definition ws.cpp:458
std::vector< uint8_t > _ws_send_buffer
Send buffer.
Definition ws.h:191
uint8_t _ws_send_mask[4]
Send mask.
Definition ws.h:193
virtual void SendResponse(const HTTP::HTTPResponse &response)
Send WebSocket server upgrade response.
Definition ws.h:202
bool _ws_final_received
Received final flag.
Definition ws.h:176
size_t _ws_header_size
Received frame header size.
Definition ws.h:178
bool PerformServerUpgrade(const HTTP::HTTPRequest &request, HTTP::HTTPResponse &response)
Perform WebSocket server upgrade.
Definition ws.cpp:103
bool _ws_handshaked
Handshaked flag.
Definition ws.h:169
C++ Server project definitions.
Definition asio.h:56
WebSocket C++ Library definition.