Skip to content

Commit

Permalink
NetworkClientSecure made copyable
Browse files Browse the repository at this point in the history
  • Loading branch information
JAndrassy committed May 10, 2024
1 parent e8e251a commit 6782090
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 30 deletions.
49 changes: 24 additions & 25 deletions libraries/NetworkClientSecure/src/NetworkClientSecure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ NetworkClientSecure::NetworkClientSecure() {
_connected = false;
_timeout = 30000; // Same default as ssl_client

sslclient = new sslclient_context;
ssl_init(sslclient);
sslclient.reset(new sslclient_context, [](struct sslclient_context *sslclient) {
stop_ssl_socket(sslclient);
delete sslclient;

});
ssl_init(sslclient.get());
sslclient->socket = -1;
sslclient->handshake_timeout = 120000;
_use_insecure = false;
Expand All @@ -53,8 +57,12 @@ NetworkClientSecure::NetworkClientSecure(int sock) {
_lastReadTimeout = 0;
_lastWriteTimeout = 0;

sslclient = new sslclient_context;
ssl_init(sslclient);
sslclient.reset(new sslclient_context, [](struct sslclient_context *sslclient) {
stop_ssl_socket(sslclient);
delete sslclient;

});
ssl_init(sslclient.get());
sslclient->socket = sock;
sslclient->handshake_timeout = 120000;

Expand All @@ -72,19 +80,10 @@ NetworkClientSecure::NetworkClientSecure(int sock) {
}

NetworkClientSecure::~NetworkClientSecure() {
stop();
delete sslclient;
}

NetworkClientSecure &NetworkClientSecure::operator=(const NetworkClientSecure &other) {
stop();
sslclient->socket = other.sslclient->socket;
_connected = other._connected;
return *this;
}

void NetworkClientSecure::stop() {
stop_ssl_socket(sslclient, _CA_cert, _cert, _private_key);
stop_ssl_socket(sslclient.get());

_connected = false;
_peek = -1;
Expand Down Expand Up @@ -130,10 +129,10 @@ int NetworkClientSecure::connect(const char *host, uint16_t port, const char *CA
}

int NetworkClientSecure::connect(IPAddress ip, uint16_t port, const char *host, const char *CA_cert, const char *cert, const char *private_key) {
int ret = start_ssl_client(sslclient, ip, port, host, _timeout, CA_cert, _use_ca_bundle, cert, private_key, NULL, NULL, _use_insecure, _alpn_protos);
int ret = start_ssl_client(sslclient.get(), ip, port, host, _timeout, CA_cert, _use_ca_bundle, cert, private_key, NULL, NULL, _use_insecure, _alpn_protos);

if (ret >= 0 && !_stillinPlainStart) {
ret = ssl_starttls_handshake(sslclient);
ret = ssl_starttls_handshake(sslclient.get());
} else {
log_i("Actual TLS start postponed.");
}
Expand All @@ -153,7 +152,7 @@ int NetworkClientSecure::startTLS() {
int ret = 1;
if (_stillinPlainStart) {
log_i("startTLS: starting TLS/SSL on this dplain connection");
ret = ssl_starttls_handshake(sslclient);
ret = ssl_starttls_handshake(sslclient.get());
if (ret < 0) {
log_e("startTLS: %d", ret);
stop();
Expand All @@ -178,7 +177,7 @@ int NetworkClientSecure::connect(const char *host, uint16_t port, const char *ps
return 0;
}

int ret = start_ssl_client(sslclient, address, port, host, _timeout, NULL, false, NULL, NULL, pskIdent, psKey, _use_insecure, _alpn_protos);
int ret = start_ssl_client(sslclient.get(), address, port, host, _timeout, NULL, false, NULL, NULL, pskIdent, psKey, _use_insecure, _alpn_protos);
_lastError = ret;
if (ret < 0) {
log_e("start_ssl_client: connect failed %d", ret);
Expand Down Expand Up @@ -213,7 +212,7 @@ size_t NetworkClientSecure::write(const uint8_t *buf, size_t size) {
}

if (_stillinPlainStart) {
return send_net_data(sslclient, buf, size);
return send_net_data(sslclient.get(), buf, size);
}

if (_lastWriteTimeout != _timeout) {
Expand All @@ -224,7 +223,7 @@ size_t NetworkClientSecure::write(const uint8_t *buf, size_t size) {
_lastWriteTimeout = _timeout;
}
}
int res = send_ssl_data(sslclient, buf, size);
int res = send_ssl_data(sslclient.get(), buf, size);
if (res < 0) {
log_e("Closing connection on failed write");
stop();
Expand All @@ -235,7 +234,7 @@ size_t NetworkClientSecure::write(const uint8_t *buf, size_t size) {

int NetworkClientSecure::read(uint8_t *buf, size_t size) {
if (_stillinPlainStart) {
return get_net_receive(sslclient, buf, size);
return get_net_receive(sslclient.get(), buf, size);
}

if (_lastReadTimeout != _timeout) {
Expand Down Expand Up @@ -268,7 +267,7 @@ int NetworkClientSecure::read(uint8_t *buf, size_t size) {
buf++;
peeked = 1;
}
res = get_ssl_receive(sslclient, buf, size);
res = get_ssl_receive(sslclient.get(), buf, size);

if (res < 0) {
log_e("Closing connection on failed read");
Expand All @@ -280,14 +279,14 @@ int NetworkClientSecure::read(uint8_t *buf, size_t size) {

int NetworkClientSecure::available() {
if (_stillinPlainStart) {
return peek_net_receive(sslclient, 0);
return peek_net_receive(sslclient.get(), 0);
}

int peeked = (_peek >= 0), res = -1;
if (!_connected) {
return peeked;
}
res = data_to_read(sslclient);
res = data_to_read(sslclient.get());

if (res < 0 && !_stillinPlainStart) {
log_e("Closing connection on failed available check");
Expand Down Expand Up @@ -346,7 +345,7 @@ bool NetworkClientSecure::verify(const char *fp, const char *domain_name) {
return false;
}

return verify_ssl_fingerprint(sslclient, fp, domain_name);
return verify_ssl_fingerprint(sslclient.get(), fp, domain_name);
}

char *NetworkClientSecure::_streamLoad(Stream &stream, size_t size) {
Expand Down
7 changes: 4 additions & 3 deletions libraries/NetworkClientSecure/src/NetworkClientSecure.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
#include "IPAddress.h"
#include "Network.h"
#include "ssl_client.h"
#include <memory>

class NetworkClientSecure : public NetworkClient {
protected:
sslclient_context *sslclient;
std::shared_ptr<sslclient_context> sslclient;

int _lastError = 0;
int _peek = -1;
Expand Down Expand Up @@ -97,14 +98,14 @@ class NetworkClientSecure : public NetworkClient {
return mbedtls_ssl_get_peer_cert(&sslclient->ssl_ctx);
};
bool getFingerprintSHA256(uint8_t sha256_result[32]) {
return get_peer_fingerprint(sslclient, sha256_result);
return get_peer_fingerprint(sslclient.get(), sha256_result);
};
int fd() const;

operator bool() {
return connected();
}
NetworkClientSecure &operator=(const NetworkClientSecure &other);

bool operator==(const bool value) {
return bool() == value;
}
Expand Down
2 changes: 1 addition & 1 deletion libraries/NetworkClientSecure/src/ssl_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ int ssl_starttls_handshake(sslclient_context *ssl_client) {
return ssl_client->socket;
}

void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key) {
void stop_ssl_socket(sslclient_context *ssl_client) {
log_v("Cleaning SSL connection.");

if (ssl_client->socket >= 0) {
Expand Down
2 changes: 1 addition & 1 deletion libraries/NetworkClientSecure/src/ssl_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ int start_ssl_client(
const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey, bool insecure, const char **alpn_protos
);
int ssl_starttls_handshake(sslclient_context *ssl_client);
void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key);
void stop_ssl_socket(sslclient_context *ssl_client);
int data_to_read(sslclient_context *ssl_client);
int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, size_t len);
int get_ssl_receive(sslclient_context *ssl_client, uint8_t *data, int length);
Expand Down

0 comments on commit 6782090

Please sign in to comment.