/* Copyright© 2021 John Sennesael UsenetSearch is Free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. UsenetSearch is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with UsenetSearch. If not, see . */ #include "usenetsearch/Logger.h" #include "usenetsearch/TcpConnection.h" #include "usenetsearch/Dns.h" #include // sockaddr_in #include // AF_INET etc... #include // close(), read(), write() #include #include #include #include #include #include namespace usenetsearch { TcpConnection::~TcpConnection() { Disconnect(); } void TcpConnection::Connect(const std::string& host, std::uint16_t port) { int fd{0}; struct sockaddr_in serv_addr{}; serv_addr.sin_family = AF_INET; // Resolve host (may resolve to multiple ip's) const std::vector addresses = DnsResolve(host, port); if (addresses.empty()) { Logger::Get().Fatal( LOGID("TcpConnection"), "The provided host (" + host + ") did not resolve to an address." ); } // If we have an open socket close it. Disconnect(); // Try each resolved IP in sequence until it works. const auto startTime = std::chrono::system_clock::now(); for (auto& addr: addresses) { while(true) { const auto currentTime = std::chrono::system_clock::now(); const auto timeDelta = std::chrono::duration_cast( currentTime - startTime ); if (timeDelta > m_connectionTimeout) { Logger::Get().Fatal( LOGID("TcpConnection"), "Timed out while trying to connect to " + host + ":" + std::to_string(port) + "." ); } fd = socket(addr.ai_family, SOCK_STREAM, 0); if (fd < 0) { Logger::Get().Fatal( LOGID("TcpConnection"), "Failed to create socket - Error (" + std::to_string(errno) + ") - " + std::strerror(errno) ); } if (connect(fd, addr.ai_addr, addr.ai_addrlen) == 0) { m_fd = fd; return; } else { if ((errno == EINPROGRESS) || (errno == EWOULDBLOCK)) { close(fd); std::this_thread::sleep_for(std::chrono::seconds{1}); } else if (errno == EALREADY) { m_fd = fd; return; } else { close(fd); Logger::Get().Fatal( LOGID("TcpConnection"), "Failed to connect to " + host + ":" + std::to_string(port) + " - Error (" + std::to_string(errno) + ") - " + strerror(errno) ); } } } } m_fd = fd; } void TcpConnection::Disconnect() { if (m_fd == 0) return; close(m_fd); m_fd = 0; } int TcpConnection::FileDescriptor() const { return m_fd; } std::string TcpConnection::Read(size_t amount) { const auto startTime = std::chrono::system_clock::now(); std::string result; while(true) { std::string buffer; buffer.resize(amount); const auto bytesRead = read(m_fd, &buffer[0], buffer.size()); if (bytesRead == 0) { if (m_ioTimeout == std::chrono::milliseconds{0}) break; const auto currentTime = std::chrono::system_clock::now(); const auto timeDelta = std::chrono::duration_cast( currentTime - startTime ); if (timeDelta > m_ioTimeout) break; } else if (bytesRead >= 0) { buffer.resize(bytesRead); result += buffer; if (result.size() == amount) break; // we're done here. } else { Logger::Get().Fatal( LOGID("TcpConnection"), "Error while reading from TCP socket (" + std::to_string(errno) + ") - " + std::strerror(errno) ); } } return result; } void TcpConnection::Write(const std::string& data) { const auto startTime = std::chrono::system_clock::now(); std::string buffer(data); while(true) { auto bytesWritten = write(m_fd, &buffer[0], buffer.size()); if (bytesWritten == 0) { const auto currentTime = std::chrono::system_clock::now(); const auto timeDelta = std::chrono::duration_cast( currentTime - startTime ); if ((timeDelta > m_ioTimeout) || m_ioTimeout == std::chrono::milliseconds{0}) { Logger::Get().Fatal( LOGID("TcpConnection"), "Timed out writing to TCP socket." ); } } else if (bytesWritten >= 0) { if (bytesWritten > buffer.size()) bytesWritten = buffer.size(); buffer.erase(0, bytesWritten); if (buffer.empty()) return; } else { Logger::Get().Fatal( LOGID("TcpConnection"), "Error writing to tcp socket (" + std::to_string(errno) + ") - " + std::strerror(errno) ); } } } } // namespace usenetsearch