str.maketrans is a much faster casefold; make casemaps an Enum

This commit is contained in:
jesopo 2022-01-07 18:53:14 +00:00
parent bc7c4d75a8
commit ea9c0c2d1f
5 changed files with 47 additions and 46 deletions

View File

@ -2,5 +2,5 @@ from .server import Server, ServerDisconnectedException
from .user import User
from .channel import Channel
from .channel_user import ChannelUser
from .casemap import casefold
from .casemap import casefold, CaseMap
from .emit import *

View File

@ -1,24 +1,24 @@
import string
from typing import List
from cachetools import cached, LRUCache
from enum import Enum
from string import ascii_lowercase, ascii_uppercase
from typing import Dict, List
ASCII_UPPER = list(string.ascii_uppercase)
ASCII_LOWER = list(string.ascii_lowercase)
RFC1459_UPPER = ASCII_UPPER+list("[]^\\")
RFC1459_LOWER = ASCII_LOWER+list("{}~|")
class CaseMap(Enum):
ASCII = "ascii"
RFC1459 = "rfc1459"
def _replace(s: str, upper: List[str], lower: List[str]):
out = ""
for char in s:
if char in upper:
out += lower[upper.index(char)]
else:
out += char
return out
def _make_trans(upper: str, lower: str):
return str.maketrans(dict(zip(upper, lower)))
@cached(cache=LRUCache(maxsize=1024))
def casefold(mapping: str, s: str):
if mapping == "rfc1459":
return _replace(s, RFC1459_UPPER, RFC1459_LOWER)
elif mapping == "ascii":
return _replace(s, ASCII_UPPER, ASCII_LOWER)
CASEMAPS: Dict[CaseMap, Dict[int, str]] = {
CaseMap.ASCII: _make_trans(
r"ABCDEFGHIJKLMNOPQRSTUVWXYZ",
r"abcdefghijklmnopqrstuvwxyz"
),
CaseMap.RFC1459: _make_trans(
r"ABCDEFGHIJKLMNOPQRSTUVWXYZ\[]^",
r"abcdefghijklmnopqrstuvwxyz|{}~"
)
}
def casefold(casemap_name: CaseMap, s: str):
casemap = CASEMAPS[casemap_name]
return s.translate(casemap)

View File

@ -1,5 +1,6 @@
from typing import Dict, List, Optional
from .tokens import ChanModes, Prefix
from ..casemap import CaseMap
CASEMAPPINGS = ["rfc1459", "ascii"]
@ -31,7 +32,7 @@ class ISupport(object):
prefix = Prefix(["o", "v"], ["@", "+"])
modes: int = 3 # -1 if "no limit"
casemapping: str = "rfc1459"
casemapping: CaseMap = CaseMap.RFC1459
chantypes: List[str] = ["#"]
statusmsg: List[str] = []
@ -75,8 +76,7 @@ class ISupport(object):
self.watch = int(value) if value else -1
elif key == "CASEMAPPING":
if value in CASEMAPPINGS:
self.casemapping = value
self.casemapping = CaseMap(value)
elif key == "CHANTYPES":
self.chantypes = list(value)

View File

@ -3,11 +3,11 @@ import ircstates, irctokens
class CaseMapTestMethod(unittest.TestCase):
def test_rfc1459(self):
lower = ircstates.casefold("rfc1459", "ÀTEST[]^\\")
lower = ircstates.casefold(ircstates.CaseMap.RFC1459, "ÀTEST[]^\\")
self.assertEqual(lower, "Àtest{}~|")
def test_ascii(self):
lower = ircstates.casefold("ascii", "ÀTEST[]~\\")
lower = ircstates.casefold(ircstates.CaseMap.ASCII, "ÀTEST[]~\\")
self.assertEqual(lower, "Àtest[]~\\")
class CaseMapTestCommands(unittest.TestCase):

View File

@ -58,25 +58,6 @@ class ISUPPORTTest(unittest.TestCase):
server.parse_tokens(irctokens.tokenise("005 * MODES=5 *"))
self.assertEqual(server.isupport.modes, 5)
def test_rfc1459(self):
server = ircstates.Server("test")
server.parse_tokens(irctokens.tokenise("001 nickname *"))
self.assertEqual(server.isupport.casemapping, "rfc1459")
server.parse_tokens(irctokens.tokenise("005 * CASEMAPPING=rfc1459 *"))
self.assertEqual(server.isupport.casemapping, "rfc1459")
def test_ascii(self):
server = ircstates.Server("test")
server.parse_tokens(irctokens.tokenise("001 nickname *"))
server.parse_tokens(irctokens.tokenise("005 * CASEMAPPING=ascii *"))
self.assertEqual(server.isupport.casemapping, "ascii")
def test_fallback_to_rfc1459(self):
server = ircstates.Server("test")
server.parse_tokens(irctokens.tokenise("001 nickname *"))
server.parse_tokens(irctokens.tokenise("005 * CASEMAPPING=asd *"))
self.assertEqual(server.isupport.casemapping, "rfc1459")
def test_network(self):
server = ircstates.Server("test")
server.parse_tokens(irctokens.tokenise("001 nickname *"))
@ -149,3 +130,23 @@ class ISUPPORTTest(unittest.TestCase):
self.assertEqual(server.isupport.nicklen, 9)
server.parse_tokens(irctokens.tokenise("005 * NICKLEN=16 *"))
self.assertEqual(server.isupport.nicklen, 16)
class ISupportTestCasemapping(unittest.TestCase):
def test_rfc1459(self):
server = ircstates.Server("test")
server.parse_tokens(irctokens.tokenise("001 nickname *"))
self.assertEqual(server.isupport.casemapping, ircstates.CaseMap.RFC1459)
server.parse_tokens(irctokens.tokenise("005 * CASEMAPPING=rfc1459 *"))
self.assertEqual(server.isupport.casemapping, ircstates.CaseMap.RFC1459)
def test_ascii(self):
server = ircstates.Server("test")
server.parse_tokens(irctokens.tokenise("001 nickname *"))
server.parse_tokens(irctokens.tokenise("005 * CASEMAPPING=ascii *"))
self.assertEqual(server.isupport.casemapping, ircstates.CaseMap.ASCII)
def test_unknown(self):
server = ircstates.Server("test")
server.parse_tokens(irctokens.tokenise("001 nickname *"))
with self.assertRaises(ValueError):
server.parse_tokens(irctokens.tokenise("005 * CASEMAPPING=asd *"))