Malloy
Loading...
Searching...
No Matches
connection.hpp
1#pragma once
2
3#include "../type_traits.hpp"
4#include "../../core/type_traits.hpp"
5#include "../../core/http/request.hpp"
6#include "../../core/http/response.hpp"
7#include "../../core/http/type_traits.hpp"
8
9#include <boost/asio/strand.hpp>
10#include <boost/beast/core.hpp>
11#include <spdlog/logger.h>
12
13#include <future>
14#include <optional>
15
16namespace malloy::client::http
17{
18
24 template<class Derived, malloy::http::concepts::body ReqBody, concepts::response_filter Filter, typename Callback>
26 {
27 public:
28 using resp_t = typename Filter::response_type;
29 using callback_t = Callback;
30
31 connection(std::shared_ptr<spdlog::logger> logger, boost::asio::io_context& io_ctx, const std::uint64_t body_limit) :
32 m_logger(std::move(logger)),
33 m_resolver(boost::asio::make_strand(io_ctx))
34 {
35 // Sanity check
36 if (!m_logger)
37 throw std::invalid_argument("no valid logger provided.");
38
39 // Set body limit
40 m_parser.body_limit(body_limit);
41 }
42
43 // Start the asynchronous operation
44 void
45 run(
47 std::promise<malloy::error_code> err_channel,
48 callback_t&& cb,
49 Filter&& filter
50 )
51 {
52 m_req_filter = std::move(filter);
53 m_req = std::move(req);
54 m_err_channel = std::move(err_channel);
55 m_cb.emplace(std::move(cb));
56
57 // Look up the domain name
58 m_resolver.async_resolve(
59 m_req.base()[malloy::http::field::host],
60 std::to_string(m_req.port()),
61 boost::beast::bind_front_handler(
62 &connection::on_resolve,
63 derived().shared_from_this()
64 )
65 );
66 }
67
68 protected:
69 std::shared_ptr<spdlog::logger> m_logger;
70
71 void
72 send_request()
73 {
74 // Send the HTTP request to the remote host
75 boost::beast::http::async_write(
76 derived().stream(),
77 m_req,
78 boost::beast::bind_front_handler(
79 &connection::on_write,
80 derived().shared_from_this()
81 )
82 );
83 }
84
85 private:
86 boost::asio::ip::tcp::resolver m_resolver;
87 boost::beast::flat_buffer m_buffer; // (Must persist between reads)
88 boost::beast::http::response_parser<boost::beast::http::empty_body> m_parser;
90 std::promise<malloy::error_code> m_err_channel;
91 std::optional<callback_t> m_cb;
92 Filter m_req_filter;
93
94 [[nodiscard]]
95 Derived&
96 derived()
97 {
98 return static_cast<Derived&>(*this);
99 }
100
101 void
102 on_resolve(const boost::beast::error_code& ec, boost::asio::ip::tcp::resolver::results_type results)
103 {
104 m_logger->trace("on_resolve()");
105
106 if (ec) {
107 m_logger->error("on_resolve: {}", ec.message());
108 m_err_channel.set_value(ec);
109 return;
110 }
111
112 // Set a timeout on the operation
113 boost::beast::get_lowest_layer(derived().stream()).expires_after(std::chrono::seconds(30));
114
115 // Make the connection on the IP address we get from a lookup
116 boost::beast::get_lowest_layer(derived().stream()).async_connect(
117 results,
118 boost::beast::bind_front_handler(
119 &connection::on_connect,
120 derived().shared_from_this()
121 )
122 );
123 }
124
125 void
126 on_connect(const boost::beast::error_code& ec, boost::asio::ip::tcp::resolver::results_type::endpoint_type)
127 {
128 m_logger->trace("on_connect()");
129
130 if (ec) {
131 m_logger->error("on_connect: {}", ec.message());
132 m_err_channel.set_value(ec);
133 return;
134 }
135
136 // Set a timeout on the operation
137 boost::beast::get_lowest_layer(derived().stream()).expires_after(std::chrono::seconds(30));
138
139 // Call hook
140 derived().hook_connected();
141 }
142
143 void
144 on_write(const boost::beast::error_code& ec, [[maybe_unused]] const std::size_t bytes_transferred)
145 {
146 if (ec) {
147 m_logger->error("on_write: {}", ec.message());
148 m_err_channel.set_value(ec);
149 return;
150 }
151
152 // Receive the HTTP response
153 boost::beast::http::async_read_header(
154 derived().stream(),
155 m_buffer,
156 m_parser,
157 malloy::bind_front_handler(
158 &connection::on_read_header,
159 derived().shared_from_this()
160 )
161 );
162 }
163
164 void
165 on_read_header(malloy::error_code ec, std::size_t)
166 {
167 if (ec) {
168 m_logger->error("on_read_header: '{}'", ec.message());
169 m_err_channel.set_value(ec);
170 return;
171 }
172
173 // Pick a body and parse it from the stream
174 auto bodies = m_req_filter.body_for(m_parser.get().base());
175 std::visit([this](auto&& body) {
176 using body_t = std::decay_t<decltype(body)>;
177
178 auto parser = std::make_shared<boost::beast::http::response_parser<body_t>>(std::move(m_parser));
179 m_req_filter.setup_body(parser->get().base(), parser->get().body());
180
181 boost::beast::http::async_read(
182 derived().stream(),
183 m_buffer,
184 *parser,
185 [this, parser, me = derived().shared_from_this()](auto ec, auto) {
186 if (ec) {
187 m_logger->error("on_read(): {}", ec.message());
188 m_err_channel.set_value(ec);
189 return;
190 }
191
192 // Notify via callback
193 (*m_cb)(malloy::http::response<body_t>{parser->release()});
194 on_read();
195 });
196 },
197 std::move(bodies)
198 );
199 }
200
201 void
202 on_read()
203 {
204 // Gracefully close the socket
206 boost::beast::get_lowest_layer(derived().stream()).socket().shutdown(boost::asio::ip::tcp::socket::shutdown_both, ec);
207
208 // not_connected happens sometimes so don't bother reporting it.
209 if (ec && ec != boost::beast::errc::not_connected)
210 m_logger->error("shutdown: {}", ec.message());
211
212 m_err_channel.set_value(ec); // Set it even if its not an error, to signify that we are done
213 }
214 };
215
216}
Definition: connection.hpp:26
Definition: request.hpp:19
constexpr std::uint16_t port() const noexcept
Definition: request.hpp:118
Definition: response.hpp:22
boost::beast::error_code error_code
Error code used to signify errors without throwing. Truthy means it holds an error.
Definition: error.hpp:9