UsenetSearch/src/TcpConnection.cpp

212 lines
6.3 KiB
C++

/*
Copyright© 2021 John Sennesael
UsenetSearch is Free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
UsenetSearch is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with UsenetSearch. If not, see <https://www.gnu.org/licenses/>.
*/
#include "usenetsearch/Logger.h"
#include "usenetsearch/TcpConnection.h"
#include "usenetsearch/Dns.h"
#include <netinet/in.h> // sockaddr_in
#include <sys/socket.h> // AF_INET etc...
#include <unistd.h> // close(), read(), write()
#include <cerrno>
#include <chrono>
#include <cstdint>
#include <cstring>
#include <string>
#include <thread>
namespace usenetsearch {
TcpConnection::~TcpConnection()
{
Disconnect();
}
void TcpConnection::Connect(const std::string& host, std::uint16_t port)
{
int fd{0};
struct sockaddr_in serv_addr{};
serv_addr.sin_family = AF_INET;
// Resolve host (may resolve to multiple ip's)
const std::vector<struct addrinfo> addresses = DnsResolve(host, port);
if (addresses.empty())
{
Logger::Get().Fatal<DnsResolveException>(
LOGID("TcpConnection"),
"The provided host (" + host + ") did not resolve to an address."
);
}
// If we have an open socket close it.
Disconnect();
// Try each resolved IP in sequence until it works.
const auto startTime = std::chrono::system_clock::now();
for (auto& addr: addresses)
{
while(true)
{
const auto currentTime = std::chrono::system_clock::now();
const auto timeDelta =
std::chrono::duration_cast<std::chrono::milliseconds>(
currentTime - startTime
);
if (timeDelta > m_connectionTimeout)
{
Logger::Get().Fatal<SocketException>(
LOGID("TcpConnection"),
"Timed out while trying to connect to " + host + ":"
+ std::to_string(port) + "."
);
}
fd = socket(addr.ai_family, SOCK_STREAM, 0);
if (fd < 0)
{
Logger::Get().Fatal<SocketException>(
LOGID("TcpConnection"),
"Failed to create socket - Error (" + std::to_string(errno)
+ ") - " + std::strerror(errno)
);
}
if (connect(fd, addr.ai_addr, addr.ai_addrlen) == 0)
{
m_fd = fd;
return;
}
else
{
if ((errno == EINPROGRESS) || (errno == EWOULDBLOCK))
{
close(fd);
std::this_thread::sleep_for(std::chrono::seconds{1});
}
else if (errno == EALREADY)
{
m_fd = fd;
return;
}
else
{
close(fd);
Logger::Get().Fatal<SocketException>(
LOGID("TcpConnection"),
"Failed to connect to " + host + ":"
+ std::to_string(port) + " - Error ("
+ std::to_string(errno) + ") - " + strerror(errno)
);
}
}
}
}
m_fd = fd;
}
void TcpConnection::Disconnect()
{
if (m_fd == 0) return;
close(m_fd);
m_fd = 0;
}
int TcpConnection::FileDescriptor() const
{
return m_fd;
}
std::string TcpConnection::Read(size_t amount)
{
const auto startTime = std::chrono::system_clock::now();
std::string result;
while(true)
{
std::string buffer;
buffer.resize(amount);
const auto bytesRead = read(m_fd, &buffer[0], buffer.size());
if (bytesRead == 0)
{
if (m_ioTimeout == std::chrono::milliseconds{0}) break;
const auto currentTime = std::chrono::system_clock::now();
const auto timeDelta =
std::chrono::duration_cast<std::chrono::milliseconds>(
currentTime - startTime
);
if (timeDelta > m_ioTimeout) break;
}
else if (bytesRead >= 0)
{
buffer.resize(bytesRead);
result += buffer;
if (result.size() == amount) break; // we're done here.
}
else
{
Logger::Get().Fatal<SocketException>(
LOGID("TcpConnection"),
"Error while reading from TCP socket (" + std::to_string(errno)
+ ") - " + std::strerror(errno)
);
}
}
return result;
}
void TcpConnection::Write(const std::string& data)
{
const auto startTime = std::chrono::system_clock::now();
std::string buffer(data);
while(true)
{
auto bytesWritten = write(m_fd, &buffer[0], buffer.size());
if (bytesWritten == 0)
{
const auto currentTime = std::chrono::system_clock::now();
const auto timeDelta =
std::chrono::duration_cast<std::chrono::milliseconds>(
currentTime - startTime
);
if ((timeDelta > m_ioTimeout)
|| m_ioTimeout == std::chrono::milliseconds{0})
{
Logger::Get().Fatal<SocketException>(
LOGID("TcpConnection"),
"Timed out writing to TCP socket."
);
}
}
else if (bytesWritten >= 0)
{
if (bytesWritten > buffer.size()) bytesWritten = buffer.size();
buffer.erase(0, bytesWritten);
if (buffer.empty()) return;
}
else
{
Logger::Get().Fatal<SocketException>(
LOGID("TcpConnection"),
"Error writing to tcp socket (" + std::to_string(errno)
+ ") - " + std::strerror(errno)
);
}
}
}
} // namespace usenetsearch