allow suppling encoding to StatefulEncoder and StatefulDecoder (and tests!)

This commit is contained in:
jesopo 2020-03-11 17:38:30 +00:00
parent 09087d63d6
commit 86fad88f92
5 changed files with 35 additions and 21 deletions

View File

@ -2,7 +2,8 @@ import typing
from .protocol import Line, tokenise
class StatefulDecoder(object):
def __init__(self, fallback: str="iso-8859"):
def __init__(self, encoding: str="utf8", fallback: str="iso-8859"):
self._encoding = encoding
self._fallback = fallback
self.clear()
@ -23,13 +24,14 @@ class StatefulDecoder(object):
decode_lines: typing.List[str] = []
for line in lines:
try:
decode_lines.append(line.decode("utf8"))
decode_lines.append(line.decode(self._encoding))
except UnicodeDecodeError as e:
decode_lines.append(line.decode(self._fallback))
return [tokenise(l) for l in decode_lines]
class StatefulEncoder(object):
def __init__(self):
def __init__(self, encoding: str="utf8"):
self._encoding = encoding
self.clear()
def clear(self):
@ -40,7 +42,7 @@ class StatefulEncoder(object):
return self._buffer
def push(self, line: Line):
self._buffer += f"{line.format()}\r\n".encode("utf8")
self._buffer += f"{line.format()}\r\n".encode(self._encoding)
self._buffered_lines.append(line)
def pop(self, byte_count: int):

View File

@ -1,7 +1,7 @@
import unittest
import irctokens
class TestTags(unittest.TestCase):
class FormatTestTags(unittest.TestCase):
def test(self):
line = irctokens.format("PRIVMSG", ["#channel", "hello"],
tags={"id": "\\" + " " + ";" + "\r\n"})
@ -21,18 +21,18 @@ class TestTags(unittest.TestCase):
tags={"a": ""})
self.assertEqual(line, "@a PRIVMSG #channel hello")
class TestSource(unittest.TestCase):
class FormatTestSource(unittest.TestCase):
def test(self):
line = irctokens.format("PRIVMSG", ["#channel", "hello"],
source="nick!user@host")
self.assertEqual(line, ":nick!user@host PRIVMSG #channel hello")
class TestCommand(unittest.TestCase):
class FormatTestCommand(unittest.TestCase):
def test_lowercase(self):
line = irctokens.format("privmsg")
self.assertEqual(line, "PRIVMSG")
class TestTrailing(unittest.TestCase):
class FormatTestTrailing(unittest.TestCase):
def test_space(self):
line = irctokens.format("PRIVMSG", ["#channel", "hello world"])
self.assertEqual(line, "PRIVMSG #channel :hello world")

View File

@ -1,7 +1,7 @@
import unittest
import irctokens
class TestPartial(unittest.TestCase):
class DecodeTestPartial(unittest.TestCase):
def test(self):
d = irctokens.StatefulDecoder()
lines = d.push(b"PRIVMSG ")
@ -12,7 +12,7 @@ class TestPartial(unittest.TestCase):
line = irctokens.tokenise("PRIVMSG #channel hello")
self.assertEqual(lines, [line])
class TestMultiple(unittest.TestCase):
class DecodeTestMultiple(unittest.TestCase):
def test(self):
d = irctokens.StatefulDecoder()
lines = d.push(b"PRIVMSG #channel1 hello\r\n"
@ -24,15 +24,20 @@ class TestMultiple(unittest.TestCase):
self.assertEqual(lines[0], line1)
self.assertEqual(lines[1], line2)
class TestFallback(unittest.TestCase):
class DecodeTestEncoding(unittest.TestCase):
def test(self):
d = irctokens.StatefulDecoder(encoding="iso-8859-2")
lines = d.push("PRIVMSG #channel :hello Č\r\n".encode("iso-8859-2"))
line = irctokens.tokenise("PRIVMSG #channel :hello Č")
self.assertEqual(lines[0], line)
def test_fallback(self):
d = irctokens.StatefulDecoder(fallback="latin-1")
lines = d.push("PRIVMSG #channel hélló\r\n".encode("latin-1"))
self.assertEqual(len(lines), 1)
line = irctokens.tokenise("PRIVMSG #channel hélló")
self.assertEqual(lines[0], line)
class TestEmpty(unittest.TestCase):
class DecodeTestEmpty(unittest.TestCase):
def test_immediate(self):
d = irctokens.StatefulDecoder()
lines = d.push(b"")
@ -44,7 +49,7 @@ class TestEmpty(unittest.TestCase):
lines = d.push(b"")
self.assertIsNone(lines)
class TestClear(unittest.TestCase):
class DecodeTestClear(unittest.TestCase):
def test(self):
d = irctokens.StatefulDecoder()
d.push(b"PRIVMSG ")

View File

@ -1,14 +1,14 @@
import unittest
import irctokens
class TestPush(unittest.TestCase):
class EncodeTestPush(unittest.TestCase):
def test(self):
e = irctokens.StatefulEncoder()
line = irctokens.tokenise("PRIVMSG #channel hello")
e.push(line)
self.assertEqual(e.pending(), b"PRIVMSG #channel hello\r\n")
class TestPop(unittest.TestCase):
class EncodeTestPop(unittest.TestCase):
def test_partial(self):
e = irctokens.StatefulEncoder()
line = irctokens.tokenise("PRIVMSG #channel hello")
@ -32,9 +32,16 @@ class TestPop(unittest.TestCase):
lines = e.pop(1)
self.assertEqual(len(lines), 0)
class TestClear(unittest.TestCase):
class EncodeTestClear(unittest.TestCase):
def test(self):
e = irctokens.StatefulEncoder()
e.push(irctokens.tokenise("PRIVMSG #channel hello"))
e.clear()
self.assertEqual(e.pending(), b"")
class EncodeTestEncoding(unittest.TestCase):
def test(self):
e = irctokens.StatefulEncoder(encoding="iso-8859-2")
e.push(irctokens.tokenise("PRIVMSG #channel :hello Č"))
self.assertEqual(e.pending(),
"PRIVMSG #channel :hello Č\r\n".encode("iso-8859-2"))

View File

@ -1,7 +1,7 @@
import unittest
import irctokens
class TestTags(unittest.TestCase):
class TokenTestTags(unittest.TestCase):
def test_missing(self):
line = irctokens.tokenise("PRIVMSG #channel")
self.assertIsNone(line.tags)
@ -18,7 +18,7 @@ class TestTags(unittest.TestCase):
line = irctokens.tokenise(r"@id=1\\\:\r\n\s2 PRIVMSG #channel")
self.assertEqual(line.tags["id"], "1\\;\r\n 2")
class TestSource(unittest.TestCase):
class TokenTestSource(unittest.TestCase):
def test_without_tags(self):
line = irctokens.tokenise(":nick!user@host PRIVMSG #channel")
self.assertEqual(line.source, "nick!user@host")
@ -35,12 +35,12 @@ class TestSource(unittest.TestCase):
line = irctokens.tokenise("@id=123 PRIVMSG #channel")
self.assertIsNone(line.source)
class TestCommand(unittest.TestCase):
class TokenTestCommand(unittest.TestCase):
def test_lowercase(self):
line = irctokens.tokenise("privmsg #channel")
self.assertEqual(line.command, "PRIVMSG")
class TestParams(unittest.TestCase):
class TokenTestParams(unittest.TestCase):
def test_trailing(self):
line = irctokens.tokenise("PRIVMSG #channel :hello world")
self.assertEqual(line.params, ["#channel", "hello world"])
@ -54,7 +54,7 @@ class TestParams(unittest.TestCase):
self.assertEqual(line.command, "PRIVMSG")
self.assertEqual(line.params, [])
class TestAll(unittest.TestCase):
class TokenTestAll(unittest.TestCase):
def test_all(self):
line = irctokens.tokenise(
"@id=123 :nick!user@host PRIVMSG #channel :hello world")