diff --git a/ircstates/__init__.py b/ircstates/__init__.py index b206b08..6a13823 100644 --- a/ircstates/__init__.py +++ b/ircstates/__init__.py @@ -2,5 +2,5 @@ from .server import Server, ServerDisconnectedException from .user import User from .channel import Channel from .channel_user import ChannelUser -from .casemap import casefold +from .casemap import casefold, CaseMap from .emit import * diff --git a/ircstates/casemap.py b/ircstates/casemap.py index 0782453..8cdba1e 100644 --- a/ircstates/casemap.py +++ b/ircstates/casemap.py @@ -1,24 +1,24 @@ -import string -from typing import List -from cachetools import cached, LRUCache +from enum import Enum +from string import ascii_lowercase, ascii_uppercase +from typing import Dict, List -ASCII_UPPER = list(string.ascii_uppercase) -ASCII_LOWER = list(string.ascii_lowercase) -RFC1459_UPPER = ASCII_UPPER+list("[]^\\") -RFC1459_LOWER = ASCII_LOWER+list("{}~|") +class CaseMap(Enum): + ASCII = "ascii" + RFC1459 = "rfc1459" -def _replace(s: str, upper: List[str], lower: List[str]): - out = "" - for char in s: - if char in upper: - out += lower[upper.index(char)] - else: - out += char - return out +def _make_trans(upper: str, lower: str): + return str.maketrans(dict(zip(upper, lower))) -@cached(cache=LRUCache(maxsize=1024)) -def casefold(mapping: str, s: str): - if mapping == "rfc1459": - return _replace(s, RFC1459_UPPER, RFC1459_LOWER) - elif mapping == "ascii": - return _replace(s, ASCII_UPPER, ASCII_LOWER) +CASEMAPS: Dict[CaseMap, Dict[int, str]] = { + CaseMap.ASCII: _make_trans( + r"ABCDEFGHIJKLMNOPQRSTUVWXYZ", + r"abcdefghijklmnopqrstuvwxyz" + ), + CaseMap.RFC1459: _make_trans( + r"ABCDEFGHIJKLMNOPQRSTUVWXYZ\[]^", + r"abcdefghijklmnopqrstuvwxyz|{}~" + ) +} +def casefold(casemap_name: CaseMap, s: str): + casemap = CASEMAPS[casemap_name] + return s.translate(casemap) diff --git a/ircstates/isupport/__init__.py b/ircstates/isupport/__init__.py index e0c3260..17e9383 100644 --- a/ircstates/isupport/__init__.py +++ b/ircstates/isupport/__init__.py @@ -1,5 +1,6 @@ from typing import Dict, List, Optional from .tokens import ChanModes, Prefix +from ..casemap import CaseMap CASEMAPPINGS = ["rfc1459", "ascii"] @@ -31,7 +32,7 @@ class ISupport(object): prefix = Prefix(["o", "v"], ["@", "+"]) modes: int = 3 # -1 if "no limit" - casemapping: str = "rfc1459" + casemapping: CaseMap = CaseMap.RFC1459 chantypes: List[str] = ["#"] statusmsg: List[str] = [] @@ -75,8 +76,7 @@ class ISupport(object): self.watch = int(value) if value else -1 elif key == "CASEMAPPING": - if value in CASEMAPPINGS: - self.casemapping = value + self.casemapping = CaseMap(value) elif key == "CHANTYPES": self.chantypes = list(value) diff --git a/test/casemap.py b/test/casemap.py index 740988f..9da28ac 100644 --- a/test/casemap.py +++ b/test/casemap.py @@ -3,11 +3,11 @@ import ircstates, irctokens class CaseMapTestMethod(unittest.TestCase): def test_rfc1459(self): - lower = ircstates.casefold("rfc1459", "ÀTEST[]^\\") + lower = ircstates.casefold(ircstates.CaseMap.RFC1459, "ÀTEST[]^\\") self.assertEqual(lower, "Àtest{}~|") def test_ascii(self): - lower = ircstates.casefold("ascii", "ÀTEST[]~\\") + lower = ircstates.casefold(ircstates.CaseMap.ASCII, "ÀTEST[]~\\") self.assertEqual(lower, "Àtest[]~\\") class CaseMapTestCommands(unittest.TestCase): diff --git a/test/isupport.py b/test/isupport.py index 8b68170..b57c29c 100644 --- a/test/isupport.py +++ b/test/isupport.py @@ -58,25 +58,6 @@ class ISUPPORTTest(unittest.TestCase): server.parse_tokens(irctokens.tokenise("005 * MODES=5 *")) self.assertEqual(server.isupport.modes, 5) - def test_rfc1459(self): - server = ircstates.Server("test") - server.parse_tokens(irctokens.tokenise("001 nickname *")) - self.assertEqual(server.isupport.casemapping, "rfc1459") - server.parse_tokens(irctokens.tokenise("005 * CASEMAPPING=rfc1459 *")) - self.assertEqual(server.isupport.casemapping, "rfc1459") - - def test_ascii(self): - server = ircstates.Server("test") - server.parse_tokens(irctokens.tokenise("001 nickname *")) - server.parse_tokens(irctokens.tokenise("005 * CASEMAPPING=ascii *")) - self.assertEqual(server.isupport.casemapping, "ascii") - - def test_fallback_to_rfc1459(self): - server = ircstates.Server("test") - server.parse_tokens(irctokens.tokenise("001 nickname *")) - server.parse_tokens(irctokens.tokenise("005 * CASEMAPPING=asd *")) - self.assertEqual(server.isupport.casemapping, "rfc1459") - def test_network(self): server = ircstates.Server("test") server.parse_tokens(irctokens.tokenise("001 nickname *")) @@ -149,3 +130,23 @@ class ISUPPORTTest(unittest.TestCase): self.assertEqual(server.isupport.nicklen, 9) server.parse_tokens(irctokens.tokenise("005 * NICKLEN=16 *")) self.assertEqual(server.isupport.nicklen, 16) + +class ISupportTestCasemapping(unittest.TestCase): + def test_rfc1459(self): + server = ircstates.Server("test") + server.parse_tokens(irctokens.tokenise("001 nickname *")) + self.assertEqual(server.isupport.casemapping, ircstates.CaseMap.RFC1459) + server.parse_tokens(irctokens.tokenise("005 * CASEMAPPING=rfc1459 *")) + self.assertEqual(server.isupport.casemapping, ircstates.CaseMap.RFC1459) + + def test_ascii(self): + server = ircstates.Server("test") + server.parse_tokens(irctokens.tokenise("001 nickname *")) + server.parse_tokens(irctokens.tokenise("005 * CASEMAPPING=ascii *")) + self.assertEqual(server.isupport.casemapping, ircstates.CaseMap.ASCII) + + def test_unknown(self): + server = ircstates.Server("test") + server.parse_tokens(irctokens.tokenise("001 nickname *")) + with self.assertRaises(ValueError): + server.parse_tokens(irctokens.tokenise("005 * CASEMAPPING=asd *"))