diff --git a/api/net/botan/tls_server.hpp b/api/net/botan/tls_server.hpp index b187243069..7c7aa283f2 100644 --- a/api/net/botan/tls_server.hpp +++ b/api/net/botan/tls_server.hpp @@ -55,6 +55,18 @@ class Server : public Botan::TLS::Callbacks, public net::Stream m_transport->on_read(bs, {this, &Server::tls_read}); this->m_on_read = cb; } + void on_data(DataCallback cb) override { + // FIXME + throw std::runtime_error("on_data not implemented on botan::server"); + } + size_t next_size() override { + // FIXME + throw std::runtime_error("next_size not implemented on botan::server"); + } + buffer_t read_next() override { + // FIXME + throw std::runtime_error("read_next not implemented on botan::server"); + } void on_write(WriteCallback cb) override { this->m_on_write = cb; } diff --git a/api/net/openssl/tls_stream.hpp b/api/net/openssl/tls_stream.hpp index 6ff53ada1a..ea60df6d46 100644 --- a/api/net/openssl/tls_stream.hpp +++ b/api/net/openssl/tls_stream.hpp @@ -2,18 +2,18 @@ #include #include #include -#include +#include -//#define VERBOSE_OPENSSL +//#define VERBOSE_OPENSSL 0 #ifdef VERBOSE_OPENSSL -#define TLS_PRINT(fmt, ...) printf(fmt, ##__VA_ARGS__) +#define TLS_PRINT(fmt, ...) printf("TLS_Stream");printf(fmt, ##__VA_ARGS__) #else #define TLS_PRINT(fmt, ...) /* fmt */ #endif namespace openssl { - struct TLS_stream : public net::Stream + struct TLS_stream : public net::StreamBuffer { using Stream_ptr = net::Stream_ptr; @@ -25,7 +25,6 @@ namespace openssl void write(const std::string&) override; void write(const void* buf, size_t n) override; void close() override; - void reset_callbacks() override; net::Socket local() const override { return m_transport->local(); @@ -37,24 +36,11 @@ namespace openssl return m_transport->to_string(); } - void on_connect(ConnectCallback cb) override { - m_on_connect = std::move(cb); - } - void on_read(size_t, ReadCallback cb) override { - m_on_read = std::move(cb); - } - void on_close(CloseCallback cb) override { - m_on_close = std::move(cb); - } - void on_write(WriteCallback cb) override { - m_on_write = std::move(cb); - } - bool is_connected() const noexcept override { return handshake_completed() && m_transport->is_connected(); } bool is_writable() const noexcept override { - return is_connected() && m_transport->is_writable(); + return (not write_congested()) && is_connected() && m_transport->is_writable(); } bool is_readable() const noexcept override { return m_transport->is_readable(); @@ -76,7 +62,12 @@ namespace openssl size_t serialize_to(void*) const override; + void handle_read_congestion() override; + void handle_write_congestion() override; private: + void handle_data(); + int decrypt(const void *data,int size); + int send_decrypted(); void tls_read(buffer_t); int tls_perform_stream_write(); int tls_perform_handshake(); @@ -89,271 +80,12 @@ namespace openssl STATUS_FAIL }; status_t status(int n) const noexcept; - Stream_ptr m_transport = nullptr; SSL* m_ssl = nullptr; BIO* m_bio_rd = nullptr; BIO* m_bio_wr = nullptr; bool m_busy = false; bool m_deferred_close = false; - ConnectCallback m_on_connect = nullptr; - ReadCallback m_on_read = nullptr; - WriteCallback m_on_write = nullptr; - CloseCallback m_on_close = nullptr; }; - inline TLS_stream::TLS_stream(SSL_CTX* ctx, Stream_ptr t, bool outgoing) - : m_transport(std::move(t)) - { - ERR_clear_error(); // prevent old errors from mucking things up - this->m_bio_rd = BIO_new(BIO_s_mem()); - this->m_bio_wr = BIO_new(BIO_s_mem()); - assert(ERR_get_error() == 0 && "Initializing BIOs"); - this->m_ssl = SSL_new(ctx); - assert(this->m_ssl != nullptr); - assert(ERR_get_error() == 0 && "Initializing SSL"); - // TLS server-mode - if (outgoing == false) - SSL_set_accept_state(this->m_ssl); - else - SSL_set_connect_state(this->m_ssl); - - SSL_set_bio(this->m_ssl, this->m_bio_rd, this->m_bio_wr); - // always-on callbacks - m_transport->on_read(8192, {this, &TLS_stream::tls_read}); - m_transport->on_close({this, &TLS_stream::close_callback_once}); - - // start TLS handshake process - if (outgoing == true) - { - if (this->tls_perform_handshake() < 0) return; - } - } - inline TLS_stream::TLS_stream(Stream_ptr t, SSL* ssl, BIO* rd, BIO* wr) - : m_transport(std::move(t)), m_ssl(ssl), m_bio_rd(rd), m_bio_wr(wr) - { - // always-on callbacks - m_transport->on_read(8192, {this, &TLS_stream::tls_read}); - m_transport->on_close({this, &TLS_stream::close_callback_once}); - } - inline TLS_stream::~TLS_stream() - { - assert(m_busy == false && "Cannot delete stream while in its call stack"); - SSL_free(this->m_ssl); - } - - inline void TLS_stream::write(buffer_t buffer) - { - if (UNLIKELY(this->is_connected() == false)) { - TLS_PRINT("TLS_stream::write() called on closed stream\n"); - return; - } - - int n = SSL_write(this->m_ssl, buffer->data(), buffer->size()); - auto status = this->status(n); - if (status == STATUS_FAIL) { - this->close(); - return; - } - - do { - n = tls_perform_stream_write(); - } while (n > 0); - } - inline void TLS_stream::write(const std::string& str) - { - write(net::Stream::construct_buffer(str.data(), str.data() + str.size())); - } - inline void TLS_stream::write(const void* data, const size_t len) - { - auto* buf = static_cast (data); - write(net::Stream::construct_buffer(buf, buf + len)); - } - - inline void TLS_stream::tls_read(buffer_t buffer) - { - ERR_clear_error(); - uint8_t* buf = buffer->data(); - int len = buffer->size(); - - while (len > 0) - { - int n = BIO_write(this->m_bio_rd, buf, len); - if (UNLIKELY(n < 0)) { - this->close(); - return; - } - buf += n; - len -= n; - - // if we aren't finished initializing session - if (UNLIKELY(!handshake_completed())) - { - int num = SSL_do_handshake(this->m_ssl); - auto status = this->status(num); - - // OpenSSL wants to write - if (status == STATUS_WANT_IO) - { - tls_perform_stream_write(); - } - else if (status == STATUS_FAIL) - { - if (num < 0) { - TLS_PRINT("TLS_stream::SSL_do_handshake() returned %d\n", num); - #ifdef VERBOSE_OPENSSL - ERR_print_errors_fp(stdout); - #endif - } - this->close(); - return; - } - // nothing more to do if still not finished - if (handshake_completed() == false) return; - // handshake success - if (m_on_connect) m_on_connect(*this); - } - - // read decrypted data - do { - char temp[8192]; - n = SSL_read(this->m_ssl, temp, sizeof(temp)); - if (n > 0) { - auto buf = net::Stream::construct_buffer(temp, temp + n); - if (m_on_read) { - this->m_busy = true; - m_on_read(std::move(buf)); - this->m_busy = false; - } - } - } while (n > 0); - // this goes here? - if (UNLIKELY(this->is_closing() || this->is_closed())) { - TLS_PRINT("TLS_stream::SSL_read closed during read\n"); - return; - } - if (this->m_deferred_close) { - this->close(); return; - } - - auto status = this->status(n); - // did peer request stream renegotiation? - if (status == STATUS_WANT_IO) - { - do { - n = tls_perform_stream_write(); - } while (n > 0); - } - else if (status == STATUS_FAIL) - { - this->close(); - return; - } - // check deferred closing - if (this->m_deferred_close) { - this->close(); return; - } - - } // while it < end - } // tls_read() - - inline int TLS_stream::tls_perform_stream_write() - { - ERR_clear_error(); - int pending = BIO_ctrl_pending(this->m_bio_wr); - //printf("pending: %d\n", pending); - if (pending > 0) - { - auto buffer = net::Stream::construct_buffer(pending); - int n = BIO_read(this->m_bio_wr, buffer->data(), buffer->size()); - assert(n == pending); - m_transport->write(buffer); - if (m_on_write) { - this->m_busy = true; - m_on_write(n); - this->m_busy = false; - } - return n; - } - else { - BIO_read(this->m_bio_wr, nullptr, 0); - } - if (!BIO_should_retry(this->m_bio_wr)) - { - this->close(); - return -1; - } - return 0; - } - inline int TLS_stream::tls_perform_handshake() - { - ERR_clear_error(); // prevent old errors from mucking things up - // will return -1:SSL_ERROR_WANT_WRITE - int ret = SSL_do_handshake(this->m_ssl); - int n = this->status(ret); - ERR_print_errors_fp(stderr); - if (n == STATUS_WANT_IO) - { - do { - n = tls_perform_stream_write(); - if (n < 0) { - TLS_PRINT("TLS_stream::tls_perform_handshake() stream write failed\n"); - } - } while (n > 0); - return n; - } - else { - TLS_PRINT("TLS_stream::tls_perform_handshake() returned %d\n", ret); - this->close(); - return -1; - } - } - - inline void TLS_stream::close() - { - //ERR_clear_error(); - if (this->m_busy) { - this->m_deferred_close = true; return; - } - CloseCallback func = std::move(this->m_on_close); - this->reset_callbacks(); - if (m_transport->is_connected()) - m_transport->close(); - if (func) func(); - } - inline void TLS_stream::close_callback_once() - { - if (this->m_busy) { - this->m_deferred_close = true; return; - } - CloseCallback func = std::move(this->m_on_close); - this->reset_callbacks(); - if (func) func(); - } - inline void TLS_stream::reset_callbacks() - { - this->m_on_close = nullptr; - this->m_on_connect = nullptr; - this->m_on_read = nullptr; - this->m_on_write = nullptr; - } - - inline bool TLS_stream::handshake_completed() const noexcept - { - return SSL_is_init_finished(this->m_ssl); - } - inline TLS_stream::status_t TLS_stream::status(int n) const noexcept - { - int error = SSL_get_error(this->m_ssl, n); - switch (error) - { - case SSL_ERROR_NONE: - return STATUS_OK; - case SSL_ERROR_WANT_WRITE: - case SSL_ERROR_WANT_READ: - return STATUS_WANT_IO; - default: - return STATUS_FAIL; - } - } } // openssl diff --git a/api/net/stream.hpp b/api/net/stream.hpp index 9eec91b4c5..66283975c3 100644 --- a/api/net/stream.hpp +++ b/api/net/stream.hpp @@ -56,13 +56,32 @@ namespace net { /** Called with a shared buffer and the length of the data when received. */ using ReadCallback = delegate; /** - * @brief Event when data is received. + * @brief Event when data is received. Pushes data to the callback. * * @param[in] n The size of the receive buffer * @param[in] cb The read callback */ virtual void on_read(size_t n, ReadCallback cb) = 0; + using DataCallback = delegate; + /** + * @brief Event when data is received. + * Does not push data, just signals its presence. + * + * @param[in] cb The callback + */ + virtual void on_data(DataCallback cb) = 0; + + /** + * @return The size of the next available chunk of data if any. + */ + virtual size_t next_size() = 0; + + /** + * @return The next available chunk of data if any. + */ + virtual buffer_t read_next() = 0; + /** Called with nothing ¯\_(ツ)_/¯ */ using CloseCallback = delegate; /** diff --git a/api/net/stream_buffer.hpp b/api/net/stream_buffer.hpp new file mode 100644 index 0000000000..700b5694ab --- /dev/null +++ b/api/net/stream_buffer.hpp @@ -0,0 +1,218 @@ +#ifndef STREAMBUFFERR_HPP +#define STREAMBUFFERR_HPP +#include +#include +#include + +namespace net { + class StreamBuffer : public net::Stream + { + public: + StreamBuffer(Timers::duration_t timeout=std::chrono::microseconds(10)) + : timer({this,&StreamBuffer::congested}),congestion_timeout(timeout) {} + using buffer_t = os::mem::buf_ptr; + using Ready_queue = std::deque; + virtual ~StreamBuffer() { + timer.stop(); + } + + void on_connect(ConnectCallback cb) override { + m_on_connect = std::move(cb); + } + + void on_read(size_t, ReadCallback cb) override { + m_on_read = std::move(cb); + signal_data(); + } + void on_data(DataCallback cb) override { + m_on_data = std::move(cb); + signal_data(); + } + size_t next_size() override; + + buffer_t read_next() override; + + void on_close(CloseCallback cb) override { + m_on_close = std::move(cb); + } + void on_write(WriteCallback cb) override { + m_on_write = std::move(cb); + } + + void signal_data(); + + bool read_congested() const noexcept + { return m_read_congested; } + + bool write_congested() const noexcept + { return m_write_congested; } + + /** + * @brief Construct a shared read vector used by streams + * If allocation failed congestion flag is set + * + * @param construction parameters + * + * @return nullptr on failure, shared_ptr to buffer on success + */ + template + buffer_t construct_read_buffer(Args&&... args) + { + return construct_buffer_with_flag(m_read_congested,std::forward (args)...); + } + + /** + * @brief Construct a shared write vector used by streams + * If allocation failed congestion flag is set + * + * @param construction parameters + * + * @return nullptr on failure, shared_ptr to buffer on success + */ + template + buffer_t construct_write_buffer(Args&&... args) + { + return construct_buffer_with_flag(m_write_congested,std::forward (args)...); + } + + virtual void handle_read_congestion() = 0; + virtual void handle_write_congestion() = 0; + protected: + void closed() + { if (m_on_close) m_on_close(); } + void connected() + { if (m_on_connect) m_on_connect(*this); } + void stream_on_write(int n) + { if (m_on_write) m_on_write(n); } + void enqueue_data(buffer_t data) + { m_send_buffers.push_back(data); } + + void congested(); + + CloseCallback getCloseCallback() { return std::move(this->m_on_close); } + + void reset_callbacks() override + { + //remove queue and reset congestion flags and busy flag ?? + this->m_on_close = nullptr; + this->m_on_connect = nullptr; + this->m_on_read = nullptr; + this->m_on_write = nullptr; + this->m_on_data = nullptr; + } + Timer timer; + + private: + Timer::duration_t congestion_timeout; + bool m_write_congested= false; + bool m_read_congested = false; + + ConnectCallback m_on_connect = nullptr; + ReadCallback m_on_read = nullptr; + DataCallback m_on_data = nullptr; + WriteCallback m_on_write = nullptr; + CloseCallback m_on_close = nullptr; + Ready_queue m_send_buffers; + + /** + * @brief Construct a shared vector and set congestion flag if allocation fails + * + * @param flag the flag to set true or false on allocation failure + * @param args arguments to constructing the buffer + * @return nullptr on failure , shared pointer to buffer on success + */ + + template + buffer_t construct_buffer_with_flag(bool &flag,Args&&... args) + { + static buffer_t buffer; + try + { + buffer = std::make_shared(std::forward (args)...); + flag = false; + } + catch (std::bad_alloc &e) + { + flag = true; + timer.start(congestion_timeout); + return nullptr; + } + return buffer; + } + + + }; // < class StreamBuffer + + inline size_t StreamBuffer::next_size() + { + if (not m_send_buffers.empty()) { + return m_send_buffers.front()->size(); + } + return 0; + } + + inline StreamBuffer::buffer_t StreamBuffer::read_next() + { + + if (not m_send_buffers.empty()) { + auto buf = m_send_buffers.front(); + m_send_buffers.pop_front(); + return buf; + } + return nullptr; + } + + inline void StreamBuffer::congested() + { + if (m_read_congested) + { + handle_read_congestion(); + } + if (m_write_congested) + { + handle_write_congestion(); + } + //if any of the congestion states are still active make sure the timer is running + if(m_read_congested or m_write_congested) + { + if (!timer.is_running()) + { + timer.start(congestion_timeout); + } + } + else + { + if (timer.is_running()) + { + timer.stop(); + } + } + } + + inline void StreamBuffer::signal_data() + { + if (not m_send_buffers.empty()) + { + if (m_on_data != nullptr){ + //on_data_callback(); + m_on_data(); + if (not m_send_buffers.empty()) { + m_read_congested=true; + timer.start(congestion_timeout); + } + } + else if (m_on_read != nullptr) + { + for (auto buf : m_send_buffers) { + // Pop each time, in case callback leads to another call here. + m_send_buffers.pop_front(); + m_on_read(buf); + if (m_on_read == nullptr) { + break; + } //if calling m_on_read reset the callbacks exit + } + } + } + } +} // namespace net +#endif // STREAMBUFFERR_HPP diff --git a/api/net/tcp/connection.hpp b/api/net/tcp/connection.hpp index 39ab8c9258..a66bc33d90 100644 --- a/api/net/tcp/connection.hpp +++ b/api/net/tcp/connection.hpp @@ -93,6 +93,35 @@ class Connection { */ inline Connection& on_read(size_t recv_bufsz, ReadCallback callback); + + using DataCallback = delegate; + /** + * @brief Event when incoming data is received by the connection. + * The callback is called when either 1) PSH is seen, or 2) the buffer is full + * + * The user is expected to fetch data by calling read_next, otherwise the + * event will be triggered again. Unread data will be buffered as long as + * there is capacity in the read queue. + * If an on_read callback is also registered, this event has no effect. + * + * @param[in] callback The callback + * + * @return This connection + */ + inline Connection& on_data(DataCallback callback); + + /** + * @brief Read the next fully acked chunk of received data if any. + * + * @return Pointer to buffer if any, otherwise nullptr. + */ + inline buffer_t read_next(); + + /** + * @return The size of the next fully acked chunk of received data. + */ + inline size_t next_size(); + /** Called with the connection itself and the reason wrapped in a Disconnect struct. */ using DisconnectCallback = delegate; /** @@ -607,6 +636,11 @@ class Connection { void set_recv_wnd_getter(Recv_window_getter func) { recv_wnd_getter = func; } + void release_memory() { + read_request = nullptr; + bufalloc.reset(); + } + private: /** "Parent" for Connection. */ TCP& host_; @@ -709,6 +743,14 @@ class Connection { */ void _on_read(size_t recv_bufsz, ReadCallback cb); + /** + * @brief Set the on_data handler + * + * @param[in] cb The callback + */ + void _on_data(DataCallback cb); + + // Retrieve the associated shared_ptr for a connection, if it exists // Throws out_of_range if it doesn't Connection_ptr retrieve_shared(); @@ -722,7 +764,7 @@ class Connection { * * @param Connection to be cleaned up */ - using CleanupCallback = delegate; + using CleanupCallback = delegate; CleanupCallback _on_cleanup_; inline Connection& _on_cleanup(CleanupCallback cb); @@ -861,6 +903,20 @@ class Connection { */ bool handle_ack(const Packet_view&); + void update_rcv_wnd() { + cb.RCV.WND = (recv_wnd_getter == nullptr) ? + calculate_rcv_wnd() : recv_wnd_getter(); + } + + uint32_t calculate_rcv_wnd() const; + + void send_window_update() { + update_rcv_wnd(); + send_ack(); + } + + void trigger_window_update(os::mem::Pmr_resource& res); + /** * @brief Receive data from an incoming packet containing data. * diff --git a/api/net/tcp/connection.inc b/api/net/tcp/connection.inc index 166d3ef0e2..cb60015695 100644 --- a/api/net/tcp/connection.inc +++ b/api/net/tcp/connection.inc @@ -14,6 +14,11 @@ inline Connection& Connection::on_read(size_t recv_bufsz, ReadCallback cb) return *this; } +inline Connection& Connection::on_data(DataCallback cb) { + _on_data(cb); + return *this; +} + inline Connection& Connection::on_disconnect(DisconnectCallback cb) { on_disconnect_ = cb; return *this; @@ -31,7 +36,7 @@ inline Connection& Connection::on_close(CloseCallback cb) { inline Connection& Connection::set_on_read_callback(ReadCallback cb) { Expects(read_request != nullptr && "on_read hasn't been called before."); - read_request->callback = cb; + read_request->on_read_callback = cb; return *this; } @@ -40,6 +45,21 @@ inline Connection& Connection::_on_cleanup(CleanupCallback cb) { return *this; } +inline buffer_t Connection::read_next() { + static buffer_t empty_buf{}; + if (UNLIKELY(read_request == nullptr)) { + return empty_buf; + } + return read_request->read_next(); +} + +inline size_t Connection::next_size() { + if (UNLIKELY(read_request == nullptr)) { + return 0; + } + return read_request->next_size(); +} + inline void Connection::write(const void* buf, size_t n) { this->write(tcp::construct_buffer((uint8_t*) buf, (uint8_t*) buf + n)); } diff --git a/api/net/tcp/connection_states.hpp b/api/net/tcp/connection_states.hpp index 869ae69de5..ca95f9acb2 100644 --- a/api/net/tcp/connection_states.hpp +++ b/api/net/tcp/connection_states.hpp @@ -352,11 +352,15 @@ class Connection::LastAck : public State { */ virtual Result handle(Connection&, Packet_view& in) override; - inline virtual std::string to_string() const override { + std::string to_string() const override { return "LAST-ACK"; }; - inline virtual bool is_closing() const override { + bool is_closing() const override { + return true; + } + + bool is_closed() const override { return true; } diff --git a/api/net/tcp/listener.hpp b/api/net/tcp/listener.hpp index 6f0af4df74..9a2202b56b 100644 --- a/api/net/tcp/listener.hpp +++ b/api/net/tcp/listener.hpp @@ -103,7 +103,7 @@ class Listener { void segment_arrived(Packet_view&); - void remove(Connection_ptr); + void remove(const Connection*); void connected(Connection_ptr); diff --git a/api/net/tcp/packet_view.hpp b/api/net/tcp/packet_view.hpp index 2f4a707c3c..1fadc3e43a 100644 --- a/api/net/tcp/packet_view.hpp +++ b/api/net/tcp/packet_view.hpp @@ -163,7 +163,10 @@ class Packet_v { const Option::opt_ts* ts_option() const noexcept { return ts_opt; } - inline const Option::opt_ts* parse_ts_option() noexcept; + inline const Option::opt_ts* parse_ts_option() const noexcept; + + void set_ts_option(const Option::opt_ts* opt) + { this->ts_opt = opt; } // Data // @@ -238,7 +241,7 @@ class Packet_v { private: - Option::opt_ts* ts_opt = nullptr; + const Option::opt_ts* ts_opt = nullptr; virtual void set_ip_src(const net::Addr& addr) noexcept = 0; virtual void set_ip_dst(const net::Addr& addr) noexcept = 0; @@ -312,19 +315,34 @@ inline void Packet_v::add_tcp_option_aligned(Args&&... args) { set_length(); // update } -// assumes the packet contains no other options. template -inline const Option::opt_ts* Packet_v::parse_ts_option() noexcept +inline const Option::opt_ts* Packet_v::parse_ts_option() const noexcept { auto* opt = this->tcp_options(); - // TODO: improve by iterate option instead of byte (see Connection::parse_options) - while(((Option*)opt)->kind == Option::NOP and opt < (uint8_t*)this->tcp_data()) - opt++; - - if(((Option*)opt)->kind == Option::TS) - this->ts_opt = (Option::opt_ts*)opt; + while(opt < (uint8_t*)this->tcp_data()) + { + auto* option = (Option*)opt; + switch(option->kind) + { + case Option::NOP: { + opt++; + break; + } + + case Option::TS: { + return reinterpret_cast(option); + } + + case Option::END: { + return nullptr; + } + + default: + opt += option->length; + } + } - return this->ts_opt; + return nullptr; } template diff --git a/api/net/tcp/read_buffer.hpp b/api/net/tcp/read_buffer.hpp index 7b657bee76..a7539ebd57 100644 --- a/api/net/tcp/read_buffer.hpp +++ b/api/net/tcp/read_buffer.hpp @@ -84,6 +84,14 @@ class Read_buffer { buffer_t buffer() { return buf; } + /** + * @brief Check if internal buffer has unhandled data + * + * @return True if the internal buffer is unique with data and doesnt contain hole + */ + bool has_unhandled_data() + { return (buf.unique() && (size() > 0) && (missing() == 0)); } + /** * @brief Exposes the internal buffer (read only) * diff --git a/api/net/tcp/read_request.hpp b/api/net/tcp/read_request.hpp index efe9e93e02..203dc329e3 100644 --- a/api/net/tcp/read_request.hpp +++ b/api/net/tcp/read_request.hpp @@ -30,12 +30,15 @@ class Read_request { public: using Buffer_ptr = std::unique_ptr; using Buffer_queue = std::deque; + using Ready_queue = std::deque; using ReadCallback = delegate; + using DataCallback = delegate; using Alloc = os::mem::buffer::allocator_type; static constexpr size_t buffer_limit = 2; - ReadCallback callback; + ReadCallback on_read_callback = nullptr; + DataCallback on_data_callback = nullptr; - Read_request(seq_t start, size_t min, size_t max, ReadCallback cb, Alloc&& alloc = Alloc()); + Read_request(seq_t start, size_t min, size_t max, Alloc&& alloc = Alloc()); size_t insert(seq_t seq, const uint8_t* data, size_t n, bool psh = false); @@ -47,6 +50,9 @@ class Read_request { void reset(const seq_t seq); + size_t next_size(); + buffer_t read_next(); + const Read_buffer& front() const { return *buffers.front(); } @@ -57,7 +63,10 @@ class Read_request { { return buffers; } private: + void signal_data(); + Buffer_queue buffers; + Ready_queue complete_buffers; Alloc alloc; Read_buffer* get_buffer(const seq_t seq); diff --git a/api/net/tcp/stream.hpp b/api/net/tcp/stream.hpp index 873cd79a22..908ceb078c 100644 --- a/api/net/tcp/stream.hpp +++ b/api/net/tcp/stream.hpp @@ -50,6 +50,31 @@ namespace net::tcp void on_read(size_t n, ReadCallback cb) override { m_tcp->on_read(n, cb); } + /** + * @brief Event when data is received. + * Does not push data, just signals its presence. + * + * @param[in] cb The callback + */ + void on_data(DataCallback cb) override { + m_tcp->on_data(cb); + }; + + /** + * @return The size of the next available chunk of data if any. + */ + size_t next_size() override { + return m_tcp->next_size(); + }; + + /** + * @return The next available chunk of data if any. + */ + buffer_t read_next() override { + return m_tcp->read_next(); + }; + + /** * @brief Event for when the Stream is being closed. * diff --git a/api/net/tcp/tcp.hpp b/api/net/tcp/tcp.hpp index 37ea2a4747..9e48eaa175 100644 --- a/api/net/tcp/tcp.hpp +++ b/api/net/tcp/tcp.hpp @@ -535,15 +535,6 @@ namespace net { return this->cpu_id; } - /** - * @brief Return a value that's supposed to describe how much - * a connection should announce as it's RCV WND, - * with regards to the whole system. - * - * @return A RCV WND value, maximum 1GB - */ - static uint32_t global_recv_wnd(); - private: IPStack& inet_; Listeners listeners_; @@ -717,8 +708,10 @@ namespace net { * @brief Adds a connection. * * @param[in] A ptr to the Connection + * + * @return True if the connection was added, false if rejected */ - void add_connection(tcp::Connection_ptr); + bool add_connection(tcp::Connection_ptr); /** * @brief Creates a connection. @@ -738,7 +731,7 @@ namespace net { * * @param[in] conn A ptr to a Connection */ - void close_connection(tcp::Connection_ptr conn) + void close_connection(const tcp::Connection* conn) { unbind(conn->local()); connections_.erase(conn->tuple()); diff --git a/api/util/alloc_pmr.hpp b/api/util/alloc_pmr.hpp index 4ce2e9cfca..e676dc2326 100644 --- a/api/util/alloc_pmr.hpp +++ b/api/util/alloc_pmr.hpp @@ -43,7 +43,8 @@ namespace os::mem { class Pmr_pool { public: - static constexpr size_t default_max_resources = 64; + static constexpr size_t default_max_resources = 0xffffff; + static constexpr size_t resource_division_offset = 2; using Resource = Pmr_resource; using Resource_ptr = std::unique_ptr>; @@ -73,6 +74,7 @@ namespace os::mem { class Pmr_resource : public std::pmr::memory_resource { public: using Pool_ptr = detail::Pool_ptr; + using Event = delegate; inline Pmr_resource(Pool_ptr p); inline Pool_ptr pool(); inline void* do_allocate(std::size_t size, std::size_t align) override; @@ -85,16 +87,29 @@ namespace os::mem { inline std::size_t dealloc_count(); inline bool full(); inline bool empty(); + + /** Fires when the resource has been full and is not full anymore **/ + void on_non_full(Event e){ non_full = e; } + + /** Fires on transition from < N bytes to >= N bytes allocatable **/ + void on_avail(std::size_t N, Event e) { avail_thresh = N; avail = e; } + private: Pool_ptr pool_; std::size_t used = 0; std::size_t allocs = 0; std::size_t deallocs = 0; + std::size_t avail_thresh = 0; + Event non_full{}; + Event avail{}; }; struct Default_pmr : public std::pmr::memory_resource { void* do_allocate(std::size_t size, std::size_t align) override { - return memalign(align, size); + auto* res = memalign(align, size); + if (res == nullptr) + throw std::bad_alloc(); + return res; } void do_deallocate (void* ptr, size_t, size_t) override { diff --git a/api/util/detail/alloc_pmr.hpp b/api/util/detail/alloc_pmr.hpp index bb02ec08c9..6a7e317622 100644 --- a/api/util/detail/alloc_pmr.hpp +++ b/api/util/detail/alloc_pmr.hpp @@ -33,6 +33,7 @@ namespace os::mem::detail { void* do_allocate(size_t size, size_t align) override { if (UNLIKELY(size + allocated_ > cap_total_)) { + //printf("pmr about to throw bad alloc: sz=%zu alloc=%zu cap=%zu\n", size, allocated_, cap_total_); throw std::bad_alloc(); } @@ -46,6 +47,7 @@ namespace os::mem::detail { void* buf = memalign(align, size); if (buf == nullptr) { + //printf("pmr memalign return nullptr, throw bad alloc\n"); throw std::bad_alloc(); } @@ -151,10 +153,10 @@ namespace os::mem::detail { } std::size_t resource_capacity() { - if (cap_suballoc_ == 0) { - if (used_resources_ == 0) - return cap_total_; - return cap_total_ / used_resources_; + if (cap_suballoc_ == 0) + { + auto div = cap_total_ / (used_resources_ + os::mem::Pmr_pool::resource_division_offset); + return std::min(div, allocatable()); } return cap_suballoc_; } @@ -247,7 +249,9 @@ namespace os::mem { // Pmr_resource implementation // Pmr_resource::Pmr_resource(Pool_ptr p) : pool_{p} {} - std::size_t Pmr_resource::capacity() { return pool_->resource_capacity(); } + std::size_t Pmr_resource::capacity() { + return pool_->resource_capacity(); + } std::size_t Pmr_resource::allocatable() { auto cap = capacity(); if (used > cap) @@ -266,24 +270,37 @@ namespace os::mem { void* Pmr_resource::do_allocate(std::size_t size, std::size_t align) { auto cap = capacity(); if (UNLIKELY(size + used > cap)) { - printf(" ERROR: Failed to alloc %zu - currently allocated %zu capacity %zu\n", - this, size, used, cap); throw std::bad_alloc(); } void* buf = pool_->allocate(size, align); - used += size; allocs++; - return buf; } void Pmr_resource::do_deallocate(void* ptr, std::size_t s, std::size_t a) { Expects(s != 0); // POSIX malloc will allow size 0, but return nullptr. + bool trigger_non_full = UNLIKELY(full() and non_full != nullptr); + bool trigger_avail_thresh = UNLIKELY(allocatable() < avail_thresh + and allocatable() + s >= avail_thresh + and avail != nullptr); + pool_->deallocate(ptr,s,a); deallocs++; used -= s; + + if (UNLIKELY(trigger_avail_thresh)) { + Ensures(allocatable() >= avail_thresh); + Ensures(avail != nullptr); + avail(*this); + } + + if (UNLIKELY(trigger_non_full)) { + Ensures(!full()); + Ensures(non_full != nullptr); + non_full(*this); + } } bool Pmr_resource::do_is_equal(const std::pmr::memory_resource& other) const noexcept { diff --git a/examples/transfer/CMakeLists.txt b/examples/transfer/CMakeLists.txt new file mode 100644 index 0000000000..1e3f8b72d4 --- /dev/null +++ b/examples/transfer/CMakeLists.txt @@ -0,0 +1,41 @@ +cmake_minimum_required(VERSION 2.8.9) + +# IncludeOS install location +if (NOT DEFINED ENV{INCLUDEOS_PREFIX}) + set(ENV{INCLUDEOS_PREFIX} /usr/local) +endif() +include($ENV{INCLUDEOS_PREFIX}/includeos/pre.service.cmake) +project (tcp) + +# Human-readable name of your service +set(SERVICE_NAME "TCP Example Service") + +# Name of your service binary +set(BINARY "tcp_example") + +# Source files to be linked with OS library parts to form bootable image +set(SOURCES + service.cpp # ...add more here + ) + +# To add your own include paths: +# set(LOCAL_INCLUDES ".") + +# Adding memdisk (expects my.disk to exist in current dir): +# set(MEMDISK ${CMAKE_SOURCE_DIR}/my.disk) + +# DRIVERS / PLUGINS: + +set(DRIVERS + virtionet # Virtio networking + # ... Others from IncludeOS/src/drivers + ) + +set(PLUGINS + # syslogd # Syslog over UDP + # ...others + ) + + +# include service build script +include($ENV{INCLUDEOS_PREFIX}/includeos/post.service.cmake) diff --git a/examples/transfer/config.json b/examples/transfer/config.json new file mode 100644 index 0000000000..26564e1325 --- /dev/null +++ b/examples/transfer/config.json @@ -0,0 +1,11 @@ +{ + "net" : [ + { + "iface": 0, + "config": "static", + "address": "10.0.0.42", + "netmask": "255.255.255.0", + "gateway": "10.0.0.1" + } + ] +} diff --git a/examples/transfer/linux/CMakeLists.txt b/examples/transfer/linux/CMakeLists.txt new file mode 100644 index 0000000000..5a6fca00dc --- /dev/null +++ b/examples/transfer/linux/CMakeLists.txt @@ -0,0 +1,18 @@ +cmake_minimum_required(VERSION 2.8.9) +if (NOT DEFINED ENV{INCLUDEOS_PREFIX}) + set(ENV{INCLUDEOS_PREFIX} /usr/local) +endif() +project (service C CXX) + +# Human-readable name of your service +set(SERVICE_NAME "TCP Transfer From Linux") + +# Name of your service binary +set(BINARY "tcp_linux") + +# Source files to be linked with OS library parts to form bootable image +set(SOURCES + ../service.cpp + ) + +include($ENV{INCLUDEOS_PREFIX}/includeos/linux.service.cmake) diff --git a/examples/transfer/send_file.sh b/examples/transfer/send_file.sh new file mode 100755 index 0000000000..66fd069b9c --- /dev/null +++ b/examples/transfer/send_file.sh @@ -0,0 +1,2 @@ +#!/bin/bash +dd if=/dev/zero bs=1280 count=1048576 > /dev/tcp/10.0.0.42/81 diff --git a/examples/transfer/server.py b/examples/transfer/server.py new file mode 100755 index 0000000000..8bdc9400aa --- /dev/null +++ b/examples/transfer/server.py @@ -0,0 +1,37 @@ +import socket +import sys + +# Create a TCP/IP socket +sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + +# Bind the socket to the port +server_address = ('10.0.0.1', 1337) +print 'starting up on %s port %s' % server_address +sock.bind(server_address) + +# Listen for incoming connections +sock.listen(5) + +while True: + # Wait for a connection + print 'waiting for a connection' + connection, client_address = sock.accept() + + try: + print 'connection from', client_address + bytes = 0 + + while True: + data = connection.recv(8192) + if data: + bytes += len(data) + #print 'received: %d' % len(data) + connection.sendall(data) + else: + print 'received %d bytes' % bytes + print 'closing', client_address + break + + finally: + # Clean up the connection + connection.close() diff --git a/examples/transfer/service.cpp b/examples/transfer/service.cpp new file mode 100644 index 0000000000..0da15c0011 --- /dev/null +++ b/examples/transfer/service.cpp @@ -0,0 +1,92 @@ +// This file is a part of the IncludeOS unikernel - www.includeos.org +// +// Copyright 2015-2016 Oslo and Akershus University College of Applied Sciences +// and Alfred Bratterud +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +/** + * An example to show incoming and outgoing TCP Connections. + * In this example, IncludeOS is listening on port 80. + * + * Data received on port 80 will be redirected to a + * outgoing connection to a (in this case) python server (server.py) + * + * Data received from the python server connection + * will be redirected back to the client. + * + * To try it out, use netcat to connect to this IncludeOS instance. +**/ + +using Connection_ptr = net::tcp::Connection_ptr; +using Disconnect = net::tcp::Connection::Disconnect; + +// Address to our python server: 10.0.2.2:1337 +// @note: This may have to be modified depending on network and server settings. +net::Socket python_server{ {10,0,0,1} , 1337}; + +void Service::start() +{ +#ifdef USERSPACE_LINUX + extern void create_network_device(int N, const char* route, const char* ip); + create_network_device(0, "10.0.0.0/24", "10.0.0.1"); +#endif + auto& inet = net::Super_stack::get(0); + inet.network_config( + { 10, 0, 0, 42 }, // IP + { 255,255,255, 0 }, // Netmask + { 10, 0, 0, 1 }, // Gateway + { 10, 0, 0, 1 }); // DNS + + // Set up a TCP server on port 81 + auto& server = inet.tcp().listen(81); + printf("Server listening: %s \n", server.local().to_string().c_str()); + + // When someone connects to our server + server.on_connect( + [&inet] (Connection_ptr client) { + printf("Connected [Client]: %s\n", client->to_string().c_str()); + // Make an outgoing connection to our python server + auto outgoing = inet.tcp().connect(python_server); + // When outgoing connection to python sever is established + outgoing->on_connect( + [client] (Connection_ptr python) { + if (!python) { + printf("Connection failed!\n"); + return; + } + printf("Connected [Python]: %s\n", python->to_string().c_str()); + + // Setup handlers for when data is received on client and python connection + // When client reads data + client->on_read(1024, [python](auto buf) { + python->write(buf); + }); + + // When client is disconnecting + client->on_disconnect([python](Connection_ptr, Disconnect reason) { + printf("Disconnected [Client]: %s\n", reason.to_string().c_str()); + python->close(); + }); + + // When python is disconnecting + python->on_disconnect([client](Connection_ptr, Disconnect reason) { + printf("Disconnected [Python]: %s\n", reason.to_string().c_str()); + client->close(); + }); + }); // << onConnect (outgoing (python)) + }); // << onConnect (client) +} diff --git a/examples/transfer/vm.json b/examples/transfer/vm.json new file mode 100644 index 0000000000..7d0b112a2f --- /dev/null +++ b/examples/transfer/vm.json @@ -0,0 +1,3 @@ +{ + "mem" : 128 +} diff --git a/lib/LiveUpdate/serialize_tcp.cpp b/lib/LiveUpdate/serialize_tcp.cpp index 3a4b62964c..47ab82692d 100644 --- a/lib/LiveUpdate/serialize_tcp.cpp +++ b/lib/LiveUpdate/serialize_tcp.cpp @@ -175,11 +175,14 @@ void Connection::deserialize_from(void* addr) slumbering_ip4.insert(&this->host_.stack()); } + // Assign new memory resource from TCP + this->bufalloc = host_.mempool_.get_resource(); + /// restore read queue auto* readq = (read_buffer*) &area->vla[writeq_len]; if (readq->capacity) { - read_request = std::make_unique(readq->seq, readq->capacity, host_.max_bufsize(), nullptr, bufalloc.get()); + read_request = std::make_unique(readq->seq, readq->capacity, host_.max_bufsize(), bufalloc.get()); read_request->front().deserialize_from(readq); } diff --git a/lib/microLB/micro_lb/autoconf.cpp b/lib/microLB/micro_lb/autoconf.cpp index 221b108791..7406b44644 100644 --- a/lib/microLB/micro_lb/autoconf.cpp +++ b/lib/microLB/micro_lb/autoconf.cpp @@ -44,9 +44,8 @@ namespace microLB { assert(clients.HasMember("key") && "TLS-enabled microLB must also have key"); // create TLS over TCP load balancer - balancer = new Balancer(netinc, CLIENT_PORT, netout, use_active_check, - clients["certificate"].GetString(), - clients["key"].GetString()); + balancer = new Balancer(netinc, CLIENT_PORT, netout, clients["certificate"].GetString(), + clients["key"].GetString(), use_active_check); } else { // create TCP load balancer diff --git a/lib/microLB/micro_lb/balancer.cpp b/lib/microLB/micro_lb/balancer.cpp index 0a20e081c8..f7c69c38e8 100644 --- a/lib/microLB/micro_lb/balancer.cpp +++ b/lib/microLB/micro_lb/balancer.cpp @@ -1,9 +1,6 @@ #include "balancer.hpp" #include -#define READQ_PER_CLIENT 4096 -#define MAX_READQ_PER_NODE 8192 -#define READQ_FOR_NODES 8192 #define MAX_OUTGOING_ATTEMPTS 100 // checking if nodes are dead or not #define ACTIVE_INITIAL_PERIOD 8s @@ -14,7 +11,7 @@ #define LB_VERBOSE 0 #if LB_VERBOSE -#define LBOUT(fmt, ...) printf(fmt, ##__VA_ARGS__) +#define LBOUT(fmt, ...) printf("MICROLB: "); printf(fmt, ##__VA_ARGS__) #else #define LBOUT(fmt, ...) /** **/ #endif @@ -74,16 +71,21 @@ namespace microLB auto& client = queue.front(); assert(client.conn != nullptr); if (client.conn->is_connected()) { - // NOTE: explicitly want to copy buffers - net::Stream_ptr rval = - nodes.assign(std::move(client.conn), client.readq); - if (rval == nullptr) { - // done with this queue item - queue.pop_front(); - } - else { - // put connection back in queue item - client.conn = std::move(rval); + try { + // NOTE: explicitly want to copy buffers + net::Stream_ptr rval = + nodes.assign(std::move(client.conn)); + if (rval == nullptr) { + // done with this queue item + queue.pop_front(); + } + else { + // put connection back in queue item + client.conn = std::move(rval); + } + } catch (...) { + queue.pop_front(); // we have no choice + throw; } } else { @@ -95,7 +97,7 @@ namespace microLB } void Balancer::handle_connections() { - LBOUT("Handle_connections. %i waiting \n", queue.size()); + LBOUT("Handle_connections. %lu waiting \n", queue.size()); // stop any rethrow timer since this is a de-facto retry if (this->throw_retry_timer != Timers::UNUSED_ID) { Timers::stop(this->throw_retry_timer); @@ -144,25 +146,11 @@ namespace microLB // Release connection if it closes before it's assigned to a node. this->conn->on_close([this](){ + printf("Waiting issuing close\n"); if (this->conn != nullptr) this->conn->reset_callbacks(); this->conn = nullptr; }); - - // queue incoming data from clients not yet - // assigned to a node - this->conn->on_read(READQ_PER_CLIENT, - [this] (auto buf) { - // prevent buffer bloat attack - this->total += buf->size(); - if (this->total > MAX_READQ_PER_NODE) { - this->conn->close(); - } - else { - LBOUT("*** Queued %lu bytes\n", buf->size()); - readq.push_back(buf); - } - }); } void Nodes::create_connections(int total) @@ -195,7 +183,7 @@ namespace microLB } } } - net::Stream_ptr Nodes::assign(net::Stream_ptr conn, queue_vector_t& readq) + net::Stream_ptr Nodes::assign(net::Stream_ptr conn) { for (size_t i = 0; i < nodes.size(); i++) { @@ -208,13 +196,10 @@ namespace microLB assert(outgoing->is_connected()); LBOUT("Assigning client to node %d (%s)\n", algo_iterator, outgoing->to_string().c_str()); + //Should we some way hold track of the session object ? auto& session = this->create_session( std::move(conn), std::move(outgoing)); - // flush readq to session.outgoing - for (auto buffer : readq) { - LBOUT("*** Flushing %lu bytes\n", buffer->size()); - session.outgoing->write(buffer); - } + return nullptr; } } @@ -271,16 +256,35 @@ namespace microLB assert(session.is_alive()); return session; } + + void Nodes::destroy_sessions() + { + for (auto& idx: closed_sessions) + { + auto &session=get_session(idx); + + // free session destroying potential unique ptr objects + session.incoming = nullptr; + auto out_tcp = dynamic_cast(session.outgoing->bottom_transport())->tcp(); + session.outgoing = nullptr; + // if we don't have anything to write to the backend, abort it. + if(not out_tcp->sendq_size()) + out_tcp->abort(); + free_sessions.push_back(session.self); + LBOUT("Session %d destroyed (total = %d)\n", session.self, session_cnt); + } + closed_sessions.clear(); + } void Nodes::close_session(int idx) { auto& session = get_session(idx); // remove connections session.incoming->reset_callbacks(); - session.incoming = nullptr; session.outgoing->reset_callbacks(); - session.outgoing = nullptr; - // free session - free_sessions.push_back(session.self); + closed_sessions.push_back(session.self); + + destroy_sessions(); + session_cnt--; LBOUT("Session %d closed (total = %d)\n", session.self, session_cnt); } @@ -319,7 +323,7 @@ namespace microLB this->restart_active_check(); } }); - } catch (std::exception& e) { + } catch (const std::exception&) { // do nothing, because might just be eph.ports used up } } @@ -360,13 +364,24 @@ namespace microLB } void Node::connect() { - auto outgoing = this->stack.tcp().connect(this->addr); + net::tcp::Connection_ptr outgoing; + try + { + outgoing = this->stack.tcp().connect(this->addr); + } + catch([[maybe_unused]]const net::TCP_error& err) + { + LBOUT("Got exception: %s\n", err.what()); + this->restart_active_check(); + return; + } // connecting to node atm. this->connecting++; // retry timer when connect takes too long int fail_timer = Timers::oneshot(CONNECT_TIMEOUT, [this, outgoing] (int) { + printf("Fail timer\n"); // close connection outgoing->abort(); // no longer connecting @@ -408,8 +423,14 @@ namespace microLB auto conn = std::move(pool.back()); assert(conn != nullptr); pool.pop_back(); - if (conn->is_connected()) return conn; - else conn->close(); + if (conn->is_connected()) { + return conn; + } + else + { + printf("CLOSING SINCE conn->connected is false\n"); + conn->close(); + } } return nullptr; } @@ -420,50 +441,37 @@ namespace microLB : parent(n), self(idx), incoming(std::move(inc)), outgoing(std::move(out)) { - incoming->on_read(READQ_PER_CLIENT, - [this] (auto buf) { - assert(this->is_alive()); - this->outgoing->write(buf); - }); + incoming->on_data({this, &Session::flush_incoming}); incoming->on_close( [&nodes = n, idx] () { nodes.close_session(idx); }); - outgoing->on_read(READQ_FOR_NODES, - [this] (auto buf) { - assert(this->is_alive()); - this->incoming->write(buf); - }); + + outgoing->on_data({this, &Session::flush_outgoing}); outgoing->on_close( [&nodes = n, idx] () { nodes.close_session(idx); }); - - // get the actual TCP connections - auto conn_in = dynamic_cast(incoming->bottom_transport())->tcp(); - assert(conn_in != nullptr); - auto conn_out = dynamic_cast(outgoing->bottom_transport())->tcp(); - assert(conn_out != nullptr); - - static const uint32_t sendq_max = 0x400000; - // set recv window handlers - conn_in->set_recv_wnd_getter( - [conn_out] () -> uint32_t { - auto sendq_size = conn_out->sendq_size(); - - if (sendq_size > sendq_max) - printf("WARNING: Incoming reports sendq size: %u\n", sendq_size); - return sendq_max - sendq_size; - }); - conn_out->set_recv_wnd_getter( - [conn_in] () -> uint32_t { - auto sendq_size = conn_in->sendq_size(); - if (sendq_size > sendq_max) - printf("WARNING: Outgoing reports sendq size: %u\n", sendq_size); - return sendq_max - sendq_size; - }); } bool Session::is_alive() const { return incoming != nullptr; } + + void Session::flush_incoming() + { + assert(this->is_alive()); + while((this->incoming->next_size() > 0) and this->outgoing->is_writable()) + { + this->outgoing->write(this->incoming->read_next()); + } + } + + void Session::flush_outgoing() + { + assert(this->is_alive()); + while((this->outgoing->next_size() > 0) and this->incoming->is_writable()) + { + this->incoming->write(this->outgoing->read_next()); + } + } } diff --git a/lib/microLB/micro_lb/balancer.hpp b/lib/microLB/micro_lb/balancer.hpp index 41a6cb503a..473e76b475 100644 --- a/lib/microLB/micro_lb/balancer.hpp +++ b/lib/microLB/micro_lb/balancer.hpp @@ -1,12 +1,12 @@ #pragma once #include #include +#include namespace microLB { typedef net::Inet netstack_t; typedef net::tcp::Connection_ptr tcp_ptr; - typedef std::vector queue_vector_t; typedef delegate pool_signal_t; struct Waiting { @@ -15,7 +15,6 @@ namespace microLB void serialize(liu::Storage&); net::Stream_ptr conn; - queue_vector_t readq; int total = 0; }; @@ -29,6 +28,9 @@ namespace microLB const int self; net::Stream_ptr incoming; net::Stream_ptr outgoing; + + void flush_incoming(); + void flush_outgoing(); }; struct Node { @@ -37,7 +39,7 @@ namespace microLB auto address() const noexcept { return this->addr; } int connection_attempts() const noexcept { return this->connecting; } int pool_size() const noexcept { return pool.size(); } - bool is_active() const noexcept { return active; }; + bool is_active() const noexcept { return active; } bool active_check() const noexcept { return do_active_check; } void restart_active_check(); @@ -77,9 +79,10 @@ namespace microLB void add_node(Args&&... args); void create_connections(int total); // returns the connection back if the operation fails - net::Stream_ptr assign(net::Stream_ptr, queue_vector_t&); + net::Stream_ptr assign(net::Stream_ptr); Session& create_session(net::Stream_ptr inc, net::Stream_ptr out); void close_session(int); + void destroy_sessions(); Session& get_session(int); void serialize(liu::Storage&); @@ -92,14 +95,16 @@ namespace microLB int conn_iterator = 0; int algo_iterator = 0; const bool do_active_check; + Timer cleanup_timer; std::deque sessions; std::deque free_sessions; + std::deque closed_sessions; }; struct Balancer { - Balancer(netstack_t& in, uint16_t port, netstack_t& out, bool do_ac); - Balancer(netstack_t& in, uint16_t port, netstack_t& out, bool do_ac, - const std::string& cert, const std::string& key); + Balancer(netstack_t& in, uint16_t port, netstack_t& out, bool do_ac = false); + Balancer(netstack_t& in, uint16_t port, netstack_t& out, + const std::string& cert, const std::string& key, bool do_ac = false); static Balancer* from_config(); int wait_queue() const; diff --git a/lib/microLB/micro_lb/openssl.cpp b/lib/microLB/micro_lb/openssl.cpp index 139f3fada0..455f26d417 100644 --- a/lib/microLB/micro_lb/openssl.cpp +++ b/lib/microLB/micro_lb/openssl.cpp @@ -10,9 +10,8 @@ namespace microLB netstack_t& in, uint16_t port, netstack_t& out, - const bool do_ac, const std::string& tls_cert, - const std::string& tls_key) + const std::string& tls_key, const bool do_ac) : nodes(do_ac), netin(in), netout(out), signal({this, &Balancer::handle_queue}) { fs::memdisk().init_fs( diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a74cc1c9d2..cc1eddb3ba 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -22,6 +22,7 @@ include_directories(${OPENSSL_DIR}/include) if(${ARCH} STREQUAL "x86_64") set(OPENSSL_MODULES "net/openssl/init.cpp" "net/openssl/client.cpp" "net/openssl/server.cpp" + "net/openssl/tls_stream.cpp" "net/https/openssl_server.cpp" "net/http/client.cpp") set(OPENSSL_LIBS openssl_ssl openssl_crypto) endif() diff --git a/src/net/checksum.cpp b/src/net/checksum.cpp index e981f001fc..db7e1035f6 100644 --- a/src/net/checksum.cpp +++ b/src/net/checksum.cpp @@ -16,9 +16,16 @@ // limitations under the License. #include -#include -#include +#include + +#if defined(__AVX2__) + #include +#elif defined(__SSSE3__) + #include +#endif + #include +#include namespace net { @@ -26,14 +33,113 @@ uint16_t checksum(uint32_t tsum, const void* data, size_t length) noexcept { const char* buffer = (const char*) data; int64_t sum = tsum; + if (UNLIKELY(length == 0)) + return 0xffff; + + if (UNLIKELY(buffer == 0)) + return 0xffff; + + + +#if defined(__SSSE3__) + static __m128i swap16a = _mm_setr_epi16(0x0001, 0xffff, 0x0203, 0xffff, + 0x0405, 0xffff, 0x0607, 0xffff); + static __m128i swap16b = _mm_setr_epi16(0x0809, 0xffff, 0x0a0b, 0xffff, + 0x0c0d, 0xffff, 0x0e0f, 0xffff); + size_t count; + __m128i zero = _mm_setzero_si128(); + __m128i suma=zero; + __m128i sumb=zero; + __m128i oldsum; + + //according to godbolt its sligtly better to count index than incrementing pointer + for(count = 0; (count+64) < length; count+=64) + { + __m128i dblock1,dblock2; + dblock1 = _mm_loadu_si128((__m128i *) (&buffer[count + 0])); + dblock2 = _mm_loadu_si128((__m128i *) (&buffer[count + 16])); + + suma = _mm_add_epi32(suma,_mm_shuffle_epi8(dblock1,swap16a)); + sumb = _mm_add_epi32(sumb,_mm_shuffle_epi8(dblock1,swap16b)); + suma = _mm_add_epi32(suma,_mm_shuffle_epi8(dblock2,swap16a)); + sumb = _mm_add_epi32(sumb,_mm_shuffle_epi8(dblock2,swap16b)); + + dblock1 = _mm_loadu_si128((__m128i *) (&buffer[count + 32])); + dblock2 = _mm_loadu_si128((__m128i *) (&buffer[count + 48])); + + suma = _mm_add_epi32(suma,_mm_shuffle_epi8(dblock1,swap16a)); + sumb = _mm_add_epi32(sumb,_mm_shuffle_epi8(dblock1,swap16b)); + + suma = _mm_add_epi32(suma,_mm_shuffle_epi8(dblock2,swap16a)); + sumb = _mm_add_epi32(sumb,_mm_shuffle_epi8(dblock2,swap16b)); + } + //why are we not doing a 32 ? + while ((count+16) <= length) + { + __m128i dblock; + dblock= _mm_loadu_si128((__m128i *) (&buffer[count])); + suma = _mm_add_epi32(suma,_mm_shuffle_epi8(dblock,swap16a)); + sumb = _mm_add_epi32(sumb,_mm_shuffle_epi8(dblock,swap16b)); + count+=16; + } + + /*alignas(16) this can be unaligned as we most likely are accessing it unaligned anyays*/ + alignas(16) static const uint8_t shift_tab[48]={ + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff + }; + /* we could extend this to fast convert 4 words to LE from BE */ + alignas(16) static const uint8_t swap32[16]{ + 0x03,0x02,0x01,0x00,0x80,0x80,0x80,0x80, + 0x80,0x80,0x80,0x80,0x80,0x80,0x80,0x80 + }; + int rest=16-(length-count); + //this over reads but who cares?! + if (LIKELY(rest != 16)) + { + __m128i dblock; + dblock = _mm_loadu_si128((__m128i *)(&buffer[count])); + //shifting the data up then down gives us leading zeroes clearing out any over read data + //the shift up shuffles the data bytewise into place + dblock = _mm_shuffle_epi8(dblock, _mm_loadu_si128((__m128i *)(&shift_tab[16 - rest]))); + dblock = _mm_shuffle_epi8(dblock, _mm_loadu_si128((__m128i *)(&shift_tab[16 + rest]))); + suma = _mm_add_epi32(suma,_mm_shuffle_epi8(dblock,swap16a)); + sumb = _mm_add_epi32(sumb,_mm_shuffle_epi8(dblock,swap16b)); + } + + suma = _mm_add_epi32(suma, sumb); + //add 0 and 1 to 0 and 2 and 3 to 1 + suma = _mm_hadd_epi32(suma, _mm_setzero_si128()); + //add 0 and 1 to 1 .. + suma = _mm_hadd_epi32(suma, _mm_setzero_si128()); + + oldsum = _mm_shuffle_epi8(_mm_cvtsi32_si128(tsum), _mm_loadu_si128((__m128i *)&swap32[0])); + suma = _mm_add_epi32(suma,oldsum); //adds the old csum to this + //fix endianess + + //extract the 32 bit sum from vector + uint32_t vsum; + vsum=(uint32_t) _mm_cvtsi128_si32(suma); + + //printf("Vsum + swapped(tsum) %08x\n",vsum); + //maybe this only works if tsum is byteswapped ? + while (vsum >>16) + { + vsum=(vsum & 0xFFFF)+(vsum>>16); + } + //allways right in this case as its allways little endian x86 + return ~ntohs((uint16_t)(vsum)); +#elif defined(__AVX2__) // VEX-align buffer while (((uintptr_t) buffer & 15) && length >= 4) { sum += *(uint32_t*) buffer; length -= 4; buffer += 4; } - -#ifdef __AVX2__ // run 4 32-bit adds in parallell union vec4 { __m256i mm; diff --git a/src/net/conntrack.cpp b/src/net/conntrack.cpp index 3621f3ef1b..42cf65f343 100644 --- a/src/net/conntrack.cpp +++ b/src/net/conntrack.cpp @@ -352,23 +352,22 @@ int Conntrack::deserialize_from(void* addr) const auto size = *reinterpret_cast(buffer); buffer += sizeof(size_t); - + size_t dupes = 0; for(auto i = size; i > 0; i--) { // create the entry auto entry = std::make_shared(); buffer += entry->deserialize_from(buffer); - entries.emplace(std::piecewise_construct, - std::forward_as_tuple(entry->first, entry->proto), - std::forward_as_tuple(entry)); - - entries.emplace(std::piecewise_construct, - std::forward_as_tuple(entry->second, entry->proto), - std::forward_as_tuple(entry)); + bool insert = false; + insert = entries.insert_or_assign({entry->first, entry->proto}, entry).second; + if(not insert) + dupes++; + insert = entries.insert_or_assign({entry->second, entry->proto}, entry).second; + if(not insert) + dupes++; } - - Ensures(entries.size() - prev_size == size * 2); + Ensures(entries.size() - (prev_size-dupes) == size * 2); return buffer - reinterpret_cast(addr); } diff --git a/src/net/openssl/tls_stream.cpp b/src/net/openssl/tls_stream.cpp new file mode 100644 index 0000000000..16722a15f2 --- /dev/null +++ b/src/net/openssl/tls_stream.cpp @@ -0,0 +1,365 @@ +#include + +using namespace openssl; + +TLS_stream::TLS_stream(SSL_CTX* ctx, Stream_ptr t, bool outgoing) + : m_transport(std::move(t)) +{ + ERR_clear_error(); // prevent old errors from mucking things up + this->m_bio_rd = BIO_new(BIO_s_mem()); + this->m_bio_wr = BIO_new(BIO_s_mem()); + assert(ERR_get_error() == 0 && "Initializing BIOs"); + this->m_ssl = SSL_new(ctx); + assert(this->m_ssl != nullptr); + assert(ERR_get_error() == 0 && "Initializing SSL"); + // TLS server-mode + if (outgoing == false) + SSL_set_accept_state(this->m_ssl); + else + SSL_set_connect_state(this->m_ssl); + + SSL_set_bio(this->m_ssl, this->m_bio_rd, this->m_bio_wr); + + // always-on callbacks + m_transport->on_data({this,&TLS_stream::handle_data}); + m_transport->on_close({this, &TLS_stream::close_callback_once}); + + // start TLS handshake process + if (outgoing == true) + { + if (this->tls_perform_handshake() < 0) return; + } +} +TLS_stream::TLS_stream(Stream_ptr t, SSL* ssl, BIO* rd, BIO* wr) + : m_transport(std::move(t)), m_ssl(ssl), m_bio_rd(rd), m_bio_wr(wr) +{ + // always-on callbacks + m_transport->on_data({this, &TLS_stream::handle_data}); + m_transport->on_close({this, &TLS_stream::close_callback_once}); +} +TLS_stream::~TLS_stream() +{ + assert(m_busy == false && "Cannot delete stream while in its call stack"); + SSL_free(this->m_ssl); +} + +void TLS_stream::write(buffer_t buffer) +{ + + if (UNLIKELY(this->is_connected() == false)) { + TLS_PRINT("::write() called on closed stream\n"); + return; + } + int n = SSL_write(this->m_ssl, buffer->data(), buffer->size()); + auto status = this->status(n); + if (status == STATUS_FAIL) { + TLS_PRINT("::write() Fail status %d\n",n); + this->close(); + return; + } + + do { + n = tls_perform_stream_write(); + } while (n > 0); +} + +void TLS_stream::write(const std::string& str) +{ + //TODO handle failed alloc + write(net::StreamBuffer::construct_write_buffer(str.data(),str.data()+str.size())); +} + +void TLS_stream::write(const void* data, const size_t len) +{ + //TODO handle failed alloc + auto* buf = static_cast (data); + write(net::StreamBuffer::construct_write_buffer(buf, buf + len)); +} + +int TLS_stream::decrypt(const void *indata, int size) +{ + int n = BIO_write(this->m_bio_rd, indata, size); + if (UNLIKELY(n < 0)) { + //TODO can we handle this more gracefully? + TLS_PRINT("BIO_write failed\n"); + this->close(); + return 0; + } + + // if we aren't finished initializing session + if (UNLIKELY(!handshake_completed())) + { + int num = SSL_do_handshake(this->m_ssl); + auto status = this->status(num); + + // OpenSSL wants to write + if (status == STATUS_WANT_IO) + { + tls_perform_stream_write(); + } + else if (status == STATUS_FAIL) + { + if (num < 0) { + TLS_PRINT("TLS_stream::SSL_do_handshake() returned %d\n", num); + #ifdef VERBOSE_OPENSSL + ERR_print_errors_fp(stdout); + #endif + } + this->close(); + return 0; + } + // nothing more to do if still not finished + if (handshake_completed() == false) return 0; + // handshake success + this->m_busy=true; + connected(); + this->m_busy=false; + + if (this->m_deferred_close) { + TLS_PRINT("::read() close on m_deferred_close after tls_perform_stream_write\n"); + this->close(); + return 0; + } + } + return n; +} + +int TLS_stream::send_decrypted() +{ + int n; + // read decrypted data + do { + //TODO "increase the size or constructor based ??") + auto buffer=StreamBuffer::construct_read_buffer(8192); + if (!buffer) return 0; + n = SSL_read(this->m_ssl,buffer->data(),buffer->size()); + if (n > 0) { + buffer->resize(n); + enqueue_data(buffer); + } + } while (n > 0); + return n; +} + +void TLS_stream::handle_read_congestion() +{ + //Ordering could be different + send_decrypted(); //decrypt any incomplete + this->m_busy=true; + signal_data(); //send any pending + this->m_busy=false; + + if (this->m_deferred_close) { + TLS_PRINT("::read() close on m_deferred_close after tls_perform_stream_write\n"); + this->close(); + return; + } +} + +void TLS_stream::handle_write_congestion() +{ + //this should resolve the potential malloc congestion + //might be missing some TLS signalling but without malloc we cant do that either + while(tls_perform_stream_write() > 0); +} +void TLS_stream::handle_data() +{ + while ( m_transport->next_size() > 0) + { + if (UNLIKELY(read_congested())){ + break; + } + tls_read(m_transport->read_next()); + //bail + if (m_transport == nullptr) + { + printf("m_transport \n"); + break; + } + } +} + +void TLS_stream::tls_read(buffer_t buffer) +{ + if (buffer == nullptr ) { + return; + } + ERR_clear_error(); + uint8_t* buf_ptr = buffer->data(); + int len = buffer->size(); + + while (len > 0) + { + if (this->m_deferred_close) { + TLS_PRINT("::read() close on m_deferred_close"); + this->close(); + return; + } + + int decrypted_bytes=decrypt(buf_ptr,len); + if (UNLIKELY(decrypted_bytes==0)) return; + buf_ptr += decrypted_bytes; + len -= decrypted_bytes; + + //enqueues decrypted data + int ret=send_decrypted(); + + // this goes here? + if (UNLIKELY(this->is_closing() || this->is_closed())) { + TLS_PRINT("TLS_stream::SSL_read closed during read\n"); + return; + } + if (this->m_deferred_close) { + TLS_PRINT("::read() close on m_deferred_close"); + this->close(); + return; + } + + auto status = this->status(ret); + // did peer request stream renegotiation? + if (status == STATUS_WANT_IO) + { + TLS_PRINT("::read() STATUS_WANT_IO\n"); + int ret; + do { + ret = tls_perform_stream_write(); + } while (ret > 0); + } + else if (status == STATUS_FAIL) + { + TLS_PRINT("::read() close on STATUS_FAIL after tls_perform_stream_write\n"); + this->close(); + return; + } + + } // while it < end + + //forward data + this->m_busy=true; + signal_data(); + this->m_busy=false; + + // check deferred closing + if (this->m_deferred_close) { + TLS_PRINT("::read() close on m_deferred_close after tls_perform_stream_write\n"); + this->close(); return; + } +} // tls_read() + +int TLS_stream::tls_perform_stream_write() +{ + ERR_clear_error(); + int pending = BIO_ctrl_pending(this->m_bio_wr); + if (pending > 0) + { + TLS_PRINT("::tls_perform_stream_write() pending=%d bytes\n",pending); + auto buffer = net::StreamBuffer::construct_write_buffer(pending); + if (buffer == nullptr) { + return 0; + } + int n = BIO_read(this->m_bio_wr, buffer->data(), buffer->size()); + assert(n == pending); + //What if we cant write.. + if (m_transport->is_writable()) + { + m_transport->write(buffer); + + this->m_busy = true; + stream_on_write(n); + this->m_busy = false; + + if (this->m_deferred_close) { + TLS_PRINT("::read() close on m_deferred_close after tls_perform_stream_write\n"); + this->close(); return 0; + } + } + + if (UNLIKELY((pending = BIO_ctrl_pending(this->m_bio_wr)) > 0)) + { + return pending; + } + return 0; + } + + BIO_read(this->m_bio_wr, nullptr, 0); + + if (!BIO_should_retry(this->m_bio_wr)) + { + TLS_PRINT("::tls_perform_stream_write() close on !BIO_should_retry\n"); + this->close(); + return -1; + } + return 0; +} + +int TLS_stream::tls_perform_handshake() +{ + ERR_clear_error(); // prevent old errors from mucking things up + // will return -1:SSL_ERROR_WANT_WRITE + int ret = SSL_do_handshake(this->m_ssl); + int n = this->status(ret); + ERR_print_errors_fp(stderr); + if (n == STATUS_WANT_IO) + { + do { + n = tls_perform_stream_write(); + if (n < 0) { + TLS_PRINT("TLS_stream::tls_perform_handshake() stream write failed\n"); + } + } while (n > 0); + return n; + } + else { + TLS_PRINT("TLS_stream::tls_perform_handshake() returned %d\n", ret); + this->close(); + return -1; + } +} + +void TLS_stream::close() +{ + TLS_PRINT("TLS_stream::close()\n"); + //ERR_clear_error(); + if (this->m_busy) { + TLS_PRINT("TLS_stream::close() deferred\n"); + this->m_deferred_close = true; return; + } + CloseCallback func = getCloseCallback(); + this->reset_callbacks(); + if (m_transport->is_connected()) + { + m_transport->close(); + m_transport->reset_callbacks(); // ??? + } + if (func) func(); +} +void TLS_stream::close_callback_once() +{ + TLS_PRINT("TLS_stream::close_callback_once() \n"); + if (this->m_busy) { + TLS_PRINT("TLS_stream::close_callback_once() deferred\n"); + this->m_deferred_close = true; return; + } + CloseCallback func = getCloseCallback(); + this->reset_callbacks(); + if (func) func(); +} + +bool TLS_stream::handshake_completed() const noexcept +{ + return SSL_is_init_finished(this->m_ssl); +} +TLS_stream::status_t TLS_stream::status(int n) const noexcept +{ + int error = SSL_get_error(this->m_ssl, n); + switch (error) + { + case SSL_ERROR_NONE: + return STATUS_OK; + case SSL_ERROR_WANT_WRITE: + case SSL_ERROR_WANT_READ: + return STATUS_WANT_IO; + default: + return STATUS_FAIL; + } +} diff --git a/src/net/tcp/connection.cpp b/src/net/tcp/connection.cpp index fa4f3dd918..38735407b4 100644 --- a/src/net/tcp/connection.cpp +++ b/src/net/tcp/connection.cpp @@ -36,7 +36,7 @@ Connection::Connection(TCP& host, Socket local, Socket remote, ConnectCallback c cb{host_.window_size()}, read_request(nullptr), writeq(), - recv_wnd_getter{TCP::global_recv_wnd}, + recv_wnd_getter{nullptr}, on_connect_{std::move(callback)}, on_disconnect_({this, &Connection::default_on_disconnect}), rtx_timer({this, &Connection::rtx_timeout}), @@ -53,8 +53,9 @@ Connection::Connection(TCP& host, Socket local, Socket remote, ConnectCallback c Connection::~Connection() { - //printf(" Deleted %p %s ACTIVE: %u\n", this, + //printf(" Deleted %p %s ACTIVE: %zu\n", this, // to_string().c_str(), host_.active_connections()); + rtx_clear(); } @@ -65,13 +66,16 @@ void Connection::_on_read(size_t recv_bufsz, ReadCallback cb) { Expects(bufalloc != nullptr); read_request.reset( - new Read_request(this->cb.RCV.NXT, host_.min_bufsize(), host_.max_bufsize(), cb, bufalloc.get())); + new Read_request(this->cb.RCV.NXT, host_.min_bufsize(), host_.max_bufsize(), bufalloc.get())); + read_request->on_read_callback = cb; + const size_t avail_thres = host_.max_bufsize() * Read_request::buffer_limit; + bufalloc->on_avail(avail_thres, {this, &Connection::trigger_window_update}); } // read request is already set, only reset if new size. else { //printf("on_read already set\n"); - read_request->callback = cb; + read_request->on_read_callback = cb; // this will flush the current data to the user (if any) read_request->reset(this->cb.RCV.NXT); @@ -82,6 +86,32 @@ void Connection::_on_read(size_t recv_bufsz, ReadCallback cb) } } +void Connection::_on_data(DataCallback cb) { + if(read_request == nullptr) + { + Expects(bufalloc != nullptr); + read_request.reset( + new Read_request(this->cb.RCV.NXT, host_.min_bufsize(), host_.max_bufsize(), bufalloc.get())); + read_request->on_data_callback = cb; + const size_t avail_thres = host_.max_bufsize() * Read_request::buffer_limit; + bufalloc->on_avail(avail_thres, {this, &Connection::trigger_window_update}); + } + // read request is already set, only reset if new size. + else + { + //printf("on_read already set\n"); + read_request->on_data_callback = cb; + + read_request->reset(this->cb.RCV.NXT); + + // due to throwing away buffers (and all data) we also + // need to clear the sack list if anything is stored here. + if(sack_list) + sack_list->clear(); + } +} + + Connection_ptr Connection::retrieve_shared() { return host_.retrieve_shared(this); } @@ -118,8 +148,10 @@ void Connection::reset_callbacks() writeq.on_write(nullptr); on_close_.reset(); recv_wnd_getter.reset(); - if(read_request) - read_request->callback.reset(); + if(read_request) { + read_request->on_read_callback.reset(); + read_request->on_data_callback.reset(); + } } uint16_t Connection::MSDS() const noexcept { @@ -275,7 +307,7 @@ void Connection::close() { void Connection::receive_disconnect() { Expects(read_request and read_request->size()); - if(read_request->callback) { + if(read_request->on_read_callback) { // TODO: consider adding back when SACK is complete //auto& buf = read_request->buffer; //if (buf.size() > 0 && buf.missing() == 0) @@ -283,6 +315,7 @@ void Connection::receive_disconnect() { } } + void Connection::segment_arrived(Packet_view& incoming) { //const uint32_t FMASK = (~(0x0000000F | htons(0x08))); @@ -320,6 +353,7 @@ int Connection::serialize_to(void*) const { return 0; } Packet_view_ptr Connection::create_outgoing_packet() { + update_rcv_wnd(); auto packet = (is_ipv6_) ? host_.create_outgoing_packet6() : host_.create_outgoing_packet(); // Set Source (local == the current connection) @@ -368,7 +402,7 @@ void Connection::transmit(Packet_view_ptr packet) { if(packet->isset(ACK)) last_ack_sent_ = cb.RCV.NXT; - //if(packet->has_tcp_data()) printf(" TX %s - NXT:%u\n", packet->to_string().c_str(), cb.SND.NXT); + //printf(" TX %s\n%s\n", packet->to_string().c_str(), to_string().c_str()); host_.transmit(std::move(packet)); } @@ -401,6 +435,8 @@ bool Connection::handle_ack(const Packet_view& in) if(is_win_update(in, true_win)) { + //if(cb.SND.WND < SMSS()*2) + // printf("Win update: %u => %u\n", cb.SND.WND, true_win); cb.SND.WND = true_win; cb.SND.WL1 = in.seq(); cb.SND.WL2 = in.ack(); @@ -415,6 +451,10 @@ bool Connection::handle_ack(const Packet_view& in) if(cb.SND.TS_OK) { const auto* ts = in.ts_option(); + // reparse to avoid case when stored ts suddenly get lost + if(ts == nullptr) + ts = in.parse_ts_option(); + if(ts != nullptr) // TODO: not sure the packet is valid if TS missing last_acked_ts_ = ts->ecr; } @@ -423,6 +463,8 @@ bool Connection::handle_ack(const Packet_view& in) rtx_ack(in.ack()); + update_rcv_wnd(); + take_rtt_measure(in); // do either congctrl or fastrecov according to New Reno @@ -463,7 +505,7 @@ void Connection::congestion_control(const Packet_view& in) } // < congestion avoidance // try to write - if(can_send() and !in.has_tcp_data()) + if(can_send() and (!in.has_tcp_data() or cb.RCV.WND < in.tcp_data_length())) { debug2(" Can send UW: %u SMSS: %u\n", usable_window(), SMSS()); send_much(); @@ -574,7 +616,7 @@ void Connection::on_dup_ack(const Packet_view& in) // 3 dup acks else if(dup_acks_ == 3) { - debug(" Dup ACK == 3 - %u\n", cb.SND.UNA); + //printf(" Dup ACK == 3 - UNA=%u recover=%u\n", cb.SND.UNA, cb.recover); if(cb.SND.UNA - 1 > cb.recover) goto fast_rtx; @@ -583,9 +625,14 @@ void Connection::on_dup_ack(const Packet_view& in) if(cb.SND.TS_OK) { const auto* ts = in.ts_option(); - if(ts != nullptr and last_acked_ts_ == ts->ecr) + // reparse to avoid case when stored ts suddenly get lost + if(ts == nullptr) + ts = in.parse_ts_option(); + + if(ts != nullptr) { - goto fast_rtx; + if(last_acked_ts_ == ts->ecr) + goto fast_rtx; } } // 4.1. ACK Heuristic @@ -593,13 +640,13 @@ void Connection::on_dup_ack(const Packet_view& in) { goto fast_rtx; } - return; fast_rtx: { cb.recover = cb.SND.NXT; - debug(" Enter Recovery - Flight Size: %u\n", flight_size()); + debug(" Enter Recovery %u - Flight Size: %u\n", + cb.recover, flight_size()); fast_retransmit(); } } @@ -649,6 +696,50 @@ void Connection::rtx_ack(const seq_t ack) { // x-rtx_q.size(), rtx_q.size()); } +void Connection::trigger_window_update(os::mem::Pmr_resource& res) +{ + const auto reserve = (host_.max_bufsize() * Read_request::buffer_limit); + if(res.allocatable() >= reserve and cb.RCV.WND == 0) { + //printf("allocatable=%zu cur_win=%u\n", res.allocatable(), cb.RCV.WND); + send_window_update(); + } +} + +uint32_t Connection::calculate_rcv_wnd() const +{ + // PRECISE REPORTING + if(UNLIKELY(read_request == nullptr)) + return 0xffff; + + const auto& rbuf = read_request->front(); + auto remaining = rbuf.capacity() - rbuf.size(); + + auto buf_avail = bufalloc->allocatable() + remaining; + auto reserve = (host_.max_bufsize() * Read_request::buffer_limit); + auto win = buf_avail > reserve ? buf_avail - reserve : 0; + + return (win < SMSS()) ? 0 : win; // Avoid small silly windows + + // REPORT CHUNKWISE + /* + //auto allocatable = bufalloc->allocatable(); + const auto& rbuf = read_request->front(); + + auto win = cb.RCV.WND; + if (bufalloc->allocatable() < rbuf.capacity()) { + printf("[connection] Allocatable data is less than capacity. Win 0. \n"); + win = 0; + } else { + win = bufalloc->allocatable() - rbuf.capacity(); + } + + return win; + */ + + // REPORT CHUNKWISE FROM ALLOCATOR + //return bufalloc->allocatable(); +} + /* 7. Process the segment text @@ -688,21 +779,20 @@ void Connection::recv_data(const Packet_view& in) // this shouldn't be necessary with well behaved connections. // I also think we shouldn't reach this point due to State::check_seq checking // if we're inside the window. if packet is out of order tho we can change the RCV wnd (i think). - if(UNLIKELY(bufalloc->allocatable() < host_.max_bufsize())) { + /*if(UNLIKELY(bufalloc->allocatable() < host_.max_bufsize())) { drop(in, Drop_reason::RCV_WND_ZERO); return; - } + }*/ size_t length = in.tcp_data_length(); - /* if(UNLIKELY(cb.RCV.WND < length)) { - printf("DROP: Receive window too small - my window is now: %u \n", cb.RCV.WND); drop(in, Drop_reason::RCV_WND_ZERO); + update_rcv_wnd(); + send_ack(); return; } - */ // Keep track if a packet is being sent during the async read callback @@ -740,52 +830,17 @@ void Connection::recv_data(const Packet_view& in) const auto recv = read_request->insert(in.seq(), in.tcp_data(), length, in.isset(PSH)); // this ensures that the data we ACK is actually put in our buffer. Ensures(recv == length); - // adjust the rcv wnd to (maybe) new value - - // LET APPLICATION REPORT - // cb.RCV.WND = recv_wnd_getter(); - - // PRECISE REPORTING - /* - const auto& rbuf = read_request->front(); - auto remaining = rbuf.capacity() - rbuf.size(); - auto win = (bufalloc->allocatable() + remaining) - rbuf.capacity(); - //auto max = read_request->front().capacity(); - //win = (win < max) ? (rbuf.capacity() - rbuf.size()) : win - max; - cb.RCV.WND = win; - */ - - // REPORT CHUNKWISE - /* - //auto allocatable = bufalloc->allocatable(); - const auto& rbuf = read_request->front(); - - auto win = cb.RCV.WND; - if (bufalloc->allocatable() < rbuf.capacity()) { - printf("[connection] Allocatable data is less than capacity. Win 0. \n"); - win = 0; - } else { - win = bufalloc->allocatable() - rbuf.capacity(); - } - - cb.RCV.WND = win; - */ - - - // REPORT CONSTANT - cb.RCV.WND = bufalloc->allocatable(); - //cb.RCV.WND = 64_MiB; - //cb.RCV.WND = std::max(bufalloc->allocatable(), 4_MiB); } } // Packet out of order - else if((in.seq() - cb.RCV.NXT) < cb.RCV.WND) + else if(( (in.seq() + in.tcp_data_length()) - cb.RCV.NXT) < cb.RCV.WND) { // only accept the data if we have a read request if(read_request != nullptr) recv_out_of_order(in); } + // User callback didnt result in transmitting an ACK if(cb.SND.NXT == snd_nxt) ack_data(); @@ -901,6 +956,9 @@ void Connection::take_rtt_measure(const Packet_view& packet) if(cb.SND.TS_OK) { const auto* ts = packet.ts_option(); + // reparse to avoid case when stored ts suddenly get lost + if(ts == nullptr) + ts = packet.parse_ts_option(); if(ts) { rttm.rtt_measurement(RTTM::milliseconds{host_.get_ts_value() - ntohl(ts->ecr)}); @@ -936,14 +994,17 @@ void Connection::retransmit() { syn_rtx_++; } // If not, check if there is data and retransmit - else if(writeq.size()) { + else if(writeq.size()) + { auto& buf = writeq.una(); - debug2(" With data (wq.sz=%u) buf.unacked=%u\n", - writeq.size(), buf.length() - buf.acknowledged); + + // TODO: Finish to send window zero probe, but only on rtx timeout + + //printf(" With data (wq.sz=%zu) buf.size=%zu buf.unacked=%zu SND.WND=%u CWND=%u\n", + // writeq.size(), buf->size(), buf->size() - writeq.acked(), cb.SND.WND, cb.cwnd); fill_packet(*packet, buf->data() + writeq.acked(), buf->size() - writeq.acked()); - packet->set_flag(PSH); + packet->set_flag(PSH); } - rtx_attempt_++; packet->set_seq(cb.SND.UNA); /* @@ -1012,19 +1073,21 @@ void Connection::rtx_clear() { begins (i.e., after the three-way handshake completes). */ void Connection::rtx_timeout() { - debug(" Timed out (RTO %lld ms). FS: %u\n", - rttm.rto_ms().count(), flight_size()); + //printf(" Timed out (RTO %lld ms). FS: %u usable=%u\n", + // rttm.rto_ms().count(), flight_size(), usable_window()); signal_rtx_timeout(); // experimental if(rto_limit_reached()) { - debug(" RTX attempt limit reached, closing.\n"); + debug(" RTX attempt limit reached, closing. rtx=%u syn_rtx=%u\n", + rtx_attempt_, syn_rtx_); abort(); return; } // retransmit SND.UNA - retransmit(); // increases rtx_attempt + retransmit(); + rtx_attempt_++; // "back off" timer rttm.RTO *= 2.0; @@ -1119,13 +1182,20 @@ void Connection::start_dack() void Connection::signal_connect(const bool success) { - // if on read was set before we got a seq number, + // if read request was set before we got a seq number, // update the starting sequence number for the read buffer if(read_request and success) read_request->set_start(cb.RCV.NXT); if(on_connect_) (success) ? on_connect_(retrieve_shared()) : on_connect_(nullptr); + + // If no data event was registered we still want to start buffering here, + // in case the user is not yet ready to subscribe to data. + if (read_request == nullptr and success) { + read_request.reset( + new Read_request(this->cb.RCV.NXT, host_.min_bufsize(), host_.max_bufsize(), bufalloc.get())); + } } void Connection::signal_close() @@ -1149,22 +1219,28 @@ void Connection::clean_up() { if(timewait_dack_timer.is_running()) timewait_dack_timer.stop(); - // necessary to keep the shared_ptr alive during the whole function after _on_cleanup_ is called - // avoids connection being destructed before function is done - auto shared = retrieve_shared(); - // clean up all other copies - // either in TCP::listeners_ (open) or Listener::syn_queue_ (half-open) - if(_on_cleanup_) _on_cleanup_(shared); - + // make sure all our delegates are cleaned up (to avoid circular dependencies) on_connect_.reset(); on_disconnect_.reset(); on_close_.reset(); recv_wnd_getter.reset(); - if(read_request) - read_request->callback.reset(); - _on_cleanup_.reset(); + if(read_request) { + read_request->on_read_callback.reset(); + read_request->on_data_callback.reset(); + } + + + debug2(" Call clean_up delg on %s\n", to_string().c_str()); + // clean up all other copies + // either in TCP::listeners_ (open) or Listener::syn_queue_ (half-open) + if(_on_cleanup_) + _on_cleanup_(this); + + + // if someone put a copy in this delg its their problem.. + //_on_cleanup_.reset(); - debug(" Succesfully cleaned up %s\n", to_string().c_str()); + debug2(" Succesfully cleaned up\n"); } std::string Connection::TCB::to_string() const { @@ -1329,6 +1405,8 @@ bool Connection::uses_SACK() const noexcept void Connection::drop(const Packet_view& packet, [[maybe_unused]]Drop_reason reason) { + /*printf("Drop %s %#.x RCV.WND: %u RCV.NXT %u alloc free: %zu flight size: %u SND.WND: %u \n", + packet.to_string().c_str(), reason, cb.RCV.WND, cb.RCV.NXT, bufalloc->allocatable(), flight_size(), cb.SND.WND);*/ host_.drop(packet); } @@ -1346,12 +1424,12 @@ void Connection::reduce_ssthresh() { fs = (fs >= two_seg) ? fs - two_seg : 0; cb.ssthresh = std::max( (fs / 2), two_seg ); - debug2(" Slow start threshold reduced: %u\n", - cb.ssthresh); + //printf(" Slow start threshold reduced: %u\n", + // cb.ssthresh); } void Connection::fast_retransmit() { - debug(" Fast retransmit initiated.\n"); + //printf(" Fast retransmit initiated.\n"); // reduce sshtresh reduce_ssthresh(); // retransmit segment starting SND.UNA @@ -1366,5 +1444,5 @@ void Connection::finish_fast_recovery() { fast_recovery_ = false; //cb.cwnd = std::min(cb.ssthresh, std::max(flight_size(), (uint32_t)SMSS()) + SMSS()); cb.cwnd = cb.ssthresh; - debug(" Finished Fast Recovery - Cwnd: %u\n", cb.cwnd); + //printf(" Finished Fast Recovery - Cwnd: %u\n", cb.cwnd); } diff --git a/src/net/tcp/connection_states.cpp b/src/net/tcp/connection_states.cpp index 90c7973a4d..acce2b2ba9 100644 --- a/src/net/tcp/connection_states.cpp +++ b/src/net/tcp/connection_states.cpp @@ -94,16 +94,16 @@ using namespace std; bool Connection::State::check_seq(Connection& tcp, Packet_view& in) { auto& tcb = tcp.tcb(); - uint32_t packet_end = static_cast(in.seq() + in.tcp_data_length()-1); // RFC 7323 static constexpr uint8_t HEADER_WITH_TS{sizeof(Header) + 12}; if(tcb.SND.TS_OK and in.tcp_header_length() == HEADER_WITH_TS) { const auto* ts = in.parse_ts_option(); + in.set_ts_option(ts); // PAWS - if(UNLIKELY(ts != nullptr and (ntohl(ts->val) < tcb.TS_recent and !in.isset(RST)))) + if(UNLIKELY(ts != nullptr and (ts->get_val() < tcb.TS_recent and !in.isset(RST)))) { /* If the connection has been idle more than 24 days, @@ -117,7 +117,8 @@ bool Connection::State::check_seq(Connection& tcp, Packet_view& in) debug2(" TCB: %s \n",tcb.to_string().c_str()); // #1 - The packet we expect - if( in.seq() == tcb.RCV.NXT ) { + if( in.seq() == tcb.RCV.NXT ) + { goto acceptable; } /// if SACK isn't permitted there is no point handling out-of-order packets @@ -125,16 +126,7 @@ bool Connection::State::check_seq(Connection& tcp, Packet_view& in) goto unacceptable; // #2 - Packet is ahead of what we expect to receive, but inside our window - if( tcb.RCV.NXT <= in.seq() and in.seq() < tcb.RCV.NXT + tcb.RCV.WND ) { - goto acceptable; - } - // #3 (INVALID) - Packet is outside the right edge of the recv window - else if( packet_end > tcb.RCV.NXT+tcb.RCV.WND ) { - goto unacceptable; - } - // #4 - Packet with payload is what we expect or bigger, but inside our window - else if( tcb.RCV.NXT <= packet_end - and packet_end < tcb.RCV.NXT+tcb.RCV.WND ) { + if( (in.seq() - tcb.RCV.NXT) < tcb.RCV.WND ) { goto acceptable; } /* @@ -149,6 +141,8 @@ bool Connection::State::check_seq(Connection& tcp, Packet_view& in) */ unacceptable: + tcp.update_rcv_wnd(); + if(!in.isset(RST)) tcp.send_ack(); @@ -157,10 +151,13 @@ bool Connection::State::check_seq(Connection& tcp, Packet_view& in) acceptable: const auto* ts = in.ts_option(); + if(tcb.SND.TS_OK) + ts = in.parse_ts_option(); + if(ts != nullptr and - (ntohl(ts->val) >= tcb.TS_recent and in.seq() <= tcp.last_ack_sent_)) + (ts->get_val() >= tcb.TS_recent and in.seq() <= tcp.last_ack_sent_)) { - tcb.TS_recent = ntohl(ts->val); + tcb.TS_recent = ts->get_val(); } debug2(" Acceptable SEQ: %u \n", in.seq()); // is acceptable. @@ -200,8 +197,8 @@ bool Connection::State::check_seq(Connection& tcp, Packet_view& in) void Connection::State::unallowed_syn_reset_connection(Connection& tcp, const Packet_view& in) { assert(in.isset(SYN)); - debug(" Unallowed SYN for STATE: %s, reseting connection.\n", - tcp.state().to_string().c_str()); + debug(" Unallowed SYN for STATE: %s, reseting connection. %s\n", + tcp.state().to_string().c_str(), in.to_string().c_str()); // Not sure if this is the correct way to send a "reset response" auto packet = tcp.outgoing_packet(); packet->set_seq(in.ack()).set_flag(RST); @@ -922,7 +919,8 @@ State::Result Connection::SynReceived::handle(Connection& tcp, Packet_view& in) */ if(tcb.SND.UNA <= in.ack() and in.ack() <= tcb.SND.NXT) { - debug2(" SND.UNA =< SEG.ACK =< SND.NXT, continue in ESTABLISHED. \n"); + debug2(" %s SND.UNA =< SEG.ACK =< SND.NXT, continue in ESTABLISHED.\n", + tcp.to_string().c_str()); tcp.set_state(Connection::Established::instance()); @@ -1061,6 +1059,7 @@ State::Result Connection::FinWait1::handle(Connection& tcp, Packet_view& in) { if(in.ack() == tcp.tcb().SND.NXT) { // TODO: I guess or FIN is ACK'ed..? tcp.set_state(TimeWait::instance()); + tcp.release_memory(); if(tcp.rtx_timer.is_running()) tcp.rtx_stop(); tcp.timewait_start(); @@ -1108,6 +1107,7 @@ State::Result Connection::FinWait2::handle(Connection& tcp, Packet_view& in) { Start the time-wait timer, turn off the other timers. */ tcp.set_state(Connection::TimeWait::instance()); + tcp.release_memory(); if(tcp.rtx_timer.is_running()) tcp.rtx_stop(); tcp.timewait_start(); @@ -1182,6 +1182,7 @@ State::Result Connection::Closing::handle(Connection& tcp, Packet_view& in) { if(in.ack() == tcp.tcb().SND.NXT) { // TODO: I guess or FIN is ACK'ed..? tcp.set_state(TimeWait::instance()); + tcp.release_memory(); tcp.timewait_start(); } diff --git a/src/net/tcp/listener.cpp b/src/net/tcp/listener.cpp index 8b6d65dfb3..8fbbd0bbeb 100644 --- a/src/net/tcp/listener.cpp +++ b/src/net/tcp/listener.cpp @@ -115,12 +115,12 @@ void Listener::segment_arrived(Packet_view& packet) { TCPL_PRINT2(" No receipent\n"); } -void Listener::remove(Connection_ptr conn) { +void Listener::remove(const Connection* conn) { TCPL_PRINT2(" Try remove %s\n", conn->to_string().c_str()); auto it = syn_queue_.begin(); while(it != syn_queue_.end()) { - if((*it) == conn) + if(it->get() == conn) { syn_queue_.erase(it); debug(" %s removed.\n", conn->to_string().c_str()); @@ -132,9 +132,10 @@ void Listener::remove(Connection_ptr conn) { void Listener::connected(Connection_ptr conn) { debug(" %s connected\n", conn->to_string().c_str()); - remove(conn); + remove(conn.get()); Expects(conn->is_connected()); - host_.add_connection(conn); + if (UNLIKELY(! host_.add_connection(conn))) + return; if(on_connect_ != nullptr) on_connect_(conn); diff --git a/src/net/tcp/read_request.cpp b/src/net/tcp/read_request.cpp index 1fba23ba71..66f0c44233 100644 --- a/src/net/tcp/read_request.cpp +++ b/src/net/tcp/read_request.cpp @@ -20,9 +20,8 @@ namespace net { namespace tcp { - Read_request::Read_request(seq_t start, size_t min, size_t max, - ReadCallback cb, Alloc&& alloc) - : callback{cb}, alloc{alloc} + Read_request::Read_request(seq_t start, size_t min, size_t max, Alloc&& alloc) + : alloc{alloc} { buffers.push_back(std::make_unique(start, min, max, alloc)); } @@ -57,7 +56,13 @@ namespace tcp { { const auto rem = buf->capacity() - buf->size(); const auto end_seq = buf->end_seq(); // store end_seq if reseted in callback - if (callback) callback(buf->buffer()); + + if (on_read_callback != nullptr) { + on_read_callback(buf->buffer()); + } else { + // Ready buffer for read_next + complete_buffers.push_back(buf->buffer()); + } // this is the only one, so we can reuse it if(buffers.size() == 1) @@ -99,6 +104,8 @@ namespace tcp { } // < while(n) + signal_data(); + Ensures(not buffers.empty()); return recv; } @@ -185,6 +192,42 @@ namespace tcp { } } + void Read_request::signal_data() { + + if (not complete_buffers.empty()) { + if (on_data_callback != nullptr){ + on_data_callback(); + if (not complete_buffers.empty()) { + // FIXME: Make sure this event gets re-triggered + // For now the user will have to make sure to re-read later if they couldn't + } + } else if (on_read_callback != nullptr) { + for (auto buf : complete_buffers) { + // Pop each time, in case callback leads to another call here. + complete_buffers.pop_front(); + on_read_callback(buf); + } + } + } + } + + size_t Read_request::next_size() { + if (not complete_buffers.empty()) { + return complete_buffers.front()->size(); + } + return 0; + } + + buffer_t Read_request::read_next() { + static const buffer_t empty_buf {}; + if (not complete_buffers.empty()) { + auto buf = complete_buffers.front(); + complete_buffers.pop_front(); + return buf; + } + return empty_buf; + } + void Read_request::reset(const seq_t seq) { Expects(not buffers.empty()); @@ -202,10 +245,13 @@ namespace tcp { // if noone is using the buffer right now, (stupid yes) // AND it contains data without any holes, // return it to the user - if(buf->buffer().unique() and buf->size() > 0 and buf->missing() == 0) + if (buf->has_unhandled_data()) { - callback(buf->buffer()); + complete_buffers.push_back(buf->buffer()); } + + signal_data(); + // reset the first buffer buf->reset(seq); // throw the others away diff --git a/src/net/tcp/tcp.cpp b/src/net/tcp/tcp.cpp index ba97ca5635..a9392761a9 100644 --- a/src/net/tcp/tcp.cpp +++ b/src/net/tcp/tcp.cpp @@ -378,23 +378,6 @@ void TCP::reset_pmtu(Socket dest, IP4::PMTU pmtu) { } } -uint32_t TCP::global_recv_wnd() -{ - using namespace util; - - auto max_use = OS::heap_max() / 4; // TODO: make proportion into variable - auto in_use = OS::heap_usage(); - - if (in_use >= max_use) { - printf("global_recv_wnd: Receive window empty. Heap use: %zu \n", in_use); - return 0; - } - - ssize_t buf_avail = max_use - in_use; - - return std::min(buf_avail, 4_MiB); -} - void TCP::transmit(tcp::Packet_view_ptr packet) { // Generate checksum. @@ -509,19 +492,43 @@ bool TCP::unbind(const Socket& socket) return false; } -void TCP::add_connection(tcp::Connection_ptr conn) { +bool TCP::add_connection(tcp::Connection_ptr conn) +{ + const size_t alloc_thres = max_bufsize() * Read_request::buffer_limit; // Stat increment number of incoming connections (*incoming_connections_)++; debug(" Connection added %s \n", conn->to_string().c_str()); - conn->_on_cleanup({this, &TCP::close_connection}); - conn->bufalloc = mempool_.get_resource(); + auto resource = mempool_.get_resource(); + + // Reject connection if we can't allocate memory + if(UNLIKELY(resource == nullptr or resource->allocatable() < alloc_thres)) + { + conn->_on_cleanup_ = nullptr; + conn->abort(); + return false; + } + + conn->bufalloc = std::move(resource); + + //printf("New inc conn %s allocatable=%zu\n", conn->to_string().c_str(), conn->bufalloc->allocatable()); + Expects(conn->bufalloc != nullptr); - connections_.emplace(conn->tuple(), conn); + conn->_on_cleanup({this, &TCP::close_connection}); + return connections_.emplace(conn->tuple(), conn).second; } Connection_ptr TCP::create_connection(Socket local, Socket remote, ConnectCallback cb) { + const size_t alloc_thres = max_bufsize() * Read_request::buffer_limit; + + auto resource = mempool_.get_resource(); + // Don't create connection if we can't allocate memory + if(UNLIKELY(resource == nullptr or resource->allocatable() < alloc_thres)) + { + throw TCP_error{"Unable to create new connection: Not enough allocatable memory"}; + } + // Stat increment number of outgoing connections (*outgoing_connections_)++; @@ -531,7 +538,10 @@ Connection_ptr TCP::create_connection(Socket local, Socket remote, ConnectCallba ) ).first->second; conn->_on_cleanup({this, &TCP::close_connection}); - conn->bufalloc = mempool_.get_resource(); + conn->bufalloc = std::move(resource); + + //printf("New out conn %s allocatable=%zu\n", conn->to_string().c_str(), conn->bufalloc->allocatable()); + Expects(conn->bufalloc != nullptr); return conn; } diff --git a/src/platform/x86_pc/idt.cpp b/src/platform/x86_pc/idt.cpp index 4b979eb1c7..50904bf724 100644 --- a/src/platform/x86_pc/idt.cpp +++ b/src/platform/x86_pc/idt.cpp @@ -304,7 +304,7 @@ void __page_fault(uintptr_t* regs, uint32_t code) { auto& range = OS::memory_map().at(key); printf("Violated address is in mapped range \"%s\" \n", range.name()); } else { - printf("Violated ddress is outside mapped memory\n"); + printf("Violated address is outside mapped memory\n"); } } diff --git a/src/posix/tcp_fd.cpp b/src/posix/tcp_fd.cpp index a2cf6115bb..ca0d3f2d6d 100644 --- a/src/posix/tcp_fd.cpp +++ b/src/posix/tcp_fd.cpp @@ -311,8 +311,10 @@ ssize_t TCP_FD_Conn::recv(void* dest, size_t len, int) bytes = buffer->size(); }); - // BLOCK HERE - while (!done || !conn->is_readable()) { + // BLOCK HERE: + // 1. if we havent read the data we asked for + // 2. or we aren't readable but not closed (not 100% sure here hehe..) + while (!done || (!conn->is_readable() and !conn->is_closed())) { OS::block(); } // restore diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 75f593b346..2aa96aaa84 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -47,7 +47,7 @@ if (APPLE) string(STRIP ${BREW_LLVM} BREW_LLVM) set(BREW_LIBCXX_INC "-L${BREW_LLVM}/lib -I${BREW_LLVM}/include/c++/v1") message(STATUS "Brew libc++ location: " ${BREW_LIBCXX_INC}) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${BREW_LIBCXX_INC} -stdlib=libc++ -nostdinc++ -Wno-unused-command-line-argument") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${BREW_LIBCXX_INC} -stdlib=libc++ -nostdinc++ -lc++experimental -Wno-unused-command-line-argument") else() set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mmacosx-version-min=10.12") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mmacosx-version-min=10.12") diff --git a/test/lest_util/os_mock.cpp b/test/lest_util/os_mock.cpp index d92fc11216..91a80883c8 100644 --- a/test/lest_util/os_mock.cpp +++ b/test/lest_util/os_mock.cpp @@ -20,14 +20,25 @@ #include #include #include -void* memalign(size_t alignment, size_t size) { +#include + +void* memalign(size_t align, size_t size) { void* ptr {nullptr}; - int res = posix_memalign(&ptr, alignment, size); - Ensures(res == 0); + + if (align < sizeof(void*)) + align = sizeof(void*); + if (size < sizeof(void*)) + size = sizeof(void*); + + int res = posix_memalign(&ptr, align, size); + if (res == EINVAL) + printf("Error %i: posix_memalign got invalid alignment param %zu \n", res, align); + if (res == ENOMEM) + printf("Error %i: posix_memalign failed, not enough memory %zu \n", res); return ptr; } -void* aligned_alloc(size_t alignment, size_t size) { - return memalign(alignment, size); +void* aligned_alloc(size_t align, size_t size) { + return memalign(align, size); } #endif diff --git a/test/net/integration/microLB/server.js b/test/net/integration/microLB/server.js index b6ea1bd9fb..bd9e5e589f 100644 --- a/test/net/integration/microLB/server.js +++ b/test/net/integration/microLB/server.js @@ -1,10 +1,14 @@ var http = require('http'); +var url = require('url') -var dataString = function() { - var len = 1024*1024 * 50; +var dataString = function(len) { return '#'.repeat(len); } +function randomData(len) { + return Array.from({length:len}, () => Math.floor(Math.random() * 40)); +} + var stringToColour = function(str) { var hash = 0; for (var i = 0; i < str.length; i++) { @@ -18,13 +22,56 @@ var stringToColour = function(str) { return colour; } -//We need a function which handles requests and send response -function handleRequest(request, response){ +function handleDigest(path, request, response) { response.setTimeout(500); var addr = request.connection.localPort; response.end(addr.toString() + dataString()); } +function handleFile(path,request, response) { + response.setTimeout(500); + var addr = request.connection.localPort; + var size = parseInt(path.replace("/",""),10); + + if (size == 0) {  + size=1024*64; + } + response.end(addr.toString() + dataString(size)); +} + +function defaultHandler(path,request,response) { + response.setTimeout(500); + var addr = request.connection.localPort; + response.end(addr.toString() + dataString(1024*1024*50)); +} + +var routes = new Map([ + ['/digest' , handleDigest], + ['/file' , handleFile] + ]); + +function findHandler(path) +{ + for (const [key,value] of routes.entries()) { + if (path.startsWith(key)) + { + return { pattern: key, func: value}; + } + } + return { pattern :'',func : defaultHandler}; +} + +function handleRequest(request, response){ + var parts = url.parse(request.url); + + var route = findHandler(parts.pathname); + if (route.func) + { + var path = parts.pathname.replace(route.pattern,''); + route.func(path,request,response); + } +} + http.createServer(handleRequest).listen(6001, '10.0.0.1'); http.createServer(handleRequest).listen(6002, '10.0.0.1'); http.createServer(handleRequest).listen(6003, '10.0.0.1'); diff --git a/test/net/integration/microLB/service.cpp b/test/net/integration/microLB/service.cpp index 0cc210aad2..77a59137b0 100644 --- a/test/net/integration/microLB/service.cpp +++ b/test/net/integration/microLB/service.cpp @@ -15,20 +15,108 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include #include -#include +#include +#include #include +#include + +using namespace util; + +microLB::Balancer* balancer = nullptr; + +void print_nic_stats() { + printf("eth0.sendq_max: %zu, eth0.sendq_now: %zu " + "eth0.stat_rx_total_packets: %zu, eth0.stat_tx_total_packets: %zu, " + "eth0.stat_rx_total_bytes: %zu, eth0.stat_tx_total_bytes: %zu, " + "eth0.sendq_dropped: %zu, eth0.rx_refill_dropped: %zu \n", + Statman::get().get_by_name("eth0.sendq_max").get_uint64(), + Statman::get().get_by_name("eth0.sendq_now").get_uint64(), + Statman::get().get_by_name("eth0.stat_rx_total_packets").get_uint64(), + Statman::get().get_by_name("eth0.stat_tx_total_packets").get_uint64(), + Statman::get().get_by_name("eth0.stat_rx_total_bytes").get_uint64(), + Statman::get().get_by_name("eth0.stat_tx_total_bytes").get_uint64(), + Statman::get().get_by_name("eth0.sendq_dropped").get_uint64(), + Statman::get().get_by_name("eth0.rx_refill_dropped").get_uint64() + ); + + printf("eth1.sendq_max: %zu, eth1.sendq_now: %zu " + "eth1.stat_rx_total_packets: %zu, eth1.stat_tx_total_packets: %zu, " + "eth1.stat_rx_total_bytes: %zu, eth1.stat_tx_total_bytes: %zu, " + "eth1.sendq_dropped: %zu, eth1.rx_refill_dropped: %zu \n", + Statman::get().get_by_name("eth1.sendq_max").get_uint64(), + Statman::get().get_by_name("eth1.sendq_now").get_uint64(), + Statman::get().get_by_name("eth1.stat_rx_total_packets").get_uint64(), + Statman::get().get_by_name("eth1.stat_tx_total_packets").get_uint64(), + Statman::get().get_by_name("eth1.stat_rx_total_bytes").get_uint64(), + Statman::get().get_by_name("eth1.stat_tx_total_bytes").get_uint64(), + Statman::get().get_by_name("eth1.sendq_dropped").get_uint64(), + Statman::get().get_by_name("eth1.rx_refill_dropped").get_uint64() + ); + printf("\n\n"); +} + +void print_mempool_stats() { + auto& inet1 = net::Super_stack::get(0); + auto& inet2 = net::Super_stack::get(1); + printf("\n\nHeap used: %s\n", util::Byte_r(OS::heap_usage()).to_string().c_str()); + auto pool1 = inet1.tcp().mempool(); + auto pool2 = inet2.tcp().mempool(); + + // Hack to get the implementation details (e.g. the detail::pool ptr) for some stats + auto res1 = pool1.get_resource(); + auto res2 = pool2.get_resource(); + + auto pool_ptr1 = res1->pool(); + auto pool_ptr2 = res2->pool(); + + res1.reset(); + res2.reset(); + + printf("\n*** TCP0 ***\n%s\n pool: %zu / %zu allocs: %zu resources: %zu (used: %zu free: %zu)\n\n", + inet1.tcp().to_string().c_str(), pool1.allocated(), pool1.total_capacity(), pool1.alloc_count(), + pool1.resource_count(), pool_ptr1->used_resources(), + pool_ptr1->free_resources()); + printf("*** TCP1 ***\n%s\npool: %zu / %zu allocs: %zu resources: %zu (used: %zu free: %zu)\n", + inet2.tcp().to_string().c_str(), pool2.allocated(), pool2.total_capacity(), pool2.alloc_count(), + pool2.resource_count(), pool_ptr2->used_resources(), + pool_ptr2->free_resources()); +} + +void print_lb_stats() { + FILLINE('-'); + CENTER("LB-Stats"); + auto& nodes = balancer->nodes; + printf("Wait queue: %i nodes: %zu tot_sess: %i open_sess: %i timeout_sess: %i pool_size: %i \n", + balancer->wait_queue(), nodes.size(), nodes.total_sessions(), nodes.open_sessions(), nodes.timed_out_sessions(), nodes.pool_size()); + printf("\n\n"); +} void Service::start() { - static auto* balancer = microLB::Balancer::from_config(); + balancer = microLB::Balancer::from_config(); printf("MicroLB ready for test\n"); - auto& inet = net::Super_stack::get(0); - inet.tcp().set_MSL(std::chrono::seconds(2)); + auto& inet1 = net::Super_stack::get(0); + auto& inet2 = net::Super_stack::get(1); + inet1.tcp().set_MSL(std::chrono::seconds(2)); + + // Increasing TCP buffer size may increase throughput + //inet1.tcp().set_total_bufsize(256_MiB); + //inet2.tcp().set_total_bufsize(256_MiB); Timers::oneshot(std::chrono::seconds(5), [] (int) { printf("TCP MSL ended (4 seconds)\n"); }); + //StackSampler::begin(); + + Timers::periodic(2s, 5s, [](auto) { + //StackSampler::print(10); + print_nic_stats(); + print_mempool_stats(); + print_lb_stats(); + }); } diff --git a/test/net/integration/microLB/test.py b/test/net/integration/microLB/test.py index 24777b3f30..3fe5734756 100755 --- a/test/net/integration/microLB/test.py +++ b/test/net/integration/microLB/test.py @@ -16,8 +16,7 @@ expected_string = "#" * 1024 * 1024 * 50 def validateRequest(addr): - response = requests.get('https://10.0.0.68:443', verify=False) - #print (response.content) + response = requests.get('https://10.0.0.68:443', verify=False, timeout=5) return (response.content) == str(addr) + expected_string # start nodeJS diff --git a/test/net/integration/microLB/vm.json b/test/net/integration/microLB/vm.json index 85f68a14cf..8c0a4549e9 100644 --- a/test/net/integration/microLB/vm.json +++ b/test/net/integration/microLB/vm.json @@ -6,5 +6,5 @@ {"device" : "virtio"}, {"device" : "virtio"} ], - "mem" : 64 + "mem" : 256 } diff --git a/test/net/integration/tcp/service.cpp b/test/net/integration/tcp/service.cpp index 2b7ae102a8..f7f350e59a 100644 --- a/test/net/integration/tcp/service.cpp +++ b/test/net/integration/tcp/service.cpp @@ -24,6 +24,7 @@ using namespace net; using namespace std::chrono; // For timers and MSL +using namespace util; // For KiB/MiB/GiB literals tcp::Connection_ptr client; static Inet& stack() @@ -33,7 +34,7 @@ static Inet& stack() TEST VARIABLES */ tcp::port_t -TEST1{8081}, TEST2{8082}, TEST3{8083}, TEST4{8084}, TEST5{8085}; +TEST0{8080},TEST1{8081}, TEST2{8082}, TEST3{8083}, TEST4{8084}, TEST5{8085}; using HostAddress = std::pair; HostAddress @@ -132,6 +133,8 @@ struct Buffer { std::string str() { return {data, size};} }; +size_t recv = 0; +size_t chunks = 0; void Service::start() { #ifdef USERSPACE_LINUX @@ -166,6 +169,9 @@ void Service::start() // reduce test duration tcp.set_MSL(MSL_TEST); + // Modify total buffers assigned to TCP here + tcp.set_total_bufsize(64_MiB); + /* TEST: Send and receive small string. */ @@ -177,6 +183,24 @@ void Service::start() CHECK(tcp.listening_ports() == 0, "No (0) open ports (listening connections)"); CHECK(tcp.active_connections() == 0, "No (0) active connections"); + // Trigger with e.g.: + // dd if=/dev/zero bs=9000 count=1000000 | nc 10.0.0.44 8080 | grep Received -a + tcp.listen(TEST0).on_connect([](tcp::Connection_ptr conn) { + INFO("Test 0", "Circle of Evil"); + conn->on_read(424242, [conn](tcp::buffer_t buffer) { + recv += buffer->size(); + chunks++; + if (chunks % 100 == 0) { + std::string res = std::string("Received ") + util::Byte_r(recv).to_string() + "\n"; + printf("%s", res.c_str()); + auto new_buf = std::make_shared>(res.begin(), res.end()); + conn->write(new_buf); + } + conn->write(buffer); + }); + }); + + tcp.listen(TEST1).on_connect([](tcp::Connection_ptr conn) { INFO("Test 1", "SMALL string (%u)", small.size()); conn->on_read(small.size(), [conn](tcp::buffer_t buffer) { diff --git a/test/net/unit/tcp_read_request_test.cpp b/test/net/unit/tcp_read_request_test.cpp index 4ff0e4f2a6..a157c024ad 100644 --- a/test/net/unit/tcp_read_request_test.cpp +++ b/test/net/unit/tcp_read_request_test.cpp @@ -34,7 +34,8 @@ CASE("Operating with out of order data") no_reads++; }; - auto req = std::make_unique(seq, BUFSZ, BUFSZ, read_cb); + auto req = std::make_unique(seq, BUFSZ, BUFSZ); + req->on_read_callback = read_cb; no_reads = 0; // Insert hole, first missing diff --git a/test/posix/integration/tcp/test.py b/test/posix/integration/tcp/test.py index 7ce8aab968..0c5b27c547 100755 --- a/test/posix/integration/tcp/test.py +++ b/test/posix/integration/tcp/test.py @@ -50,6 +50,7 @@ def TCP_connect(): sock.connect((HOST, PORT)) MESSAGE = "POSIX is for hipsters" sock.send(MESSAGE) + sock.close() def TCP_recv(trigger_line): server.listen(1) diff --git a/test/stress/test.py b/test/stress/test.py index 4b35dc3144..92cae05a37 100755 --- a/test/stress/test.py +++ b/test/stress/test.py @@ -124,7 +124,7 @@ def httperf(burst_size = BURST_SIZE, burst_interval = BURST_INTERVAL): # Fire a single burst of ARP requests def ARP_burst(burst_size = BURST_SIZE, burst_interval = BURST_INTERVAL): # Note: Arping requires sudo, and we expect the bridge 'bridge43' to be present - command = ["sudo", "arping", "-q","-w", str(100), "-I", "bridge43", "-c", str(burst_size * 10), HOST] + command = ["sudo", "arping", "-q","-W", str(0.0001), "-I", "bridge43", "-c", str(burst_size * 10), HOST] print color.DATA(" ".join(command)) time.sleep(0.5) res = subprocess32.check_call(command, timeout=thread_timeout); diff --git a/test/util/unit/buddy_alloc_test.cpp b/test/util/unit/buddy_alloc_test.cpp index 888d5a80ad..01d45f079d 100644 --- a/test/util/unit/buddy_alloc_test.cpp +++ b/test/util/unit/buddy_alloc_test.cpp @@ -279,7 +279,7 @@ CASE("mem::buddy random chaos with data verification"){ std::vector allocs; for (auto rnd : test::random_1k) { - auto sz = std::max(rnd % alloc.pool_size_ / 1024, alloc.min_size); + auto sz = std::max(rnd % alloc.pool_size_ / 1024, alloc.min_size); EXPECT(sz); if (not alloc.full()) { diff --git a/test/util/unit/pmr_alloc_test.cpp b/test/util/unit/pmr_alloc_test.cpp index 8816e62297..1f537bc0b3 100644 --- a/test/util/unit/pmr_alloc_test.cpp +++ b/test/util/unit/pmr_alloc_test.cpp @@ -14,11 +14,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -//#define DEBUG_UNIT +#define DEBUG_UNIT #include #include #include + +#if __has_include() +#include +#endif + +#include #include CASE("pmr::default_pmr_resource") { @@ -37,7 +43,7 @@ CASE("pmr::Pmr_pool usage") { constexpr auto pool_cap = 40_MiB; // Using default resource capacity, which is pool_cap / allocator count - os::mem::Pmr_pool pool{pool_cap}; + os::mem::Pmr_pool pool{pool_cap, pool_cap}; EXPECT(pool.total_capacity() == pool_cap); @@ -49,7 +55,6 @@ CASE("pmr::Pmr_pool usage") { std::pmr::polymorphic_allocator alloc{res.get()}; std::pmr::vector numbers{alloc}; - EXPECT(numbers.capacity() < 1000); numbers.reserve(1000); EXPECT(numbers.capacity() == 1000); @@ -79,7 +84,6 @@ CASE("pmr::Pmr_pool usage") { my_strings.push_back("Still works"); EXPECT(my_strings.back() == "Still works"); - // Using small res capacity constexpr auto alloc_cap = 4_KiB; @@ -115,6 +119,7 @@ CASE("pmr::Pmr_pool usage") { EXPECT(numbers2.capacity() < 1000); EXPECT(res2->allocatable() < alloc_cap); EXPECT(res2->allocatable() > alloc_cap - 1000); + } @@ -161,19 +166,46 @@ CASE("pmr::resource usage") { // Drain all the resources for (auto& res : resources) { + auto exp_alloc = resource_cap; + EXPECT(not res->full()); + EXPECT(pool.allocatable() >= exp_alloc); + EXPECT(res->allocatable() == exp_alloc); + EXPECT(res->allocated() == 0); + auto* p1 = res->allocate(1_KiB); + exp_alloc -= 1_KiB; + EXPECT(res->allocated() == 1_KiB); + EXPECT(res->capacity() == resource_cap); + EXPECT(pool.allocatable() >= exp_alloc); + EXPECT(res->allocatable() == exp_alloc); + auto* p2 = res->allocate(1_KiB); + exp_alloc -= 1_KiB; + EXPECT(res->allocated() == 2_KiB); + EXPECT(pool.allocatable() >= exp_alloc); + EXPECT(res->allocatable() == exp_alloc); + auto* p3 = res->allocate(1_KiB); + exp_alloc -= 1_KiB; + EXPECT(res->allocated() == 3_KiB); + EXPECT(pool.allocatable() >= exp_alloc); + EXPECT(res->allocatable() == exp_alloc); + auto* p4 = res->allocate(1_KiB); + exp_alloc -= 1_KiB; + EXPECT(res->allocated() == 4_KiB); + EXPECT(pool.allocatable() >= exp_alloc); + EXPECT(res->allocatable() == exp_alloc); + EXPECT(p1 != nullptr); EXPECT(p2 != nullptr); EXPECT(p3 != nullptr); EXPECT(p4 != nullptr); - allocations.at(res.get()).push_back(p1); - allocations.at(res.get()).push_back(p2); - allocations.at(res.get()).push_back(p3); - allocations.at(res.get()).push_back(p4); + allocations[res.get()].push_back(p1); + allocations[res.get()].push_back(p2); + allocations[res.get()].push_back(p3); + allocations[res.get()].push_back(p4); EXPECT(res->full()); EXPECT_THROWS(res->allocate(1_KiB)); @@ -215,21 +247,20 @@ CASE("pmr::resource usage") { for (auto alloc : vec) pool->deallocate(alloc, 1_KiB); + EXPECT(pool.empty()); EXPECT(not pool.full()); EXPECT(pool.allocatable() == pool_cap); - // Each resource's state is remembered as it's passed back and forth. - // ...There's now no way of fetching any non-full resources - auto res_tricked = pool.get_resource(); - EXPECT(pool.resource_count() == resource_count); - EXPECT(res_tricked->full()); - EXPECT(res_tricked->allocatable() == 0); - EXPECT_THROWS(res_tricked->allocate(1_KiB)); + auto res_reused = pool.get_resource(); + EXPECT(pool.resource_count() == resource_count); - res_tricked.reset(); + EXPECT(res_reused->empty()); + EXPECT(res_reused->allocatable() == resource_cap); + EXPECT(pool_ptr->free_resources() == resource_count - 1); + EXPECT(pool_ptr->used_resources() == 1); - pool_ptr->clear_free_resources(); + res_reused.reset(); auto res2 = pool.get_resource(); @@ -242,6 +273,127 @@ CASE("pmr::resource usage") { } -CASE("pmr::Resource performance") { +CASE("pmr::on_non_full event") { using namespace util; + constexpr auto pool_cap = 400_KiB; + + // Using default resource capacity, which is pool_cap / allocator count + os::mem::Pmr_pool pool{pool_cap, pool_cap}; + auto res = pool.get_resource(); + bool event_fired = false; + + res->on_non_full([&](auto& r){ + EXPECT(&r == res.get()); + EXPECT(not r.full()); + event_fired = true; + }); + + std::pmr::polymorphic_allocator alloc{res.get()}; + std::pmr::vector numbers{alloc}; + auto reserved = pool_cap - 2; + numbers.reserve(reserved); + EXPECT(numbers.capacity() == reserved); + EXPECT(res->allocated() == reserved); + EXPECT(not event_fired); + + numbers.push_back(0); + numbers.push_back(1); + + // In order to shrink, it needs to allocate new space for 2 chars then copy. + numbers.shrink_to_fit(); + EXPECT(res->allocated() < reserved); + EXPECT(event_fired); + event_fired = false; + EXPECT(not event_fired); + + for (int i = 2; i < pool_cap / 2; i++) { + numbers.push_back(i); + } + + EXPECT(not event_fired); + EXPECT(not res->full()); + + // Reduce capacity, making the resource full right now + pool.set_resource_capacity(pool_cap / 3); + numbers.clear(); + numbers.shrink_to_fit(); + EXPECT(event_fired); + +} + +CASE("pmr::on_avail event") { + using namespace util; + constexpr auto pool_cap = 400_KiB; + + // Using default resource capacity, which is pool_cap / allocator count + os::mem::Pmr_pool pool{pool_cap, pool_cap}; + auto res = pool.get_resource(); + bool event_fired = false; + + res->on_avail(200_KiB, [&](auto& r){ + EXPECT(&r == res.get()); + EXPECT(not r.full()); + EXPECT(r.allocatable() >= 200_KiB); + event_fired = true; + }); + + std::pmr::polymorphic_allocator alloc{res.get()}; + std::pmr::vector numbers{alloc}; + + numbers.push_back(0); + numbers.push_back(1); + EXPECT(not event_fired); + + auto reserved = 201_KiB; + numbers.reserve(reserved); + EXPECT(numbers.capacity() == reserved); + EXPECT(res->allocated() == reserved); + EXPECT(not event_fired); + + // In order to shrink, it needs to allocate new space for 2 chars then copy. + numbers.shrink_to_fit(); + EXPECT(res->allocated() < reserved); + EXPECT(event_fired); + event_fired = false; + EXPECT(not event_fired); + + for (int i = 2; i < 40_KiB; i++) { + numbers.push_back(i); + } + + EXPECT(not event_fired); + EXPECT(not res->full()); + + numbers.clear(); + numbers.shrink_to_fit(); + EXPECT(not event_fired); + +} + + +CASE("pmr::default resource cap") { + // Not providing a resource cap will give each resource a proportion of max + + using namespace util; + constexpr auto pool_cap = 400_KiB; + + // Using default resource capacity, which is pool_cap / allocator count + os::mem::Pmr_pool pool{pool_cap}; + auto res1 = pool.get_resource(); + auto expected = pool_cap / (1 + os::mem::Pmr_pool::resource_division_offset); + EXPECT(res1->allocatable() == expected); + + auto res2 = pool.get_resource(); + expected = pool_cap / (2 + os::mem::Pmr_pool::resource_division_offset); + EXPECT(res2->allocatable() == expected); + + auto res3 = pool.get_resource(); + expected = pool_cap / (3 + os::mem::Pmr_pool::resource_division_offset); + EXPECT(res3->allocatable() == expected); + + auto res4 = pool.get_resource(); + expected = pool_cap / (4 + os::mem::Pmr_pool::resource_division_offset); + EXPECT(res4->allocatable() == expected); + + }