Skip to content

Commit

Permalink
Merge pull request #1999 from alfred-bratterud/v0.13.x
Browse files Browse the repository at this point in the history
V0.13.x merged as it builds on vaskemaskin
  • Loading branch information
KristianJerpetjon authored Jan 22, 2019
2 parents 461343d + 39a3790 commit cc87322
Show file tree
Hide file tree
Showing 50 changed files with 1,893 additions and 597 deletions.
12 changes: 12 additions & 0 deletions api/net/botan/tls_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@ class Server : public Botan::TLS::Callbacks, public net::Stream
m_transport->on_read(bs, {this, &Server::tls_read});
this->m_on_read = cb;
}
void on_data(DataCallback cb) override {
// FIXME
throw std::runtime_error("on_data not implemented on botan::server");
}
size_t next_size() override {
// FIXME
throw std::runtime_error("next_size not implemented on botan::server");
}
buffer_t read_next() override {
// FIXME
throw std::runtime_error("read_next not implemented on botan::server");
}
void on_write(WriteCallback cb) override {
this->m_on_write = cb;
}
Expand Down
288 changes: 10 additions & 278 deletions api/net/openssl/tls_stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
#include <openssl/err.h>
#include <openssl/pem.h>
#include <openssl/ssl.h>
#include <net/stream.hpp>
#include <net/stream_buffer.hpp>

//#define VERBOSE_OPENSSL
//#define VERBOSE_OPENSSL 0
#ifdef VERBOSE_OPENSSL
#define TLS_PRINT(fmt, ...) printf(fmt, ##__VA_ARGS__)
#define TLS_PRINT(fmt, ...) printf("TLS_Stream");printf(fmt, ##__VA_ARGS__)
#else
#define TLS_PRINT(fmt, ...) /* fmt */
#endif

namespace openssl
{
struct TLS_stream : public net::Stream
struct TLS_stream : public net::StreamBuffer
{
using Stream_ptr = net::Stream_ptr;

Expand All @@ -25,7 +25,6 @@ namespace openssl
void write(const std::string&) override;
void write(const void* buf, size_t n) override;
void close() override;
void reset_callbacks() override;

net::Socket local() const override {
return m_transport->local();
Expand All @@ -37,24 +36,11 @@ namespace openssl
return m_transport->to_string();
}

void on_connect(ConnectCallback cb) override {
m_on_connect = std::move(cb);
}
void on_read(size_t, ReadCallback cb) override {
m_on_read = std::move(cb);
}
void on_close(CloseCallback cb) override {
m_on_close = std::move(cb);
}
void on_write(WriteCallback cb) override {
m_on_write = std::move(cb);
}

bool is_connected() const noexcept override {
return handshake_completed() && m_transport->is_connected();
}
bool is_writable() const noexcept override {
return is_connected() && m_transport->is_writable();
return (not write_congested()) && is_connected() && m_transport->is_writable();
}
bool is_readable() const noexcept override {
return m_transport->is_readable();
Expand All @@ -76,7 +62,12 @@ namespace openssl

size_t serialize_to(void*) const override;

void handle_read_congestion() override;
void handle_write_congestion() override;
private:
void handle_data();
int decrypt(const void *data,int size);
int send_decrypted();
void tls_read(buffer_t);
int tls_perform_stream_write();
int tls_perform_handshake();
Expand All @@ -89,271 +80,12 @@ namespace openssl
STATUS_FAIL
};
status_t status(int n) const noexcept;

Stream_ptr m_transport = nullptr;
SSL* m_ssl = nullptr;
BIO* m_bio_rd = nullptr;
BIO* m_bio_wr = nullptr;
bool m_busy = false;
bool m_deferred_close = false;
ConnectCallback m_on_connect = nullptr;
ReadCallback m_on_read = nullptr;
WriteCallback m_on_write = nullptr;
CloseCallback m_on_close = nullptr;
};

inline TLS_stream::TLS_stream(SSL_CTX* ctx, Stream_ptr t, bool outgoing)
: m_transport(std::move(t))
{
ERR_clear_error(); // prevent old errors from mucking things up
this->m_bio_rd = BIO_new(BIO_s_mem());
this->m_bio_wr = BIO_new(BIO_s_mem());
assert(ERR_get_error() == 0 && "Initializing BIOs");
this->m_ssl = SSL_new(ctx);
assert(this->m_ssl != nullptr);
assert(ERR_get_error() == 0 && "Initializing SSL");
// TLS server-mode
if (outgoing == false)
SSL_set_accept_state(this->m_ssl);
else
SSL_set_connect_state(this->m_ssl);

SSL_set_bio(this->m_ssl, this->m_bio_rd, this->m_bio_wr);
// always-on callbacks
m_transport->on_read(8192, {this, &TLS_stream::tls_read});
m_transport->on_close({this, &TLS_stream::close_callback_once});

// start TLS handshake process
if (outgoing == true)
{
if (this->tls_perform_handshake() < 0) return;
}
}
inline TLS_stream::TLS_stream(Stream_ptr t, SSL* ssl, BIO* rd, BIO* wr)
: m_transport(std::move(t)), m_ssl(ssl), m_bio_rd(rd), m_bio_wr(wr)
{
// always-on callbacks
m_transport->on_read(8192, {this, &TLS_stream::tls_read});
m_transport->on_close({this, &TLS_stream::close_callback_once});
}
inline TLS_stream::~TLS_stream()
{
assert(m_busy == false && "Cannot delete stream while in its call stack");
SSL_free(this->m_ssl);
}

inline void TLS_stream::write(buffer_t buffer)
{
if (UNLIKELY(this->is_connected() == false)) {
TLS_PRINT("TLS_stream::write() called on closed stream\n");
return;
}

int n = SSL_write(this->m_ssl, buffer->data(), buffer->size());
auto status = this->status(n);
if (status == STATUS_FAIL) {
this->close();
return;
}

do {
n = tls_perform_stream_write();
} while (n > 0);
}
inline void TLS_stream::write(const std::string& str)
{
write(net::Stream::construct_buffer(str.data(), str.data() + str.size()));
}
inline void TLS_stream::write(const void* data, const size_t len)
{
auto* buf = static_cast<const uint8_t*> (data);
write(net::Stream::construct_buffer(buf, buf + len));
}

inline void TLS_stream::tls_read(buffer_t buffer)
{
ERR_clear_error();
uint8_t* buf = buffer->data();
int len = buffer->size();

while (len > 0)
{
int n = BIO_write(this->m_bio_rd, buf, len);
if (UNLIKELY(n < 0)) {
this->close();
return;
}
buf += n;
len -= n;

// if we aren't finished initializing session
if (UNLIKELY(!handshake_completed()))
{
int num = SSL_do_handshake(this->m_ssl);
auto status = this->status(num);

// OpenSSL wants to write
if (status == STATUS_WANT_IO)
{
tls_perform_stream_write();
}
else if (status == STATUS_FAIL)
{
if (num < 0) {
TLS_PRINT("TLS_stream::SSL_do_handshake() returned %d\n", num);
#ifdef VERBOSE_OPENSSL
ERR_print_errors_fp(stdout);
#endif
}
this->close();
return;
}
// nothing more to do if still not finished
if (handshake_completed() == false) return;
// handshake success
if (m_on_connect) m_on_connect(*this);
}

// read decrypted data
do {
char temp[8192];
n = SSL_read(this->m_ssl, temp, sizeof(temp));
if (n > 0) {
auto buf = net::Stream::construct_buffer(temp, temp + n);
if (m_on_read) {
this->m_busy = true;
m_on_read(std::move(buf));
this->m_busy = false;
}
}
} while (n > 0);
// this goes here?
if (UNLIKELY(this->is_closing() || this->is_closed())) {
TLS_PRINT("TLS_stream::SSL_read closed during read\n");
return;
}
if (this->m_deferred_close) {
this->close(); return;
}

auto status = this->status(n);
// did peer request stream renegotiation?
if (status == STATUS_WANT_IO)
{
do {
n = tls_perform_stream_write();
} while (n > 0);
}
else if (status == STATUS_FAIL)
{
this->close();
return;
}
// check deferred closing
if (this->m_deferred_close) {
this->close(); return;
}

} // while it < end
} // tls_read()

inline int TLS_stream::tls_perform_stream_write()
{
ERR_clear_error();
int pending = BIO_ctrl_pending(this->m_bio_wr);
//printf("pending: %d\n", pending);
if (pending > 0)
{
auto buffer = net::Stream::construct_buffer(pending);
int n = BIO_read(this->m_bio_wr, buffer->data(), buffer->size());
assert(n == pending);
m_transport->write(buffer);
if (m_on_write) {
this->m_busy = true;
m_on_write(n);
this->m_busy = false;
}
return n;
}
else {
BIO_read(this->m_bio_wr, nullptr, 0);
}
if (!BIO_should_retry(this->m_bio_wr))
{
this->close();
return -1;
}
return 0;
}
inline int TLS_stream::tls_perform_handshake()
{
ERR_clear_error(); // prevent old errors from mucking things up
// will return -1:SSL_ERROR_WANT_WRITE
int ret = SSL_do_handshake(this->m_ssl);
int n = this->status(ret);
ERR_print_errors_fp(stderr);
if (n == STATUS_WANT_IO)
{
do {
n = tls_perform_stream_write();
if (n < 0) {
TLS_PRINT("TLS_stream::tls_perform_handshake() stream write failed\n");
}
} while (n > 0);
return n;
}
else {
TLS_PRINT("TLS_stream::tls_perform_handshake() returned %d\n", ret);
this->close();
return -1;
}
}

inline void TLS_stream::close()
{
//ERR_clear_error();
if (this->m_busy) {
this->m_deferred_close = true; return;
}
CloseCallback func = std::move(this->m_on_close);
this->reset_callbacks();
if (m_transport->is_connected())
m_transport->close();
if (func) func();
}
inline void TLS_stream::close_callback_once()
{
if (this->m_busy) {
this->m_deferred_close = true; return;
}
CloseCallback func = std::move(this->m_on_close);
this->reset_callbacks();
if (func) func();
}
inline void TLS_stream::reset_callbacks()
{
this->m_on_close = nullptr;
this->m_on_connect = nullptr;
this->m_on_read = nullptr;
this->m_on_write = nullptr;
}

inline bool TLS_stream::handshake_completed() const noexcept
{
return SSL_is_init_finished(this->m_ssl);
}
inline TLS_stream::status_t TLS_stream::status(int n) const noexcept
{
int error = SSL_get_error(this->m_ssl, n);
switch (error)
{
case SSL_ERROR_NONE:
return STATUS_OK;
case SSL_ERROR_WANT_WRITE:
case SSL_ERROR_WANT_READ:
return STATUS_WANT_IO;
default:
return STATUS_FAIL;
}
}
} // openssl
21 changes: 20 additions & 1 deletion api/net/stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,32 @@ namespace net {
/** Called with a shared buffer and the length of the data when received. */
using ReadCallback = delegate<void(buffer_t)>;
/**
* @brief Event when data is received.
* @brief Event when data is received. Pushes data to the callback.
*
* @param[in] n The size of the receive buffer
* @param[in] cb The read callback
*/
virtual void on_read(size_t n, ReadCallback cb) = 0;

using DataCallback = delegate<void()>;
/**
* @brief Event when data is received.
* Does not push data, just signals its presence.
*
* @param[in] cb The callback
*/
virtual void on_data(DataCallback cb) = 0;

/**
* @return The size of the next available chunk of data if any.
*/
virtual size_t next_size() = 0;

/**
* @return The next available chunk of data if any.
*/
virtual buffer_t read_next() = 0;

/** Called with nothing ¯\_(ツ)_/¯ */
using CloseCallback = delegate<void()>;
/**
Expand Down
Loading

0 comments on commit cc87322

Please sign in to comment.