CppServer  1.0.4.0
C++ Server Library
ws.cpp
Go to the documentation of this file.
1 
9 #include "server/ws/ws.h"
10 
11 #include "string/encoding.h"
12 #include "string/format.h"
13 #include "string/string_utils.h"
14 
15 #include <algorithm>
16 #include <openssl/sha.h>
17 
18 namespace CppServer {
19 namespace WS {
20 
22 {
23  std::generate(_ws_nonce.begin(), _ws_nonce.end(), []() { return (uint8_t)std::rand(); });
24 }
25 
26 bool WebSocket::PerformClientUpgrade(const HTTP::HTTPResponse& response, const CppCommon::UUID& id)
27 {
28  if (response.status() != 101)
29  return false;
30 
31  bool error = false;
32  bool accept = false;
33  bool connection = false;
34  bool upgrade = false;
35 
36  // Validate WebSocket handshake headers
37  for (size_t i = 0; i < response.headers(); ++i)
38  {
39  auto header = response.header(i);
40  auto key = std::get<0>(header);
41  auto value = std::get<1>(header);
42 
43  if (CppCommon::StringUtils::CompareNoCase(key, "Connection"))
44  {
45  if (!CppCommon::StringUtils::CompareNoCase(value, "Upgrade"))
46  {
47  error = true;
48  onWSError("Invalid WebSocket handshaked response: 'Connection' header value must be 'Upgrade'");
49  break;
50  }
51 
52  connection = true;
53  }
54  else if (CppCommon::StringUtils::CompareNoCase(key, "Upgrade"))
55  {
56  if (!CppCommon::StringUtils::CompareNoCase(value, "websocket"))
57  {
58  error = true;
59  onWSError("Invalid WebSocket handshaked response: 'Upgrade' header value must be 'websocket'");
60  break;
61  }
62 
63  upgrade = true;
64  }
65  else if (CppCommon::StringUtils::CompareNoCase(key, "Sec-WebSocket-Accept"))
66  {
67  // Calculate the original WebSocket hash
68  std::string wskey = CppCommon::Encoding::Base64Encode(ws_nonce()) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
69  char wshash[SHA_DIGEST_LENGTH];
70  SHA1((const unsigned char*)wskey.data(), wskey.size(), (unsigned char*)wshash);
71 
72  // Get the received WebSocket hash
73  wskey = CppCommon::Encoding::Base64Decode(value);
74 
75  // Compare original and received hashes
76  if (std::strncmp(wskey.data(), wshash, std::min(wskey.size(), sizeof(wshash))) != 0)
77  {
78  error = true;
79  onWSError("Invalid WebSocket handshaked response: 'Sec-WebSocket-Accept' value validation failed");
80  break;
81  }
82 
83  accept = true;
84  }
85  }
86 
87  // Failed to perform WebSocket handshake
88  if (!accept || !connection || !upgrade)
89  {
90  if (!error)
91  onWSError("Invalid WebSocket response");
92  return false;
93  }
94 
95  // WebSocket successfully handshaked!
96  _ws_handshaked = true;
97  *((uint32_t*)_ws_send_mask) = rand();
98  onWSConnected(response);
99 
100  return true;
101 }
102 
104 {
105  if (request.method() != "GET")
106  return false;
107 
108  bool error = false;
109  bool connection = false;
110  bool upgrade = false;
111  bool ws_key = false;
112  bool ws_version = false;
113 
114  std::string accept;
115 
116  // Validate WebSocket handshake headers
117  for (size_t i = 0; i < request.headers(); ++i)
118  {
119  auto header = request.header(i);
120  auto key = std::get<0>(header);
121  auto value = std::get<1>(header);
122 
123  if (CppCommon::StringUtils::CompareNoCase(key, "Connection"))
124  {
125  if (!CppCommon::StringUtils::CompareNoCase(value, "Upgrade") && !CppCommon::StringUtils::CompareNoCase(CppCommon::StringUtils::RemoveBlank(value), "keep-alive,Upgrade"))
126  {
127  error = true;
128  response.MakeErrorResponse(400, "Invalid WebSocket handshaked request: 'Connection' header value must be 'Upgrade' or 'keep-alive, Upgrade'");
129  break;
130  }
131 
132  connection = true;
133  }
134  else if (CppCommon::StringUtils::CompareNoCase(key, "Upgrade"))
135  {
136  if (!CppCommon::StringUtils::CompareNoCase(value, "websocket"))
137  {
138  error = true;
139  response.MakeErrorResponse(400, "Invalid WebSocket handshaked request: 'Upgrade' header value must be 'websocket'");
140  break;
141  }
142 
143  upgrade = true;
144  }
145  else if (CppCommon::StringUtils::CompareNoCase(key, "Sec-WebSocket-Key"))
146  {
147  if (value.empty())
148  {
149  error = true;
150  response.MakeErrorResponse(400, "Invalid WebSocket handshaked request: 'Sec-WebSocket-Key' header value must be non empty");
151  break;
152  }
153 
154  // Calculate WebSocket accept value
155  std::string wskey = std::string(value) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
156  char wshash[SHA_DIGEST_LENGTH];
157  SHA1((const unsigned char*)wskey.data(), wskey.size(), (unsigned char*)wshash);
158 
159  accept = CppCommon::Encoding::Base64Encode(std::string(wshash, sizeof(wshash)));
160 
161  ws_key = true;
162  }
163  else if (CppCommon::StringUtils::CompareNoCase(key, "Sec-WebSocket-Version"))
164  {
165  if (!CppCommon::StringUtils::CompareNoCase(value, "13"))
166  {
167  error = true;
168  response.MakeErrorResponse(400, "Invalid WebSocket handshaked request: 'Sec-WebSocket-Version' header value must be '13'");
169  break;
170  }
171 
172  ws_version = true;
173  }
174  }
175 
176  // Filter out non WebSocket handshake requests
177  if (!connection && !upgrade && !ws_key && !ws_version)
178  return false;
179 
180  // Failed to perform WebSocket handshake
181  if (!connection || !upgrade || !ws_key || !ws_version)
182  {
183  if (!error)
184  response.MakeErrorResponse(400, "Invalid WebSocket response");
185  SendResponse(response);
186  return false;
187  }
188 
189  // Prepare WebSocket upgrade success response
190  response.Clear();
191  response.SetBegin(101, "HTTP/1.1");
192  response.SetHeader("Connection", "Upgrade");
193  response.SetHeader("Upgrade", "websocket");
194  response.SetHeader("Sec-WebSocket-Accept", accept);
195  response.SetBody();
196 
197  // Validate WebSocket upgrade request and response
198  if (!onWSConnecting(request, response))
199  return false;
200 
201  // Send WebSocket upgrade response
202  SendResponse(response);
203 
204  // WebSocket successfully handshaked!
205  _ws_handshaked = true;
206  *((uint32_t*)_ws_send_mask) = 0;
207  onWSConnected(request);
208 
209  return true;
210 }
211 
212 void WebSocket::PrepareSendFrame(uint8_t opcode, bool mask, const void* buffer, size_t size, int status)
213 {
214  // Check if we need to store additional 2 bytes of close status frame
215  bool store_status = ((opcode & WS_CLOSE) == WS_CLOSE) && ((size > 0) || (status != 0));
216  if (store_status)
217  size += 2;
218 
219  // Clear the previous WebSocket send buffer
220  _ws_send_buffer.clear();
221 
222  // Append WebSocket frame opcode
223  _ws_send_buffer.push_back(opcode);
224 
225  // Append WebSocket frame size
226  if (size <= 125)
227  _ws_send_buffer.push_back((size & 0xFF) | (mask ? 0x80 : 0));
228  else if (size <= 65535)
229  {
230  _ws_send_buffer.push_back(126 | (mask ? 0x80 : 0));
231  _ws_send_buffer.push_back((size >> 8) & 0xFF);
232  _ws_send_buffer.push_back(size & 0xFF);
233  }
234  else
235  {
236  _ws_send_buffer.push_back(127 | (mask ? 0x80 : 0));
237  for (int i = 7; i >= 0; --i)
238  _ws_send_buffer.push_back((size >> (8 * i)) & 0xFF);
239  }
240 
241  if (mask)
242  {
243  // Append WebSocket frame mask
244  _ws_send_buffer.push_back(_ws_send_mask[0]);
245  _ws_send_buffer.push_back(_ws_send_mask[1]);
246  _ws_send_buffer.push_back(_ws_send_mask[2]);
247  _ws_send_buffer.push_back(_ws_send_mask[3]);
248  }
249 
250  // Resize WebSocket frame buffer
251  size_t offset = _ws_send_buffer.size();
252  _ws_send_buffer.resize(offset + size);
253 
254  size_t index = 0;
255  const uint8_t* data = (const uint8_t*)buffer;
256 
257  // Append WebSocket close status
258  // RFC 6455: If there is a body, the first two bytes of the body MUST
259  // be a 2-byte unsigned integer (in network byte order) representing
260  // a status code with value code.
261  if (store_status)
262  {
263  index += 2;
264  _ws_send_buffer[offset + 0] = ((status >> 8) & 0xFF) ^ _ws_send_mask[0];
265  _ws_send_buffer[offset + 1] = (status & 0xFF) ^ _ws_send_mask[1];
266  }
267 
268  // Mask WebSocket frame content
269  for (size_t i = index; i < size; ++i)
270  _ws_send_buffer[offset + i] = data[i - index] ^ _ws_send_mask[i % 4];
271 }
272 
273 void WebSocket::PrepareReceiveFrame(const void* buffer, size_t size)
274 {
275  const uint8_t* data = (const uint8_t*)buffer;
276 
277  // Clear received data after WebSocket frame was processed
278  if (_ws_frame_received)
279  {
280  _ws_frame_received = false;
281  _ws_header_size = 0;
282  _ws_payload_size = 0;
283  _ws_receive_frame_buffer.clear();
284  *((uint32_t*)_ws_receive_mask) = 0;
285  }
286  if (_ws_final_received)
287  {
288  _ws_final_received = false;
289  _ws_receive_final_buffer.clear();
290  }
291 
292  while (size > 0)
293  {
294  // Clear received data after WebSocket frame was processed
295  if (_ws_frame_received)
296  {
297  _ws_frame_received = false;
298  _ws_header_size = 0;
299  _ws_payload_size = 0;
300  _ws_receive_frame_buffer.clear();
301  *((uint32_t*)_ws_receive_mask) = 0;
302  }
303  if (_ws_final_received)
304  {
305  _ws_final_received = false;
306  _ws_receive_final_buffer.clear();
307  }
308 
309  // Prepare WebSocket frame opcode and mask flag
310  if (_ws_receive_frame_buffer.size() < 2)
311  {
312  for (size_t i = 0; i < 2; ++i, ++data, --size)
313  {
314  if (size == 0)
315  return;
316  _ws_receive_frame_buffer.push_back(*data);
317  }
318  }
319 
320  uint8_t opcode = _ws_receive_frame_buffer[0] & 0x0F;
321  bool fin = ((_ws_receive_frame_buffer[0] >> 7) & 0x01) != 0;
322  bool mask = ((_ws_receive_frame_buffer[1] >> 7) & 0x01) != 0;
323  size_t payload = _ws_receive_frame_buffer[1] & (~0x80);
324 
325  // Prepare WebSocket opcode
326  _ws_opcode = (opcode != 0) ? opcode : _ws_opcode;
327 
328  // Prepare WebSocket frame size
329  if (payload <= 125)
330  {
331  _ws_header_size = 2 + (mask ? 4 : 0);
332  _ws_payload_size = payload;
335  }
336  else if (payload == 126)
337  {
338  if (_ws_receive_frame_buffer.size() < 4)
339  {
340  for (size_t i = 0; i < 2; ++i, ++data, --size)
341  {
342  if (size == 0)
343  return;
344  _ws_receive_frame_buffer.push_back(*data);
345  }
346  }
347 
348  payload = (((size_t)_ws_receive_frame_buffer[2] << 8) | ((size_t)_ws_receive_frame_buffer[3] << 0));
349  _ws_header_size = 4 + (mask ? 4 : 0);
350  _ws_payload_size = payload;
353  }
354  else if (payload == 127)
355  {
356  if (_ws_receive_frame_buffer.size() < 10)
357  {
358  for (size_t i = 0; i < 8; ++i, ++data, --size)
359  {
360  if (size == 0)
361  return;
362  _ws_receive_frame_buffer.push_back(*data);
363  }
364  }
365 
366  payload = (((size_t)_ws_receive_frame_buffer[2] << 56) | ((size_t)_ws_receive_frame_buffer[3] << 48) | ((size_t)_ws_receive_frame_buffer[4] << 40) | ((size_t)_ws_receive_frame_buffer[5] << 32) | ((size_t)_ws_receive_frame_buffer[6] << 24) | ((size_t)_ws_receive_frame_buffer[7] << 16) | ((size_t)_ws_receive_frame_buffer[8] << 8) | ((size_t)_ws_receive_frame_buffer[9] << 0));
367  _ws_header_size = 10 + (mask ? 4 : 0);
368  _ws_payload_size = payload;
371  }
372 
373  // Prepare WebSocket frame mask
374  if (mask)
375  {
377  {
378  for (size_t i = 0; i < 4; ++i, ++data, --size)
379  {
380  if (size == 0)
381  return;
382  _ws_receive_frame_buffer.push_back(*data);
383  _ws_receive_mask[i] = *data;
384  }
385  }
386  }
387 
388  size_t total = _ws_header_size + _ws_payload_size;
389  size_t length = std::min(total - _ws_receive_frame_buffer.size(), size);
390 
391  // Prepare WebSocket frame payload
392  _ws_receive_frame_buffer.insert(_ws_receive_frame_buffer.end(), data, data + length);
393  data += length;
394  size -= length;
395 
396  // Process WebSocket frame
397  if (_ws_receive_frame_buffer.size() == total)
398  {
399  // Unmask WebSocket frame content
400  if (mask)
401  {
402  for (size_t i = 0; i < _ws_payload_size; ++i)
404  }
405  else
407 
408  _ws_frame_received = true;
409 
410  // Finalize WebSocket frame
411  if (fin)
412  {
413  _ws_final_received = true;
414 
415  switch (_ws_opcode)
416  {
417  case WS_PING:
418  {
419  // Call the WebSocket ping handler
421  break;
422  }
423  case WS_PONG:
424  {
425  // Call the WebSocket pong handler
427  break;
428  }
429  case WS_CLOSE:
430  {
431  size_t sindex = 0;
432  int status = 1000;
433 
434  // Read WebSocket close status
435  if (_ws_receive_final_buffer.size() >= 2)
436  {
437  sindex += 2;
438  status = ((_ws_receive_final_buffer[0] << 8) | (_ws_receive_final_buffer[1] << 0));
439  }
440 
441  // Call the WebSocket close handler
442  onWSClose(_ws_receive_final_buffer.data() + sindex, _ws_receive_final_buffer.size() - sindex, status);
443  break;
444  }
445  case WS_BINARY:
446  case WS_TEXT:
447  {
448  // Call the WebSocket received handler
450  break;
451  }
452  }
453  }
454  }
455  }
456 }
457 
459 {
460  if (_ws_frame_received)
461  return 0;
462 
463  // Required WebSocket frame opcode and mask flag
464  if (_ws_receive_frame_buffer.size() < 2)
465  return 2 - _ws_receive_frame_buffer.size();
466 
467  bool mask = ((_ws_receive_frame_buffer[1] >> 7) & 0x01) != 0;
468  size_t payload = _ws_receive_frame_buffer[1] & (~0x80);
469 
470  // Required WebSocket frame size
471  if ((payload == 126) && (_ws_receive_frame_buffer.size() < 4))
472  return 4 - _ws_receive_frame_buffer.size();
473  if ((payload == 127) && (_ws_receive_frame_buffer.size() < 10))
474  return 10 - _ws_receive_frame_buffer.size();
475 
476  // Required WebSocket frame mask
477  if ((mask) && (_ws_receive_frame_buffer.size() < _ws_header_size))
479 
480  // Required WebSocket frame payload
482 }
483 
485 {
486  _ws_frame_received = false;
487  _ws_final_received = false;
488  _ws_header_size = 0;
489  _ws_payload_size = 0;
490  _ws_receive_frame_buffer.clear();
491  _ws_receive_final_buffer.clear();
492  *((uint32_t*)_ws_receive_mask) = 0;
493 
494  std::scoped_lock locker(_ws_send_lock);
495 
496  _ws_send_buffer.clear();
497  *((uint32_t*)_ws_send_mask) = 0;
498 }
499 
500 } // namespace WS
501 } // namespace CppServer
size_t headers() const noexcept
Get the HTTP request headers count.
Definition: http_request.h:64
std::string_view method() const noexcept
Get the HTTP request method.
Definition: http_request.h:58
std::tuple< std::string_view, std::string_view > header(size_t i) const noexcept
Get the HTTP request header by index.
HTTPResponse & SetBody(std::string_view body="")
Set the HTTP response body.
HTTPResponse & SetHeader(std::string_view key, std::string_view value)
Set the HTTP response header.
HTTPResponse & Clear()
Clear the HTTP response cache.
size_t headers() const noexcept
Get the HTTP response headers count.
Definition: http_response.h:73
int status() const noexcept
Get the HTTP response status.
Definition: http_response.h:67
HTTPResponse & MakeErrorResponse(std::string_view content="", std::string_view content_type="text/plain; charset=UTF-8")
Make ERROR response.
HTTPResponse & SetBegin(int status, std::string_view protocol="HTTP/1.1")
Set the HTTP response begin with a given status and protocol.
std::tuple< std::string_view, std::string_view > header(size_t i) const noexcept
Get the HTTP response header by index.
size_t _ws_payload_size
Received frame payload size.
Definition: ws.h:180
static const uint8_t WS_CLOSE
Close frame.
Definition: ws.h:39
void PrepareSendFrame(uint8_t opcode, bool mask, const void *buffer, size_t size, int status=0)
Prepare WebSocket send frame.
Definition: ws.cpp:212
static const uint8_t WS_TEXT
Text frame.
Definition: ws.h:35
std::vector< uint8_t > _ws_receive_frame_buffer
Receive frame buffer.
Definition: ws.h:182
virtual void onWSClose(const void *buffer, size_t size, int status=1000)
Handle WebSocket client close notification.
Definition: ws.h:147
bool _ws_frame_received
Received frame flag.
Definition: ws.h:174
static const uint8_t WS_BINARY
Binary frame.
Definition: ws.h:37
std::string_view ws_nonce() const noexcept
Get the WebSocket random nonce.
Definition: ws.h:54
static const uint8_t WS_PING
Ping frame.
Definition: ws.h:41
void InitWSNonce()
Initialize WebSocket random nonce.
Definition: ws.cpp:21
void PrepareReceiveFrame(const void *buffer, size_t size)
Prepare WebSocket receive frame.
Definition: ws.cpp:273
virtual void onWSConnecting(HTTP::HTTPRequest &request)
Handle WebSocket client connecting notification.
Definition: ws.h:107
virtual void onWSReceived(const void *buffer, size_t size)
Handle WebSocket received notification.
Definition: ws.h:139
virtual void onWSPing(const void *buffer, size_t size)
Handle WebSocket ping notification.
Definition: ws.h:153
std::vector< uint8_t > _ws_receive_final_buffer
Receive final buffer.
Definition: ws.h:184
uint8_t _ws_receive_mask[4]
Receive mask.
Definition: ws.h:186
std::array< uint8_t, 16 > _ws_nonce
WebSocket random nonce of 16 bytes.
Definition: ws.h:196
bool PerformClientUpgrade(const HTTP::HTTPResponse &response, const CppCommon::UUID &id)
Perform WebSocket client upgrade.
Definition: ws.cpp:26
std::mutex _ws_send_lock
Send buffer lock.
Definition: ws.h:189
virtual void onWSConnected(const HTTP::HTTPResponse &response)
Handle WebSocket client connected notification.
Definition: ws.h:112
uint8_t _ws_opcode
Received frame opcode.
Definition: ws.h:172
void ClearWSBuffers()
Clear WebSocket send/receive buffers.
Definition: ws.cpp:484
virtual void onWSPong(const void *buffer, size_t size)
Handle WebSocket pong notification.
Definition: ws.h:159
static const uint8_t WS_PONG
Pong frame.
Definition: ws.h:43
virtual void onWSError(const std::string &message)
Handle WebSocket error notification.
Definition: ws.h:165
size_t RequiredReceiveFrameSize()
Required WebSocket receive frame size.
Definition: ws.cpp:458
std::vector< uint8_t > _ws_send_buffer
Send buffer.
Definition: ws.h:191
uint8_t _ws_send_mask[4]
Send mask.
Definition: ws.h:193
virtual void SendResponse(const HTTP::HTTPResponse &response)
Send WebSocket server upgrade response.
Definition: ws.h:202
bool _ws_final_received
Received final flag.
Definition: ws.h:176
size_t _ws_header_size
Received frame header size.
Definition: ws.h:178
bool PerformServerUpgrade(const HTTP::HTTPRequest &request, HTTP::HTTPResponse &response)
Perform WebSocket server upgrade.
Definition: ws.cpp:103
bool _ws_handshaked
Handshaked flag.
Definition: ws.h:169
C++ Server project definitions.
Definition: asio.h:56
WebSocket C++ Library definition.