allow suppling encoding to StatefulEncoder and StatefulDecoder (and tests!)

This commit is contained in:
jesopo 2020-03-11 17:38:30 +00:00
parent 09087d63d6
commit 86fad88f92
5 changed files with 35 additions and 21 deletions

View File

@ -2,7 +2,8 @@ import typing
from .protocol import Line, tokenise from .protocol import Line, tokenise
class StatefulDecoder(object): class StatefulDecoder(object):
def __init__(self, fallback: str="iso-8859"): def __init__(self, encoding: str="utf8", fallback: str="iso-8859"):
self._encoding = encoding
self._fallback = fallback self._fallback = fallback
self.clear() self.clear()
@ -23,13 +24,14 @@ class StatefulDecoder(object):
decode_lines: typing.List[str] = [] decode_lines: typing.List[str] = []
for line in lines: for line in lines:
try: try:
decode_lines.append(line.decode("utf8")) decode_lines.append(line.decode(self._encoding))
except UnicodeDecodeError as e: except UnicodeDecodeError as e:
decode_lines.append(line.decode(self._fallback)) decode_lines.append(line.decode(self._fallback))
return [tokenise(l) for l in decode_lines] return [tokenise(l) for l in decode_lines]
class StatefulEncoder(object): class StatefulEncoder(object):
def __init__(self): def __init__(self, encoding: str="utf8"):
self._encoding = encoding
self.clear() self.clear()
def clear(self): def clear(self):
@ -40,7 +42,7 @@ class StatefulEncoder(object):
return self._buffer return self._buffer
def push(self, line: Line): def push(self, line: Line):
self._buffer += f"{line.format()}\r\n".encode("utf8") self._buffer += f"{line.format()}\r\n".encode(self._encoding)
self._buffered_lines.append(line) self._buffered_lines.append(line)
def pop(self, byte_count: int): def pop(self, byte_count: int):

View File

@ -1,7 +1,7 @@
import unittest import unittest
import irctokens import irctokens
class TestTags(unittest.TestCase): class FormatTestTags(unittest.TestCase):
def test(self): def test(self):
line = irctokens.format("PRIVMSG", ["#channel", "hello"], line = irctokens.format("PRIVMSG", ["#channel", "hello"],
tags={"id": "\\" + " " + ";" + "\r\n"}) tags={"id": "\\" + " " + ";" + "\r\n"})
@ -21,18 +21,18 @@ class TestTags(unittest.TestCase):
tags={"a": ""}) tags={"a": ""})
self.assertEqual(line, "@a PRIVMSG #channel hello") self.assertEqual(line, "@a PRIVMSG #channel hello")
class TestSource(unittest.TestCase): class FormatTestSource(unittest.TestCase):
def test(self): def test(self):
line = irctokens.format("PRIVMSG", ["#channel", "hello"], line = irctokens.format("PRIVMSG", ["#channel", "hello"],
source="nick!user@host") source="nick!user@host")
self.assertEqual(line, ":nick!user@host PRIVMSG #channel hello") self.assertEqual(line, ":nick!user@host PRIVMSG #channel hello")
class TestCommand(unittest.TestCase): class FormatTestCommand(unittest.TestCase):
def test_lowercase(self): def test_lowercase(self):
line = irctokens.format("privmsg") line = irctokens.format("privmsg")
self.assertEqual(line, "PRIVMSG") self.assertEqual(line, "PRIVMSG")
class TestTrailing(unittest.TestCase): class FormatTestTrailing(unittest.TestCase):
def test_space(self): def test_space(self):
line = irctokens.format("PRIVMSG", ["#channel", "hello world"]) line = irctokens.format("PRIVMSG", ["#channel", "hello world"])
self.assertEqual(line, "PRIVMSG #channel :hello world") self.assertEqual(line, "PRIVMSG #channel :hello world")

View File

@ -1,7 +1,7 @@
import unittest import unittest
import irctokens import irctokens
class TestPartial(unittest.TestCase): class DecodeTestPartial(unittest.TestCase):
def test(self): def test(self):
d = irctokens.StatefulDecoder() d = irctokens.StatefulDecoder()
lines = d.push(b"PRIVMSG ") lines = d.push(b"PRIVMSG ")
@ -12,7 +12,7 @@ class TestPartial(unittest.TestCase):
line = irctokens.tokenise("PRIVMSG #channel hello") line = irctokens.tokenise("PRIVMSG #channel hello")
self.assertEqual(lines, [line]) self.assertEqual(lines, [line])
class TestMultiple(unittest.TestCase): class DecodeTestMultiple(unittest.TestCase):
def test(self): def test(self):
d = irctokens.StatefulDecoder() d = irctokens.StatefulDecoder()
lines = d.push(b"PRIVMSG #channel1 hello\r\n" lines = d.push(b"PRIVMSG #channel1 hello\r\n"
@ -24,15 +24,20 @@ class TestMultiple(unittest.TestCase):
self.assertEqual(lines[0], line1) self.assertEqual(lines[0], line1)
self.assertEqual(lines[1], line2) self.assertEqual(lines[1], line2)
class TestFallback(unittest.TestCase): class DecodeTestEncoding(unittest.TestCase):
def test(self): def test(self):
d = irctokens.StatefulDecoder(encoding="iso-8859-2")
lines = d.push("PRIVMSG #channel :hello Č\r\n".encode("iso-8859-2"))
line = irctokens.tokenise("PRIVMSG #channel :hello Č")
self.assertEqual(lines[0], line)
def test_fallback(self):
d = irctokens.StatefulDecoder(fallback="latin-1") d = irctokens.StatefulDecoder(fallback="latin-1")
lines = d.push("PRIVMSG #channel hélló\r\n".encode("latin-1")) lines = d.push("PRIVMSG #channel hélló\r\n".encode("latin-1"))
self.assertEqual(len(lines), 1) self.assertEqual(len(lines), 1)
line = irctokens.tokenise("PRIVMSG #channel hélló") line = irctokens.tokenise("PRIVMSG #channel hélló")
self.assertEqual(lines[0], line) self.assertEqual(lines[0], line)
class TestEmpty(unittest.TestCase): class DecodeTestEmpty(unittest.TestCase):
def test_immediate(self): def test_immediate(self):
d = irctokens.StatefulDecoder() d = irctokens.StatefulDecoder()
lines = d.push(b"") lines = d.push(b"")
@ -44,7 +49,7 @@ class TestEmpty(unittest.TestCase):
lines = d.push(b"") lines = d.push(b"")
self.assertIsNone(lines) self.assertIsNone(lines)
class TestClear(unittest.TestCase): class DecodeTestClear(unittest.TestCase):
def test(self): def test(self):
d = irctokens.StatefulDecoder() d = irctokens.StatefulDecoder()
d.push(b"PRIVMSG ") d.push(b"PRIVMSG ")

View File

@ -1,14 +1,14 @@
import unittest import unittest
import irctokens import irctokens
class TestPush(unittest.TestCase): class EncodeTestPush(unittest.TestCase):
def test(self): def test(self):
e = irctokens.StatefulEncoder() e = irctokens.StatefulEncoder()
line = irctokens.tokenise("PRIVMSG #channel hello") line = irctokens.tokenise("PRIVMSG #channel hello")
e.push(line) e.push(line)
self.assertEqual(e.pending(), b"PRIVMSG #channel hello\r\n") self.assertEqual(e.pending(), b"PRIVMSG #channel hello\r\n")
class TestPop(unittest.TestCase): class EncodeTestPop(unittest.TestCase):
def test_partial(self): def test_partial(self):
e = irctokens.StatefulEncoder() e = irctokens.StatefulEncoder()
line = irctokens.tokenise("PRIVMSG #channel hello") line = irctokens.tokenise("PRIVMSG #channel hello")
@ -32,9 +32,16 @@ class TestPop(unittest.TestCase):
lines = e.pop(1) lines = e.pop(1)
self.assertEqual(len(lines), 0) self.assertEqual(len(lines), 0)
class TestClear(unittest.TestCase): class EncodeTestClear(unittest.TestCase):
def test(self): def test(self):
e = irctokens.StatefulEncoder() e = irctokens.StatefulEncoder()
e.push(irctokens.tokenise("PRIVMSG #channel hello")) e.push(irctokens.tokenise("PRIVMSG #channel hello"))
e.clear() e.clear()
self.assertEqual(e.pending(), b"") self.assertEqual(e.pending(), b"")
class EncodeTestEncoding(unittest.TestCase):
def test(self):
e = irctokens.StatefulEncoder(encoding="iso-8859-2")
e.push(irctokens.tokenise("PRIVMSG #channel :hello Č"))
self.assertEqual(e.pending(),
"PRIVMSG #channel :hello Č\r\n".encode("iso-8859-2"))

View File

@ -1,7 +1,7 @@
import unittest import unittest
import irctokens import irctokens
class TestTags(unittest.TestCase): class TokenTestTags(unittest.TestCase):
def test_missing(self): def test_missing(self):
line = irctokens.tokenise("PRIVMSG #channel") line = irctokens.tokenise("PRIVMSG #channel")
self.assertIsNone(line.tags) self.assertIsNone(line.tags)
@ -18,7 +18,7 @@ class TestTags(unittest.TestCase):
line = irctokens.tokenise(r"@id=1\\\:\r\n\s2 PRIVMSG #channel") line = irctokens.tokenise(r"@id=1\\\:\r\n\s2 PRIVMSG #channel")
self.assertEqual(line.tags["id"], "1\\;\r\n 2") self.assertEqual(line.tags["id"], "1\\;\r\n 2")
class TestSource(unittest.TestCase): class TokenTestSource(unittest.TestCase):
def test_without_tags(self): def test_without_tags(self):
line = irctokens.tokenise(":nick!user@host PRIVMSG #channel") line = irctokens.tokenise(":nick!user@host PRIVMSG #channel")
self.assertEqual(line.source, "nick!user@host") self.assertEqual(line.source, "nick!user@host")
@ -35,12 +35,12 @@ class TestSource(unittest.TestCase):
line = irctokens.tokenise("@id=123 PRIVMSG #channel") line = irctokens.tokenise("@id=123 PRIVMSG #channel")
self.assertIsNone(line.source) self.assertIsNone(line.source)
class TestCommand(unittest.TestCase): class TokenTestCommand(unittest.TestCase):
def test_lowercase(self): def test_lowercase(self):
line = irctokens.tokenise("privmsg #channel") line = irctokens.tokenise("privmsg #channel")
self.assertEqual(line.command, "PRIVMSG") self.assertEqual(line.command, "PRIVMSG")
class TestParams(unittest.TestCase): class TokenTestParams(unittest.TestCase):
def test_trailing(self): def test_trailing(self):
line = irctokens.tokenise("PRIVMSG #channel :hello world") line = irctokens.tokenise("PRIVMSG #channel :hello world")
self.assertEqual(line.params, ["#channel", "hello world"]) self.assertEqual(line.params, ["#channel", "hello world"])
@ -54,7 +54,7 @@ class TestParams(unittest.TestCase):
self.assertEqual(line.command, "PRIVMSG") self.assertEqual(line.command, "PRIVMSG")
self.assertEqual(line.params, []) self.assertEqual(line.params, [])
class TestAll(unittest.TestCase): class TokenTestAll(unittest.TestCase):
def test_all(self): def test_all(self):
line = irctokens.tokenise( line = irctokens.tokenise(
"@id=123 :nick!user@host PRIVMSG #channel :hello world") "@id=123 :nick!user@host PRIVMSG #channel :hello world")