212 lines
6.3 KiB
C++
212 lines
6.3 KiB
C++
/*
|
|
Copyright© 2021 John Sennesael
|
|
|
|
UsenetSearch is Free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
UsenetSearch is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with UsenetSearch. If not, see <https://www.gnu.org/licenses/>.
|
|
*/
|
|
|
|
#include "usenetsearch/Logger.h"
|
|
#include "usenetsearch/TcpConnection.h"
|
|
|
|
#include "usenetsearch/Dns.h"
|
|
|
|
#include <netinet/in.h> // sockaddr_in
|
|
#include <sys/socket.h> // AF_INET etc...
|
|
#include <unistd.h> // close(), read(), write()
|
|
|
|
#include <cerrno>
|
|
#include <chrono>
|
|
#include <cstdint>
|
|
#include <cstring>
|
|
#include <string>
|
|
#include <thread>
|
|
|
|
namespace usenetsearch {
|
|
|
|
TcpConnection::~TcpConnection()
|
|
{
|
|
Disconnect();
|
|
}
|
|
|
|
void TcpConnection::Connect(const std::string& host, std::uint16_t port)
|
|
{
|
|
int fd{0};
|
|
struct sockaddr_in serv_addr{};
|
|
serv_addr.sin_family = AF_INET;
|
|
|
|
// Resolve host (may resolve to multiple ip's)
|
|
const std::vector<struct addrinfo> addresses = DnsResolve(host, port);
|
|
if (addresses.empty())
|
|
{
|
|
Logger::Get().Fatal<DnsResolveException>(
|
|
LOGID("TcpConnection"),
|
|
"The provided host (" + host + ") did not resolve to an address."
|
|
);
|
|
}
|
|
|
|
// If we have an open socket close it.
|
|
Disconnect();
|
|
|
|
// Try each resolved IP in sequence until it works.
|
|
const auto startTime = std::chrono::system_clock::now();
|
|
for (auto& addr: addresses)
|
|
{
|
|
while(true)
|
|
{
|
|
const auto currentTime = std::chrono::system_clock::now();
|
|
const auto timeDelta =
|
|
std::chrono::duration_cast<std::chrono::milliseconds>(
|
|
currentTime - startTime
|
|
);
|
|
if (timeDelta > m_connectionTimeout)
|
|
{
|
|
Logger::Get().Fatal<SocketException>(
|
|
LOGID("TcpConnection"),
|
|
"Timed out while trying to connect to " + host + ":"
|
|
+ std::to_string(port) + "."
|
|
);
|
|
}
|
|
fd = socket(addr.ai_family, SOCK_STREAM, 0);
|
|
if (fd < 0)
|
|
{
|
|
Logger::Get().Fatal<SocketException>(
|
|
LOGID("TcpConnection"),
|
|
"Failed to create socket - Error (" + std::to_string(errno)
|
|
+ ") - " + std::strerror(errno)
|
|
);
|
|
}
|
|
|
|
if (connect(fd, addr.ai_addr, addr.ai_addrlen) == 0)
|
|
{
|
|
m_fd = fd;
|
|
return;
|
|
}
|
|
else
|
|
{
|
|
if ((errno == EINPROGRESS) || (errno == EWOULDBLOCK))
|
|
{
|
|
close(fd);
|
|
std::this_thread::sleep_for(std::chrono::seconds{1});
|
|
}
|
|
else if (errno == EALREADY)
|
|
{
|
|
m_fd = fd;
|
|
return;
|
|
}
|
|
else
|
|
{
|
|
close(fd);
|
|
Logger::Get().Fatal<SocketException>(
|
|
LOGID("TcpConnection"),
|
|
"Failed to connect to " + host + ":"
|
|
+ std::to_string(port) + " - Error ("
|
|
+ std::to_string(errno) + ") - " + strerror(errno)
|
|
);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
m_fd = fd;
|
|
}
|
|
|
|
void TcpConnection::Disconnect()
|
|
{
|
|
if (m_fd == 0) return;
|
|
close(m_fd);
|
|
m_fd = 0;
|
|
}
|
|
|
|
int TcpConnection::FileDescriptor() const
|
|
{
|
|
return m_fd;
|
|
}
|
|
|
|
std::string TcpConnection::Read(size_t amount)
|
|
{
|
|
const auto startTime = std::chrono::system_clock::now();
|
|
std::string result;
|
|
while(true)
|
|
{
|
|
std::string buffer;
|
|
buffer.resize(amount);
|
|
const auto bytesRead = read(m_fd, &buffer[0], buffer.size());
|
|
if (bytesRead == 0)
|
|
{
|
|
if (m_ioTimeout == std::chrono::milliseconds{0}) break;
|
|
const auto currentTime = std::chrono::system_clock::now();
|
|
const auto timeDelta =
|
|
std::chrono::duration_cast<std::chrono::milliseconds>(
|
|
currentTime - startTime
|
|
);
|
|
if (timeDelta > m_ioTimeout) break;
|
|
}
|
|
else if (bytesRead >= 0)
|
|
{
|
|
buffer.resize(bytesRead);
|
|
result += buffer;
|
|
if (result.size() == amount) break; // we're done here.
|
|
}
|
|
else
|
|
{
|
|
Logger::Get().Fatal<SocketException>(
|
|
LOGID("TcpConnection"),
|
|
"Error while reading from TCP socket (" + std::to_string(errno)
|
|
+ ") - " + std::strerror(errno)
|
|
);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
void TcpConnection::Write(const std::string& data)
|
|
{
|
|
const auto startTime = std::chrono::system_clock::now();
|
|
std::string buffer(data);
|
|
while(true)
|
|
{
|
|
auto bytesWritten = write(m_fd, &buffer[0], buffer.size());
|
|
if (bytesWritten == 0)
|
|
{
|
|
const auto currentTime = std::chrono::system_clock::now();
|
|
const auto timeDelta =
|
|
std::chrono::duration_cast<std::chrono::milliseconds>(
|
|
currentTime - startTime
|
|
);
|
|
if ((timeDelta > m_ioTimeout)
|
|
|| m_ioTimeout == std::chrono::milliseconds{0})
|
|
{
|
|
Logger::Get().Fatal<SocketException>(
|
|
LOGID("TcpConnection"),
|
|
"Timed out writing to TCP socket."
|
|
);
|
|
}
|
|
}
|
|
else if (bytesWritten >= 0)
|
|
{
|
|
if (bytesWritten > buffer.size()) bytesWritten = buffer.size();
|
|
buffer.erase(0, bytesWritten);
|
|
if (buffer.empty()) return;
|
|
}
|
|
else
|
|
{
|
|
Logger::Get().Fatal<SocketException>(
|
|
LOGID("TcpConnection"),
|
|
"Error writing to tcp socket (" + std::to_string(errno)
|
|
+ ") - " + std::strerror(errno)
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace usenetsearch
|