/* Copyright© 2021 John Sennesael UsenetSearch is Free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. UsenetSearch is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with UsenetSearch. If not, see . */ #include "usenetsearch/SSLConnection.h" #include "usenetsearch/Logger.h" #include #include #include #include #include namespace usenetsearch { SSLConnection::SSLConnection(std::unique_ptr connection) { m_tcpConnection = std::move(connection); } SSLConnection::SSLReturnState SSLConnection::CheckSSLReturn(int ret) { int result = SSL_get_error(m_ssl.get(), ret); if (result == SSL_ERROR_NONE) return SSLReturnState::SUCCESS; if (result == SSL_ERROR_SYSCALL) { if (errno == 0) return SSLReturnState::SUCCESS; Logger::Get().Fatal( LOGID("SSLConnection"), std::string{"SSL error: "} + std::strerror(errno) ); } if (result == SSL_ERROR_SSL) { const auto errorCode = ERR_get_error(); const std::string errorString = ERR_error_string( errorCode, nullptr ); Logger::Get().Fatal( LOGID("SSLConnection"), "SSL error: " + errorString ); } return SSLReturnState::RETRY; } void SSLConnection::Connect() { if (m_tcpConnection == nullptr) { Logger::Get().Fatal( LOGID("SSLConnection"), "Null tcp connection when attempting ssl connect." ); } int fd = m_tcpConnection->FileDescriptor(); if (!fd) { Logger::Get().Fatal( LOGID("SSLConnection"), "Bad file descriptor (" + std::to_string(fd) + ") when attempting to ssl connect." ); } Disconnect(); m_sslContext = std::shared_ptr( SSL_CTX_new(TLS_client_method()), [](auto p){ SSL_CTX_free(p); } ); m_ssl = std::shared_ptr( SSL_new(m_sslContext.get()), [](auto p){ SSL_free(p); } ); SSL_set_fd(m_ssl.get(), fd); const auto startTime = std::chrono::system_clock::now(); while(true) { const auto currentTime = std::chrono::system_clock::now(); const auto timeDelta = std::chrono::duration_cast( currentTime - startTime ); if (timeDelta > m_connectionTimeout) { Logger::Get().Fatal( LOGID("SSLConnection"), "Timed out while trying to establish SSL connection." ); } ERR_clear_error(); const int status = SSL_connect(m_ssl.get()); if (status == 1) break; if (CheckSSLReturn(status) == SSLReturnState::SUCCESS) break; } } void SSLConnection::Disconnect() { if (m_ssl != nullptr) m_ssl.reset(); if (m_sslContext != nullptr) m_sslContext.reset(); } std::string SSLConnection::Read(size_t amount) { if (m_sslContext == nullptr) { Logger::Get().Fatal( LOGID("SSLConnection"), "Attempted to write over SSL socket without SSL context." ); } std::string result; const auto startTime = std::chrono::system_clock::now(); while(true) { std::string buffer; buffer.resize(amount - result.size()); ERR_clear_error(); int bytesRead = SSL_read(m_ssl.get(), &buffer[0], buffer.size()); if (bytesRead == 0) { if (m_ioTimeout == std::chrono::seconds{0}) return result; const auto currentTime = std::chrono::system_clock::now(); const auto timeDelta = std::chrono::duration_cast( currentTime - startTime ); if (timeDelta > m_ioTimeout) return result; } if (bytesRead > 0) { buffer.resize(bytesRead); result += buffer; if (result.size() == amount) break; // done here. } else { CheckSSLReturn(bytesRead); } } return result; } void SSLConnection::Write(const std::string& data) { if (m_sslContext == nullptr) { Logger::Get().Fatal( LOGID("SSLConnection"), "Attempted to write over SSL socket without SSL context." ); } std::string buffer(data); const auto startTime = std::chrono::system_clock::now(); while(true) { ERR_clear_error(); int bytesWritten = SSL_write( m_ssl.get(), buffer.c_str(), buffer.size() ); if (bytesWritten == buffer.size()) { return; // If we wrote the entire buffer, we're done. } else if (bytesWritten == 0) { const auto currentTime = std::chrono::system_clock::now(); const auto timeDelta = std::chrono::duration_cast( currentTime - startTime ); if ( (timeDelta > m_ioTimeout) || (m_ioTimeout == std::chrono::milliseconds{0}) ) { Logger::Get().Fatal( LOGID("SSLConnection"), "Timed out while trying to write to SSL connection." ); } } else if (bytesWritten > 0) { // If we wrote a partial buffer, pop off what we wrote, try again. if (bytesWritten > buffer.size()) bytesWritten = buffer.size(); buffer.erase(0, bytesWritten); if (buffer.empty()) return; } else { // status < 0 indicates error. CheckSSLReturn(bytesWritten); } } } } // namespace usenetsearch