Initial commit

This commit is contained in:
John Sennesael 2021-09-19 16:05:16 -05:00
commit e1c1ba031e
15 changed files with 918 additions and 0 deletions

5
.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
[syntax=glob]
*.o
*.*~
*.sw*
build/*

39
CMakeLists.txt Normal file
View File

@ -0,0 +1,39 @@
cmake_minimum_required(VERSION 3.5)
set (CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -ggdb -fno-omit-frame-pointer -fsanitize=address")
set (CMAKE_LINKER_FLAGS_DEBUG "${CMAKE_LINKER_FLAGS_DEBUG} -fno-omit-frame-pointer -fsanitize=address")
#set (CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -ggdb -fno-omit-frame-pointer")
#set (CMAKE_LINKER_FLAGS_DEBUG "${CMAKE_LINKER_FLAGS_DEBUG} -fno-omit-frame-pointer")
if(NOT DEFINED CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "DEBUG" CACHE STRING "Type of build: Debug|Release")
endif()
project(usenetsearch CXX)
set(CMAKE_CXX_STANDARD 17)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
find_package(OpenSSL REQUIRED)
add_executable(usenetsearch
"src/Dns.cpp"
"src/Except.cpp"
"src/IoSocket.cpp"
"src/main.cpp"
"src/SSLConnection.cpp"
"src/TcpConnection.cpp"
"src/UsenetClient.cpp"
)
target_link_libraries(usenetsearch
${OPENSSL_LIBRARIES}
)
target_include_directories(usenetsearch
PRIVATE
include
PUBLIC
${OPENSSL_INCLUDE_DIR}
)

View File

@ -0,0 +1,27 @@
#pragma once
#include <chrono>
#include <string>
#include <vector>
#include <netdb.h> // struct addrinfo
#include "usenetsearch/Except.h"
namespace usenetsearch {
struct DnsResolveException: public UsenetSearchException
{
DnsResolveException(int errorCode, const std::string& message):
UsenetSearchException(errorCode, message){}
virtual ~DnsResolveException() = default;
};
std::vector<struct addrinfo> DnsResolve(
const std::string& host,
int port,
const std::chrono::milliseconds& timeOut = std::chrono::milliseconds{5000}
);
} // namespace usenetsearch

View File

@ -0,0 +1,22 @@
#pragma once
#include <stdexcept>
namespace usenetsearch {
class UsenetSearchException: public std::exception
{
int m_errorCode;
std::string m_message;
public:
UsenetSearchException(int errorCode, const std::string& message);
virtual ~UsenetSearchException() = default;
int Code() const;
virtual const char* what() const noexcept override;
};
} // namespace usenetsearch

View File

@ -0,0 +1,33 @@
#pragma once
#include <chrono>
#include <string>
namespace usenetsearch {
class IoSocket
{
public:
virtual ~IoSocket() = default;
std::chrono::milliseconds ConnectionTimeout() const;
void ConnectionTimeout(const std::chrono::milliseconds& timeOut);
std::chrono::milliseconds IoTimeout() const;
void IoTimeout(const std::chrono::milliseconds& timeOut);
virtual std::string Read(size_t amount) = 0;
std::string ReadUntil(std::string deliminator);
virtual void Write(const std::string& data) = 0;
protected:
std::chrono::milliseconds m_connectionTimeout{10000};
std::chrono::milliseconds m_ioTimeout{5000};
};
} // namespace usenetsearch

View File

@ -0,0 +1,41 @@
#pragma once
#include <memory>
#include <openssl/ssl.h>
#include "usenetsearch/IoSocket.h"
#include "usenetsearch/TcpConnection.h"
namespace usenetsearch {
struct SSLException: public UsenetSearchException
{
SSLException(int errorCode, const std::string& message):
UsenetSearchException(errorCode, message){}
virtual ~SSLException() = default;
};
class SSLConnection : public IoSocket
{
enum class SSLReturnState{ RETRY, SUCCESS };
std::chrono::milliseconds m_connectionTimeout{10000};
std::chrono::milliseconds m_ioTimeout{10000};
std::shared_ptr<SSL> m_ssl;
std::shared_ptr<SSL_CTX> m_sslContext;
std::unique_ptr<TcpConnection> m_tcpConnection;
SSLReturnState CheckSSLReturn(int ret);
public:
SSLConnection(std::unique_ptr<TcpConnection> connection);
void Connect();
void Disconnect();
std::string Read(size_t amount);
void Write(const std::string& data);
};
} // namespace usenetsearch

View File

@ -0,0 +1,37 @@
#pragma once
#include <chrono>
#include <string>
#include "usenetsearch/Except.h"
#include "usenetsearch/IoSocket.h"
namespace usenetsearch {
struct SocketException: public UsenetSearchException
{
SocketException(int errorCode, const std::string& message):
UsenetSearchException(errorCode, message){}
virtual ~SocketException() = default;
};
class TcpConnection: public IoSocket
{
int m_fd{0};
public:
virtual ~TcpConnection();
void Connect(const std::string& host, std::uint16_t port);
void Disconnect();
int FileDescriptor() const;
std::string Read(size_t amount);
void Write(const std::string& data);
};
} // namespace usenetsearch

View File

@ -0,0 +1,51 @@
#pragma once
#include <cstdint>
#include <codecvt>
#include <locale>
#include <memory>
#include <string>
#include "usenetsearch/SSLConnection.h"
#include "usenetsearch/TcpConnection.h"
namespace usenetsearch {
struct UsenetClientException: public UsenetSearchException
{
UsenetClientException(int errorCode, const std::string& message):
UsenetSearchException(errorCode, message){}
virtual ~UsenetClientException() = default;
};
struct NntpMessage
{
std::uint16_t code;
std::wstring message;
};
class UsenetClient
{
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> m_conv;
std::unique_ptr<SSLConnection> m_ssl;
std::unique_ptr<TcpConnection> m_tcp;
bool m_useSSL{false};
bool IsError(const NntpMessage& msg) const;
NntpMessage ReadLine();
void Write(const std::wstring& message);
std::wstring ReadUntil(const std::wstring& deliminator);
public:
void Authenticate(const std::wstring& user, const std::wstring& password);
void Connect(
const std::string& host,
std::uint16_t port,
bool useSSL = false
);
};
} // namespace usenetsearch

77
src/Dns.cpp Normal file
View File

@ -0,0 +1,77 @@
#include <cerrno>
#include <thread>
#include "usenetsearch/Dns.h"
namespace usenetsearch {
std::vector<struct addrinfo> DnsResolve(
const std::string& host,
int port,
const std::chrono::milliseconds& timeOut)
{
std::vector<struct addrinfo> results;
struct addrinfo* firstResult = nullptr;
struct addrinfo hints{};
struct addrinfo* result = nullptr;
hints.ai_family = AF_UNSPEC; // ipv6 or ipv4
hints.ai_flags = AI_V4MAPPED | AI_ALL;
const auto start_time = std::chrono::system_clock::now();
while (true)
{
// Handle timeout.
const auto time = std::chrono::system_clock::now();
const auto deltaT =
std::chrono::duration_cast<std::chrono::milliseconds>(
time - start_time
);
if (deltaT > timeOut) break;
// Try to resolve.
const int getAddrInfoResult = getaddrinfo(
host.c_str(),
std::to_string(port).c_str(),
&hints,
&result
);
firstResult = result;
if (getAddrInfoResult == EAI_AGAIN)
{
std::this_thread::sleep_for(std::chrono::milliseconds{500});
continue;
}
if (getAddrInfoResult == 0)
{
break; // success
}
else
{
throw DnsResolveException(getAddrInfoResult,
"Could not resolve host " + host + ": - Error ("
+ std::to_string(getAddrInfoResult) + ") - "
+ gai_strerror(getAddrInfoResult)
);
}
}
if (result == nullptr)
{
throw DnsResolveException(ETIMEDOUT,
"Timed out trying to resolve host: " + host
);
}
while (result != nullptr)
{
results.emplace_back(*result);
result = result->ai_next;
}
if (firstResult != nullptr)
{
freeaddrinfo(firstResult);
}
return results;
}
} // namespace usenetsearch

21
src/Except.cpp Normal file
View File

@ -0,0 +1,21 @@
#include "usenetsearch/Except.h"
namespace usenetsearch {
UsenetSearchException::UsenetSearchException(int errorCode, const std::string& message)
{
m_errorCode = errorCode;
m_message = message;
}
int UsenetSearchException::Code() const
{
return m_errorCode;
}
const char* UsenetSearchException::what() const noexcept
{
return m_message.c_str();
}
} // namespace usenetsearch

46
src/IoSocket.cpp Normal file
View File

@ -0,0 +1,46 @@
#include <chrono>
#include <string>
#include "usenetsearch/IoSocket.h"
namespace usenetsearch {
std::chrono::milliseconds IoSocket::ConnectionTimeout() const
{
return m_connectionTimeout;
}
void IoSocket::ConnectionTimeout(const std::chrono::milliseconds& timeOut)
{
m_connectionTimeout = timeOut;
}
std::chrono::milliseconds IoSocket::IoTimeout() const
{
return m_ioTimeout;
}
void IoSocket::IoTimeout(const std::chrono::milliseconds& timeOut)
{
m_ioTimeout = timeOut;
}
std::string IoSocket::ReadUntil(std::string deliminator)
{
std::string result;
while(true)
{
std::string buffer = Read(1);
result += buffer;
if (result.length() >= deliminator.length())
{
if (result.substr(result.length() - deliminator.length())
== deliminator)
{
return result;
}
}
}
}
} // namespace usenetsearch

182
src/SSLConnection.cpp Normal file
View File

@ -0,0 +1,182 @@
#include <cerrno>
#include <cstring>
#include <memory>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include "usenetsearch/SSLConnection.h"
namespace usenetsearch {
SSLConnection::SSLConnection(std::unique_ptr<TcpConnection> connection)
{
m_tcpConnection = std::move(connection);
}
SSLConnection::SSLReturnState SSLConnection::CheckSSLReturn(int ret)
{
int result = SSL_get_error(m_ssl.get(), ret);
if (result == SSL_ERROR_NONE) return SSLReturnState::SUCCESS;
if (result == SSL_ERROR_SYSCALL)
{
if (errno == 0) return SSLReturnState::SUCCESS;
throw SSLException(errno, std::string{"SSL error: "}
+ std::strerror(errno));
}
if (result == SSL_ERROR_SSL)
{
const auto errorCode = ERR_get_error();
const std::string errorString = ERR_error_string(
errorCode,
nullptr
);
throw SSLException(errorCode,
"SSL error: " + errorString);
}
return SSLReturnState::RETRY;
}
void SSLConnection::Connect()
{
if (m_tcpConnection == nullptr)
{
throw SSLException(EBADF,
"Null tcp connection when attempting ssl connect."
);
}
int fd = m_tcpConnection->FileDescriptor();
if (!fd)
{
throw SSLException(EBADF,
"Bad file descriptor (" + std::to_string(fd)
+ ") when attempting to ssl connect."
);
}
Disconnect();
m_sslContext = std::shared_ptr<SSL_CTX>(
SSL_CTX_new(TLS_client_method()),
[](auto p){ SSL_CTX_free(p); }
);
m_ssl = std::shared_ptr<SSL>(
SSL_new(m_sslContext.get()),
[](auto p){ SSL_free(p); }
);
SSL_set_fd(m_ssl.get(), fd);
const auto startTime = std::chrono::system_clock::now();
while(true)
{
const auto currentTime = std::chrono::system_clock::now();
const auto timeDelta =
std::chrono::duration_cast<std::chrono::milliseconds>(
currentTime - startTime
);
if (timeDelta > m_connectionTimeout)
{
throw SSLException(
ETIMEDOUT,
"Timed out while trying to establish SSL connection."
);
}
ERR_clear_error();
const int status = SSL_connect(m_ssl.get());
if (status == 1) break;
if (CheckSSLReturn(status) == SSLReturnState::SUCCESS) break;
}
}
void SSLConnection::Disconnect()
{
if (m_ssl != nullptr) m_ssl.reset();
if (m_sslContext != nullptr) m_sslContext.reset();
}
std::string SSLConnection::Read(size_t amount)
{
if (m_sslContext == nullptr)
{
throw SSLException(EBADF,
"Attempted to write over SSL socket without SSL context."
);
}
std::string result;
const auto startTime = std::chrono::system_clock::now();
while(true)
{
std::string buffer;
buffer.resize(amount - result.size());
const auto currentTime = std::chrono::system_clock::now();
const auto timeDelta =
std::chrono::duration_cast<std::chrono::milliseconds>(
currentTime - startTime
);
if (timeDelta > m_ioTimeout) return result;
ERR_clear_error();
int bytesRead = SSL_read(m_ssl.get(), &buffer[0], buffer.size());
if (bytesRead > 0)
{
buffer.resize(bytesRead);
result += buffer;
if (result.size() == amount) break; // done here.
}
else
{
CheckSSLReturn(bytesRead);
}
}
return result;
}
void SSLConnection::Write(const std::string& data)
{
if (m_sslContext == nullptr)
{
throw SSLException(EBADF,
"Attempted to write over SSL socket without SSL context."
);
}
std::string buffer(data);
const auto startTime = std::chrono::system_clock::now();
while(true)
{
const auto currentTime = std::chrono::system_clock::now();
const auto timeDelta =
std::chrono::duration_cast<std::chrono::milliseconds>(
currentTime - startTime
);
if (timeDelta > m_ioTimeout)
{
throw SSLException(
ETIMEDOUT,
"Timed out while trying to write to SSL connection."
);
}
ERR_clear_error();
int bytesWritten = SSL_write(
m_ssl.get(),
buffer.c_str(),
buffer.size()
);
if (bytesWritten == buffer.size())
{
return; // If we wrote the entire buffer, we're done.
}
else if (bytesWritten > 0)
{
// If we wrote a partial buffer, pop off what we wrote, try again.
if (bytesWritten > buffer.size()) bytesWritten = buffer.size();
buffer.erase(0, bytesWritten);
if (buffer.empty()) return;
}
else
{
// status < 0 indicates error.
CheckSSLReturn(bytesWritten);
}
}
}
} // namespace usenetsearch

177
src/TcpConnection.cpp Normal file
View File

@ -0,0 +1,177 @@
#include <cerrno>
#include <chrono>
#include <cstdint>
#include <cstring>
#include <string>
#include <thread>
#include <unistd.h> // close(), read(), write()
#include "usenetsearch/Dns.h"
#include "usenetsearch/TcpConnection.h"
namespace usenetsearch {
TcpConnection::~TcpConnection()
{
Disconnect();
}
void TcpConnection::Connect(const std::string& host, std::uint16_t port)
{
int fd{0};
struct sockaddr_in serv_addr{};
serv_addr.sin_family = AF_INET;
// Resolve host (may resolve to multiple ip's)
const std::vector<struct addrinfo> addresses = DnsResolve(host, port);
if (addresses.empty())
{
throw DnsResolveException(EDESTADDRREQ,
"The provided host (" + host + ") did not resolve to an address."
);
}
// If we have an open socket close it.
Disconnect();
// Try each resolved IP in sequence until it works.
const auto startTime = std::chrono::system_clock::now();
for (auto& addr: addresses)
{
while(true)
{
const auto currentTime = std::chrono::system_clock::now();
const auto timeDelta =
std::chrono::duration_cast<std::chrono::milliseconds>(
currentTime - startTime
);
if (timeDelta > m_connectionTimeout)
{
throw SocketException(
ETIMEDOUT,
"Timed out while trying to connect to " + host + ":"
+ std::to_string(port) + "."
);
}
fd = socket(addr.ai_family, SOCK_STREAM, 0);
if (fd < 0)
{
throw SocketException(errno,
"Failed to create socket - Error (" + std::to_string(errno)
+ ") - " + std::strerror(errno)
);
}
if (connect(fd, addr.ai_addr, addr.ai_addrlen) == 0)
{
m_fd = fd;
return;
}
else
{
if ((errno == EINPROGRESS) || (errno == EWOULDBLOCK))
{
close(fd);
std::this_thread::sleep_for(std::chrono::seconds{1});
}
else if (errno == EALREADY)
{
m_fd = fd;
return;
}
else
{
close(fd);
throw SocketException(errno,
"Failed to connect to " + host + ":"
+ std::to_string(port) + " - Error ("
+ std::to_string(errno) + ") - " + strerror(errno)
);
}
}
}
}
m_fd = fd;
}
void TcpConnection::Disconnect()
{
if (m_fd == 0) return;
close(m_fd);
m_fd = 0;
}
int TcpConnection::FileDescriptor() const
{
return m_fd;
}
std::string TcpConnection::Read(size_t amount)
{
const auto startTime = std::chrono::system_clock::now();
std::string result;
while(true)
{
std::string buffer;
buffer.resize(amount);
const auto currentTime = std::chrono::system_clock::now();
const auto timeDelta =
std::chrono::duration_cast<std::chrono::milliseconds>(
currentTime - startTime
);
if (timeDelta > m_ioTimeout) break;
const auto bytesRead = read(m_fd, &buffer[0], buffer.size());
if (bytesRead >= 0)
{
buffer.resize(bytesRead);
result += buffer;
if (result.size() == amount) break; // we're done here.
}
else
{
throw SocketException(errno,
"Error while reading from TCP socket (" + std::to_string(errno)
+ ") - " + std::strerror(errno)
);
}
}
return result;
}
void TcpConnection::Write(const std::string& data)
{
const auto startTime = std::chrono::system_clock::now();
std::string buffer(data);
while(true)
{
const auto currentTime = std::chrono::system_clock::now();
const auto timeDelta =
std::chrono::duration_cast<std::chrono::milliseconds>(
currentTime - startTime
);
if (timeDelta > m_ioTimeout)
{
throw SocketException(ETIMEDOUT,
"Timed out writing to TCP socket."
);
}
auto bytesWritten = write(m_fd, &buffer[0], buffer.size());
if (bytesWritten >= 0)
{
if (bytesWritten > buffer.size()) bytesWritten = buffer.size();
buffer.erase(0, bytesWritten);
if (buffer.empty()) return;
}
else
{
throw SocketException(errno,
"Error writing to tcp socket (" + std::to_string(errno)
+ ") - " + std::strerror(errno)
);
}
}
}
} // namespace usenetsearch

130
src/UsenetClient.cpp Normal file
View File

@ -0,0 +1,130 @@
#include <codecvt>
#include <locale>
#include <memory>
#include <string>
#include "usenetsearch/Except.h"
#include "usenetsearch/UsenetClient.h"
namespace usenetsearch {
void UsenetClient::Authenticate(
const std::wstring& user,
const std::wstring& password)
{
// Send user name
Write(L"AUTHINFO USER " + user + L"\r\n");
auto response = ReadLine();
if (IsError(response))
{
throw UsenetClientException(
response.code,
"Error authenticating with NNTP server: "
+ m_conv.to_bytes(response.message)
);
}
// Send password
Write(L"AUTHINFO PASS " + password + L"\r\n");
response = ReadLine();
if (IsError(response))
{
throw UsenetClientException(
response.code,
"Error authenticating with NNTP server: "
+ m_conv.to_bytes(response.message)
);
}
}
void UsenetClient::Connect(
const std::string& host,
std::uint16_t port,
bool useSSL)
{
// Establish connection.
m_useSSL = useSSL;
try
{
m_tcp = std::make_unique<usenetsearch::TcpConnection>();
m_tcp->Connect(host, port);
if (useSSL)
{
m_ssl = std::make_unique<SSLConnection>(std::move(m_tcp));
m_ssl->Connect();
}
}
catch (const UsenetSearchException& e)
{
throw UsenetClientException(
e.Code(),
"Error while trying to connect to host: " + host + ":"
+ std::to_string(port) + " - " + e.what()
);
}
// Read server banner.
const auto serverHello = ReadLine();
if (IsError(serverHello))
{
throw UsenetClientException(
serverHello.code,
"Error received from NNTP server: "
+ m_conv.to_bytes(serverHello.message)
);
}
}
bool UsenetClient::IsError(const NntpMessage& msg) const
{
if (msg.code >= 400) return true;
return false;
}
NntpMessage UsenetClient::ReadLine()
{
NntpMessage result{};
std::wstring line;
line = ReadUntil(L"\r\n");
if (line.length() < 4)
{
throw UsenetSearchException(EPROTONOSUPPORT,
"NNTP protocol error - invalid response from server: "
+ m_conv.to_bytes(line));
}
std::wstring codeStr = line.substr(0, 3);
result.code = std::stoi(codeStr);
result.message = line.substr(4, line.length());
return result;
}
std::wstring UsenetClient::ReadUntil(const std::wstring& deliminator)
{
std::wstring result;
const std::string deliminatorStr = m_conv.to_bytes(deliminator);
std::string resultStr;
if (m_useSSL)
{
resultStr = m_ssl->ReadUntil(deliminatorStr);
}
else
{
resultStr = m_tcp->ReadUntil(deliminatorStr);
}
result = m_conv.from_bytes(resultStr);
return result;
}
void UsenetClient::Write(const std::wstring& message)
{
const std::string toSend = m_conv.to_bytes(message);
if (m_useSSL)
{
m_ssl->Write(toSend);
}
else
{
m_tcp->Write(toSend);
}
}
} // namespace usenetsearch

30
src/main.cpp Normal file
View File

@ -0,0 +1,30 @@
#include <iostream>
#include <memory>
#include "usenetsearch/UsenetClient.h"
int main(int argc, char* argv[])
{
(void) argc;
(void) argv;
std::string host = "news.newshosting.com";
std::uint16_t port = 443;
bool useSSL = true;
usenetsearch::UsenetClient client;
try
{
client.Connect(host, port, useSSL);
client.Authenticate(L"xxxxxxx", L"yyyyy");
}
catch (const std::exception& e)
{
std::cerr << e.what() << std::endl;;
return 1;
}
std::cout << "success." << std::endl;
return 0;
}