UsenetSearch/src/Serialize.cpp

587 lines
15 KiB
C++

/*
UsenetSearch is Free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
UsenetSearch is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with UsenetSearch. If not, see <https://www.gnu.org/licenses/>.
*/
#include "usenetsearch/Serialize.h"
#include "usenetsearch/Application.h"
#include "usenetsearch/Database.h"
#include "usenetsearch/Logger.h"
#include "usenetsearch/ScopeExit.h"
#include "usenetsearch/UsenetClient.h"
#include <fcntl.h>
#include <sys/file.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>
#include <cerrno>
#include <chrono>
#include <codecvt>
#include <cstdint>
#include <cstring>
#include <locale>
#include <fstream>
#include <thread>
#include <vector>
namespace usenetsearch {
// Class implementation --------------------------------------------------------
SerializableFile::SerializableFile(bool lockOnOpen): m_lockOnOpen(lockOnOpen)
{
}
SerializableFile::~SerializableFile()
{
Close();
}
void SerializableFile::FileLock()
{
if (!m_fd)
{
Logger::Get().Fatal<FileIOException>(
LOGID("SerializableFile"),
"Attempt to write to a file (" + m_fileName.string()
+ " ) that isn't open."
);
}
int ret{EWOULDBLOCK};
while ((ret == EWOULDBLOCK) || (ret == EINTR))
{
if (Application::Get().ShouldStop())
{
Logger::Get().Fatal<FileIOException>(
LOGID("SerializableFile"),
"Interrupt while trying to establish a lock on file ("
+ m_fileName.string() + ")."
);
}
ret = flock(m_fd, LOCK_EX);
if ((ret == EWOULDBLOCK) || (ret == EINTR))
{
std::this_thread::sleep_for(std::chrono::milliseconds(50));
}
}
if (ret != 0)
{
Logger::Get().Fatal<FileIOException>(
LOGID("SerializableFile"),
"Error trying to lock file: " + m_fileName.string()
+ " : " + std::strerror(errno)
);
}
m_locked = true;
}
void SerializableFile::FileUnlock()
{
if (!m_fd)
{
Logger::Get().Fatal<FileIOException>(
LOGID("SerializableFile"),
"Attempt to write to a file (" + m_fileName.string()
+ " ) that isn't open."
);
}
const int ret = flock(m_fd, LOCK_UN);
if (ret != 0)
{
Logger::Get().Fatal<FileIOException>(
LOGID("SerializableFile"),
"Error trying to unlock file: " + m_fileName.string()
+ " : " + std::strerror(errno)
);
}
m_locked = false;
}
bool SerializableFile::IsOpen() const
{
return m_fd != 0;
}
void SerializableFile::Close()
{
if (m_fd)
{
if (m_lockOnOpen) FileUnlock();
close(m_fd);
m_fd = 0;
}
}
void SerializableFile::Open(const std::string& fileName)
{
#ifdef __linux__
const int flags{O_RDWR|O_NOATIME|O_CREAT};
#else
const int flags{O_RDWR|O_CREAT};
#endif
int ret = open(fileName.c_str(), flags, 0644);
if (ret < 0)
{
Logger::Get().Fatal<FileIOException>(
LOGID("SerializableFile"),
"Could not open file: " + fileName
+ " : " + std::string{std::strerror(errno)}
);
}
m_fd = ret;
if (m_lockOnOpen) FileLock();
}
void SerializableFile::RangeLock(size_t offset, size_t size) const
{
if (!m_fd)
{
Logger::Get().Fatal<FileIOException>(
LOGID("SerializableFile"),
"Attempt to write to a file (" + m_fileName.string()
+ " ) that isn't open."
);
}
const size_t pos = Tell();
Seek(offset, std::ios_base::beg);
if (lockf(m_fd, F_LOCK, size) == -1)
{
Logger::Get().Fatal<FileIOException>(
LOGID("SerializableFile"),
"Could not acquire a write lock on file: " + m_fileName.string()
+ " : " + std::strerror(errno)
);
}
Seek(pos, std::ios_base::beg);
}
void SerializableFile::RangeUnlock(size_t offset, size_t size) const
{
if (!m_fd)
{
Logger::Get().Fatal<FileIOException>(
LOGID("SerializableFile"),
"Attempt to write to a file (" + m_fileName.string()
+ " ) that isn't open."
);
}
const size_t pos = Tell();
Seek(offset, std::ios_base::beg);
if (lockf(m_fd, F_ULOCK, size) == -1)
{
Logger::Get().Fatal<FileIOException>(
LOGID("SerializableFile"),
"Could not release a write lock on file: " + m_fileName.string()
+ " : " + std::strerror(errno)
);
}
Seek(pos, std::ios_base::beg);
}
std::string SerializableFile::ReadStr(size_t size) const
{
if (!m_fd)
{
Logger::Get().Fatal<FileIOException>(
LOGID("SerializableFile"),
"Attempt to write to a file (" + m_fileName.string()
+ " ) that isn't open."
);
}
const auto startPos = Tell();
RangeLock(startPos, size);
ScopeExit unlock([&](){
RangeUnlock(startPos, size);
});
size_t bytesRead{0};
std::string result(size, '\0');
while (bytesRead < size)
{
const auto readNow = read(m_fd, &result[0], size);
if (readNow == -1)
{
Logger::Get().Fatal<FileIOException>(
LOGID("SerializableFile"),
"Error while reading from file: " + m_fileName.string()
+ " : " + std::strerror(errno)
);
}
bytesRead += readNow;
}
return result;
}
std::uint8_t SerializableFile::ReadInt8() const
{
const std::string str = ReadStr(sizeof(std::uint8_t));
const std::uint8_t* result = reinterpret_cast<const std::uint8_t*>(
str.c_str()
);
return *result;
}
std::uint32_t SerializableFile::ReadInt32() const
{
const std::string str = ReadStr(sizeof(std::uint32_t));
const std::uint32_t* result = reinterpret_cast<const std::uint32_t*>(
str.c_str()
);
return *result;
}
std::uint64_t SerializableFile::ReadInt64() const
{
const std::string str = ReadStr(sizeof(std::uint64_t));
const std::uint64_t* result = reinterpret_cast<const std::uint64_t*>(
str.c_str()
);
return *result;
}
void SerializableFile::Seek(
size_t offset,
std::ios_base::seekdir direction) const
{
if (!m_fd)
{
Logger::Get().Fatal<FileIOException>(
LOGID("SerializableFile"),
"Attempt to write to a file (" + m_fileName.string()
+ " ) that isn't open."
);
}
int whence{0};
switch (direction)
{
case std::ios_base::beg:
whence = SEEK_SET;
break;
case std::ios_base::end:
whence = SEEK_END;
break;
case std::ios_base::cur:
default:
whence = SEEK_CUR;
break;
}
const auto newOffset = lseek(m_fd, offset, whence);
if (newOffset == -1)
{
Logger::Get().Fatal<FileIOException>(
LOGID("SerializableFile"),
"Could not set cursor position in file: " + m_fileName.string()
+ " : " + std::strerror(errno)
);
}
}
size_t SerializableFile::Size() const
{
const auto savedPos = Tell();
Seek(0, std::ios_base::end);
const auto endPos = Tell();
Seek(savedPos, std::ios_base::beg);
return endPos;
}
size_t SerializableFile::Tell() const
{
if (!m_fd)
{
Logger::Get().Fatal<FileIOException>(
LOGID("SerializableFile"),
"Attempt to write to a file (" + m_fileName.string()
+ " ) that isn't open."
);
}
const auto result = lseek(m_fd, 0, SEEK_CUR);
if (result == -1)
{
Logger::Get().Fatal<FileIOException>(
LOGID("SerializableFile"),
"Could not get cursor position in file: " + m_fileName.string()
+ " : " + std::strerror(errno)
);
}
return result;
}
void SerializableFile::Write(std::uint8_t val) const
{
Write(reinterpret_cast<const char*>(&val), sizeof(std::uint8_t));
}
void SerializableFile::Write(std::uint64_t val) const
{
Write(reinterpret_cast<const char*>(&val), sizeof(std::uint64_t));
}
void SerializableFile::Write(std::uint32_t val) const
{
Write(reinterpret_cast<const char*>(&val), sizeof(std::uint32_t));
}
void SerializableFile::Write(const std::string& str) const
{
Write(str.c_str(), str.size());
}
void SerializableFile::Write(const char* bytes, size_t size) const
{
if (!m_fd)
{
Logger::Get().Fatal<FileIOException>(
LOGID("SerializableFile"),
"Attempt to write to a file (" + m_fileName.string()
+ " ) that isn't open."
);
}
size_t written{0};
while(written < size)
{
const auto writtenNow = write(
m_fd,
reinterpret_cast<const void*>(bytes),
size
);
if (writtenNow > 0)
{
written += writtenNow;
continue;
}
else if (writtenNow == 0)
{
if (errno == 0) continue;
}
Logger::Get().Fatal<FileIOException>(
LOGID("SerializableFile"),
"Failure while writing to file: " + m_fileName.string()
+ " : " + std::strerror(errno)
);
}
}
// Serialization of primitive types --------------------------------------------
SerializableFile& operator<<(SerializableFile& out, const std::uint8_t& obj)
{
out.Write(obj);
return out;
}
SerializableFile& operator>>(SerializableFile& in, std::uint8_t& obj)
{
obj = in.ReadInt8();
return in;
}
SerializableFile& operator<<(SerializableFile& out, const std::uint32_t& obj)
{
out.Write(obj);
return out;
}
SerializableFile& operator>>(SerializableFile& in, std::uint32_t& obj)
{
obj = in.ReadInt32();
return in;
}
SerializableFile& operator<<(SerializableFile& out, const std::uint64_t& obj)
{
out.Write(obj);
return out;
}
SerializableFile& operator>>(SerializableFile& in, std::uint64_t& obj)
{
obj = in.ReadInt64();
return in;
}
// Serialization of stl classes ------------------------------------------------
SerializableFile& operator<<(SerializableFile& out, const std::string& str)
{
out.Write(str.size());
out.Write(str);
return out;
}
SerializableFile& operator>>(SerializableFile& in, std::string& str)
{
const std::uint64_t size = in.ReadInt64();
str = in.ReadStr(size);
return in;
}
SerializableFile& operator<<(SerializableFile& out, const std::wstring& str)
{
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> conv;
const std::string narrowString = conv.to_bytes(str);
out.Write(narrowString.size());
out.Write(narrowString);
return out;
}
SerializableFile& operator>>(SerializableFile& in, std::wstring& str)
{
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> conv;
const std::uint64_t size = in.ReadInt64();
const std::string narrowString = in.ReadStr(size);
str = conv.from_bytes(narrowString);
return in;
}
SerializableFile& operator<<(
SerializableFile& out, const std::vector<std::uint64_t>& arr)
{
out << arr.size();
for (const auto value: arr) out << value;
return out;
}
SerializableFile& operator>>(
SerializableFile& in, std::vector<std::uint64_t>& arr)
{
std::uint64_t size;
in >> size;
for (std::uint64_t n = 0; n != size; ++n)
{
std::uint64_t value;
in >> value;
arr.emplace_back(value);
}
return in;
}
// Serialization of usenetsearch classes ---------------------------------------
SerializableFile& operator<<(
SerializableFile& out, const ArticleEntry& obj)
{
out.Write(std::uint8_t{1}); // start of heading
out.Write(std::uint8_t{2}); // start of text
for (std::uint8_t i: obj.hash)
{
out.Write(i);
}
out << obj.newsgroupID;
out << obj.articleID;
out.Write(std::uint8_t{3}); // end of text
out.Write(std::uint8_t{4}); // end of transmission
return out;
}
SerializableFile& operator>>(SerializableFile& in, ArticleEntry& obj)
{
std::uint8_t SOH{};
std::uint8_t STX{};
std::uint8_t ETX{};
std::uint8_t EOT{};
in >> SOH;
in >> STX;
if ((SOH != 1) || (STX != 2))
{
Logger::Get().Fatal<SerializeException>(
LOGID("SerializableFile"),
"Bad magic number in entry header."
);
}
for (std::uint8_t i = 0; i != 16; ++i)
{
in >> obj.hash[i];
}
in >> obj.newsgroupID;
in >> obj.articleID;
in >> ETX;
in >> EOT;
if ((ETX != 3) || (EOT != 4))
{
Logger::Get().Fatal<SerializeException>(
LOGID("SerializableFile"),
"Bad magic number in entry footer."
);
}
return in;
}
SerializableFile& operator<<(SerializableFile& out, const NntpHeader& obj)
{
out.Write(obj.articleID);
out << obj.subject;
return out;
}
SerializableFile& operator>>(SerializableFile& in, NntpHeader& obj)
{
in >> obj.articleID;
in >> obj.subject;
return in;
}
SerializableFile& operator<<(SerializableFile& out, const NntpListEntry& obj)
{
out.Write(std::uint8_t{1}); // start of heading
out.Write(std::uint8_t{2}); // start of text
out << obj.id;
out << obj.lastIndexedArticle;
out << obj.count;
out << obj.high;
out << obj.low;
out << obj.name;
out << obj.status;
out.Write(std::uint8_t{3}); // end of text
out.Write(std::uint8_t{4}); // end of transmission
return out;
}
SerializableFile& operator>>(SerializableFile& in, NntpListEntry& obj)
{
std::uint8_t SOH{};
std::uint8_t STX{};
std::uint8_t ETX{};
std::uint8_t EOT{};
in >> SOH;
in >> STX;
if ((SOH != 1) || (STX != 2))
{
Logger::Get().Fatal<SerializeException>(
LOGID("SerializableFile"),
"Bad magic number in NNTP entry header."
);
}
in >> obj.id;
in >> obj.lastIndexedArticle;
in >> obj.count;
in >> obj.high;
in >> obj.low;
in >> obj.name;
in >> obj.status;
in >> ETX;
in >> EOT;
if ((ETX != 3) || (EOT != 4))
{
Logger::Get().Fatal<SerializeException>(
LOGID("SerializableFile"),
"Bad magic number in NNTP entry footer."
);
}
return in;
}
} // namespace usenetsearch