Malloy
Loading...
Searching...
No Matches
connection.hpp
1#pragma once
2
3#include "types.hpp"
4#include "../error.hpp"
5#include "../type_traits.hpp"
6#include "../utils.hpp"
7#include "../detail/action_queue.hpp"
8#include "../http/request.hpp"
9#include "../websocket/stream.hpp"
10
11#include <boost/asio/io_context.hpp>
12#include <boost/asio/post.hpp>
13#include <boost/beast/core/error.hpp>
14#include <fmt/format.h>
15#include <spdlog/spdlog.h>
16
17#include <concepts>
18#include <functional>
19#include <memory>
20
22{
23
34 template<bool isClient>
35 class connection :
36 public std::enable_shared_from_this<connection<isClient>>
37 {
38 using ws_executor_t = std::invoke_result_t<decltype(&stream::get_executor), stream*>;
40
41 public:
42 using handler_t = std::function<void(const malloy::http::request<>&, const std::shared_ptr<connection>&)>;
43
47 enum class state
48 {
49 handshaking,
50 active,
51 closing,
52 closed,
53 inactive, // Initial state
54 };
55
59 virtual
60 ~connection() noexcept
61 {
62 m_logger->trace("destructor()");
63 }
64
72 [[nodiscard]]
73 std::shared_ptr<spdlog::logger>
74 logger() const noexcept
75 {
76 return m_logger;
77 }
78
82 void
83 set_binary(const bool enabled)
84 {
85 m_ws.set_binary(enabled);
86 }
87
91 [[nodiscard]]
92 bool
94 {
95 return m_ws.binary();
96 }
97
104 static
105 std::shared_ptr<connection>
106 make(const std::shared_ptr<spdlog::logger> logger, stream&& ws, const std::string& agent_string)
107 {
108 // We have to emulate make_shared here because the ctor is private
109 connection* me = nullptr;
110 try {
111 me = new connection{logger, std::move(ws), agent_string};
112 return std::shared_ptr<connection>{me};
113 }
114 catch (...) {
115 delete me;
116 throw;
117 }
118 }
119
128 template<concepts::accept_handler Callback>
129 void
130 connect(const boost::asio::ip::tcp::resolver::results_type& target, const std::string& resource, Callback&& done)
131 requires(isClient)
132 {
133 m_logger->trace("connect()");
134
135 if (m_state != state::inactive)
136 throw std::logic_error{"connect() called on already active websocket connection"};
137
138 // Set the timeout for the operation
139 m_ws.get_lowest_layer([&, me = this->shared_from_this(), this, done = std::forward<Callback>(done), resource](auto& sock) mutable {
140 sock.expires_after(std::chrono::seconds(30));
141
142 // Make the connection on the IP address we get from a lookup
143 sock.async_connect(
144 target,
145 [this, me, target, done = std::forward<Callback>(done), resource](auto ec, boost::asio::ip::tcp::resolver::results_type::endpoint_type ep) mutable {
146 if (ec) {
147 done(ec);
148 }
149 else {
150 me->on_connect(
151 ec,
152 ep,
153 resource,
154 [this, done = std::forward<Callback>(done)](auto ec) mutable {
155 go_active();
156 std::invoke(std::forward<decltype(done)>(done), ec);
157 }
158 );
159 }
160 });
161 });
162 }
163
175 template<class Body, class Fields, std::invocable<> Callback>
176 void
177 accept(const boost::beast::http::request<Body, Fields>& req, Callback&& done)
178 requires(!isClient)
179 {
180 m_logger->trace("accept()");
181
182 if (m_state != state::inactive)
183 throw std::logic_error{"accept() called on already active websocket connection"};
184
185 // Update state
186 m_state = state::handshaking;
187
188 setup_connection();
189
190 // Accept the websocket handshake
191 m_ws.async_accept(req, [this, me = this->shared_from_this(), done = std::forward<decltype(done)>(done)](malloy::error_code ec) mutable {
192 m_logger->trace("on_accept()");
193
194 // Check for errors
195 if (ec) {
196 m_logger->error("on_accept(): {}", ec.message());
197 return;
198 }
199
200 // We're good to go
201 go_active();
202
203 std::invoke(std::forward<decltype(done)>(done));
204 });
205 }
206
216 void
217 disconnect(boost::beast::websocket::close_reason why = boost::beast::websocket::normal)
218 {
219 m_logger->trace("disconnect()");
220
221 if (m_state == state::closed || m_state == state::closing)
222 return;
223
224 auto build_act = [this, why, me = this->shared_from_this()](const auto& on_done) mutable {
225 // Check we haven't been beaten to it
226 if (m_state == state::closed || m_state == state::closing) {
227 on_done();
228 return;
229 }
230
231 do_disconnect(why, on_done);
232 };
233
234 // We queue in both read and write, and whichever gets there first wins
235 m_write_queue.push(build_act);
236 m_read_queue.push(build_act);
237 }
238
246 void
247 force_disconnect(boost::beast::websocket::close_reason why = boost::beast::websocket::normal)
248 {
249 m_logger->trace("force_disconnect()");
250
251 if (m_state == state::inactive)
252 throw std::logic_error{"force_disconnect() called on inactive websocket connection"};
253
254 else if (m_state == state::closed || m_state == state::closing)
255 return; // Already disconnecting
256
257 do_disconnect(why, []{});
258 }
259
272 void
274 {
275 m_logger->trace("read()");
276
277 m_read_queue.push(
278 [
279 this,
280 me = this->shared_from_this(),
281 buff = &buff, // Capturing reference by value copies the object
282 done = std::forward<decltype(done)>(done)
283 ]
284 (const auto& on_done) mutable
285 {
286 assert(buff != nullptr);
287 m_ws.async_read(*buff, [this, me, on_done, done = std::forward<decltype(done)>(done)](auto ec, auto size) mutable {
288 std::invoke(std::forward<decltype(done)>(done), ec, size);
289 on_done();
290 });
291 }
292 );
293 }
294
305 template<concepts::async_read_handler Callback>
306 void
307 send(const concepts::const_buffer_sequence auto& payload, Callback&& done)
308 {
309 m_logger->trace("send(). payload size: {}", payload.size());
310
311 m_write_queue.push([buff = payload, done = std::forward<Callback>(done), this, me = this->shared_from_this()](const auto& on_done) mutable {
312 m_ws.async_write(buff, [this, me, on_done, done = std::forward<decltype(done)>(done)](auto ec, auto size) mutable {
313 on_write(ec, size);
314 std::invoke(std::forward<Callback>(done), ec, size);
315 on_done();
316 });
317 });
318 }
319
320 private:
321 enum class sending_state
322 {
323 idling,
324 sending
325 };
326
327 enum sending_state m_sending_state = sending_state::idling;
328 std::shared_ptr<spdlog::logger> m_logger;
329 stream m_ws;
330 std::string m_agent_string;
331 act_queue_t m_write_queue;
332 act_queue_t m_read_queue;
333 std::atomic<state> m_state{ state::inactive };
334
335 connection(
336 std::shared_ptr<spdlog::logger> logger, stream&& ws, std::string agent_str) :
337 m_logger(std::move(logger)),
338 m_ws{std::move(ws)},
339 m_agent_string{std::move(agent_str)},
340 m_write_queue{boost::asio::make_strand(m_ws.get_executor())},
341 m_read_queue{boost::asio::make_strand(m_ws.get_executor())}
342 {
343 // Sanity check logger
344 if (!m_logger)
345 throw std::invalid_argument("no valid logger provided.");
346 }
347
348 void
349 go_active()
350 {
351 m_logger->trace("go_active()");
352
353 // Update state
354 m_state = state::active;
355
356 // Start/run action queues
357 m_read_queue.run();
358 m_write_queue.run();
359 }
360
361 void
362 setup_connection()
363 {
364 m_logger->trace("setup_connection()");
365
366 // Set suggested timeout settings for the websocket
367 m_ws.set_option(
368 boost::beast::websocket::stream_base::timeout::suggested(
369 isClient ? boost::beast::role_type::client : boost::beast::role_type::server)
370 );
371
372 // Set agent string/field
373 const auto agent_field = isClient ? malloy::http::field::user_agent : malloy::http::field::server;
374 m_ws.set_option(
375 boost::beast::websocket::stream_base::decorator(
376 [this, agent_field](boost::beast::websocket::request_type& req) {
377 req.set(agent_field, m_agent_string);
378 }
379 )
380 );
381 }
382
383 void
384 do_disconnect(boost::beast::websocket::close_reason why, const std::invocable<> auto& on_done)
385 {
386 m_logger->trace("do_disconnect()");
387
388 // Update state
389 m_state = state::closing;
390
391 m_ws.async_close(why, [me = this->shared_from_this(), this, on_done](auto ec) {
392 if (ec)
393 m_logger->error("could not close websocket: '{}'", ec.message()); // TODO: See #40
394 else
395 on_close();
396
397 on_done();
398 });
399 }
400
401 void
402 on_connect(
403 boost::beast::error_code ec,
404 boost::asio::ip::tcp::resolver::results_type::endpoint_type ep,
405 const std::string& resource,
406 concepts::accept_handler auto&& on_handshake)
407 {
408 m_logger->trace("on_connect()");
409
410 if (ec) {
411 m_logger->error("on_connect(): {}", ec.message());
412 return;
413 }
414
415 m_ws.get_lowest_layer([](auto& s) { s.expires_never(); }); // websocket has its own timeout system that conflicts
416
417 // Update the m_host string. This will provide the value of the
418 // Host HTTP header during the WebSocket handshake.
419 // See https://tools.ietf.org/html/rfc7230#section-5.4
420 const std::string host = fmt::format("{}:{}", ep.address().to_string(), ep.port());
421
422#if MALLOY_FEATURE_TLS
423 if constexpr (isClient) {
424 if (m_ws.is_tls()) {
425 // TODO: Should this be a separate method?
426 m_ws.async_handshake_tls(
427 boost::asio::ssl::stream_base::handshake_type::client,
428 [on_handshake = std::forward<decltype(on_handshake)>(on_handshake), resource, host, me = this->shared_from_this()](auto ec) mutable
429 {
430 if (ec)
431 on_handshake(ec);
432
433 me->on_ready_for_handshake(host, resource, std::forward<decltype(on_handshake)>(on_handshake));
434 }
435 );
436 return;
437 }
438 }
439#endif
440 on_ready_for_handshake(host, resource, std::forward<decltype(on_handshake)>(on_handshake));
441 }
442
443 void
444 on_ready_for_handshake(const std::string& host, const std::string& resource, concepts::accept_handler auto&& on_handshake)
445 {
446 m_logger->trace("on_ready_for_handshake()");
447
448 // Turn off the timeout on the tcp_stream, because
449 // the websocket stream has its own timeout system.
450 m_ws.get_lowest_layer([](auto& s) { s.expires_never(); });
451 setup_connection();
452
453 // Perform the websocket handshake
454 m_ws.async_handshake(
455 host,
456 resource,
457 std::forward<decltype(on_handshake)>(on_handshake)
458 );
459 }
460
461 void
462 on_write(auto ec, auto size)
463 {
464 m_logger->trace("on_write() wrote: '{}' bytes", size);
465
466 if (ec) {
467 m_logger->error("on_write failed for websocket connection: '{}'", ec.message());
468 return;
469 }
470 }
471
472 void
473 on_close()
474 {
475 m_logger->trace("on_close()");
476
477 m_state = state::closed;
478 }
479 };
480
481} // namespace malloy::websocket
void push(act_t act)
Add an action to the queue.
Definition: action_queue.hpp:48
Definition: request.hpp:19
Represents a connection via the WebSocket protocol.
Definition: connection.hpp:37
static std::shared_ptr< connection > make(const std::shared_ptr< spdlog::logger > logger, stream &&ws, const std::string &agent_string)
Construct a new connection object.
Definition: connection.hpp:106
void send(const concepts::const_buffer_sequence auto &payload, Callback &&done)
Send the contents of a buffer to the client.
Definition: connection.hpp:307
state
Definition: connection.hpp:48
void force_disconnect(boost::beast::websocket::close_reason why=boost::beast::websocket::normal)
Same as disconnect, but bypasses all queues and runs immediately.
Definition: connection.hpp:247
void set_binary(const bool enabled)
Definition: connection.hpp:83
void read(concepts::dynamic_buffer auto &buff, concepts::async_read_handler auto &&done)
Read a complete message into a buffer.
Definition: connection.hpp:273
bool binary()
Definition: connection.hpp:93
std::shared_ptr< spdlog::logger > logger() const noexcept
Definition: connection.hpp:74
void accept(const boost::beast::http::request< Body, Fields > &req, Callback &&done)
Accept an incoming connection.
Definition: connection.hpp:177
void connect(const boost::asio::ip::tcp::resolver::results_type &target, const std::string &resource, Callback &&done)
Connect to a remote (websocket) endpoint.
Definition: connection.hpp:130
virtual ~connection() noexcept
Definition: connection.hpp:60
void disconnect(boost::beast::websocket::close_reason why=boost::beast::websocket::normal)
Disconnect/stop/close the connection.
Definition: connection.hpp:217
Websocket stream. May use TLS.
Definition: stream.hpp:50
bool binary() const
Checks whether outgoing messages will be indicated as text or binary.
Definition: stream.hpp:222
void get_lowest_layer(Func &&visitor)
Access get_lowest_layer of wrapped stream type.
Definition: stream.hpp:244
void set_binary(const bool enabled)
Controls whether outgoing message will be indicated text or binary.
Definition: stream.hpp:204
auto get_executor()
Get executor of the underlying stream.
Definition: stream.hpp:260
Definition: type_traits.hpp:44
Definition: type_traits.hpp:35
Definition: type_traits.hpp:41
Definition: connection.hpp:22
boost::beast::error_code error_code
Error code used to signify errors without throwing. Truthy means it holds an error.
Definition: error.hpp:9