simplify wait_for

This commit is contained in:
jesopo 2021-02-18 14:50:01 +00:00
parent d0e0314169
commit 2cab5b3002
2 changed files with 36 additions and 104 deletions

View File

@ -19,17 +19,9 @@ class MaybeAwait(Generic[TEvent]):
class WaitFor(object):
def __init__(self,
response: IMatchResponse,
deadline: float):
label: Optional[str]=None):
self.response = response
self.deadline = deadline
self._label: Optional[str] = None
self._our_fut: "Future[Line]" = Future()
def __await__(self) -> Generator[Any, None, Line]:
return self._our_fut.__await__()
def with_label(self, label: str):
self._label = label
self._label = label
def match(self, server: IServer, line: Line):
if (self._label is not None and
@ -39,6 +31,3 @@ class WaitFor(object):
label == self._label):
return True
return self.response.match(server, line)
def resolve(self, line: Line):
self._our_fut.set_result(line)

View File

@ -28,7 +28,7 @@ from .interface import ITCPTransport, ITCPReader, ITCPWriter
THROTTLE_RATE = 4 # lines
THROTTLE_TIME = 2 # seconds
PING_TIMEOUT = 60 # seconds
PING_INTERVAL = 60 # seconds
WAIT_TIMEOUT = 20 # seconds
JOIN_ERR_FIRST = [
@ -64,7 +64,9 @@ class Server(IServer):
self._send_queue: PriorityQueue[SentLine] = PriorityQueue()
self.desired_caps: Set[ICapability] = set([])
self._read_queue: Deque[Line] = deque()
self.read_lock = asyncio.Lock()
self._read_queue: Deque[Line] = deque()
self._process_queue: Deque[Line] = deque()
self._wait_for: Optional[Tuple[Awaitable, WaitFor]] = None
self._wait_for_fut: Optional[Future[WaitFor]] = None
@ -260,94 +262,44 @@ class Server(IServer):
else:
await self.send(build("WHO", [chan]))
async def _read_line(self, timeout: float) -> Optional[Line]:
async def _read_line(self) -> Line:
while True:
if self._read_queue:
return self._read_queue.popleft()
async with self.read_lock:
if self._read_queue:
return self._read_queue.popleft()
try:
async with timeout_(timeout):
data = await self._reader.read(1024)
except asyncio.TimeoutError:
return None
self.last_read = monotonic()
lines = self.recv(data)
for line in lines:
self._read_queue.append(line)
async def _line_or_wait(self,
line_aw: asyncio.Task
) -> Optional[Tuple[Awaitable, WaitFor]]:
wait_for_fut: Future[WaitFor] = Future()
self._wait_for_fut = wait_for_fut
done, pend = await asyncio.wait([line_aw, wait_for_fut],
return_when=asyncio.FIRST_COMPLETED)
self._wait_for_fut = None
if wait_for_fut.done():
new_line_aw = list(pend)[0]
return (new_line_aw, wait_for_fut.result())
else:
return None
data = await self._reader.read(1024)
lines = self.recv(data)
# last_read under self.recv() as recv might throw Disconnected
self.last_read = monotonic()
for line in lines:
self._read_queue.append(line)
async def _read_lines(self):
waiting_lines: List[Tuple[Line, Optional[Emit]]] = []
sent_ping = False
while True:
now = monotonic()
timeouts: List[float] = []
timeouts.append((self.last_read+PING_TIMEOUT)-now)
if self._wait_for is not None:
_, wait_for = self._wait_for
timeouts.append(wait_for.deadline-now)
line = await self._read_line(max([0.1, min(timeouts)]))
if line is None:
now = monotonic()
since = now-self.last_read
if self._wait_for is not None:
aw, wait_for = self._wait_for
if wait_for.deadline <= now:
self._wait_for = None
await aw
if since >= PING_TIMEOUT:
if since >= (PING_TIMEOUT*2):
raise ServerDisconnectedException()
elif not sent_ping:
if not self._process_queue:
try:
async with timeout_(PING_INTERVAL):
line = await self._read_line()
except asyncio.TimeoutError:
if not sent_ping:
sent_ping = True
await self.send(build("PING", ["hello"]))
continue
else:
sent_ping = False
emit = self.parse_tokens(line)
waiting_lines.append((line, emit))
self.line_preread(line)
if self._wait_for is not None:
aw, wait_for = self._wait_for
if wait_for.match(self, line):
wait_for.resolve(line)
self._wait_for = await self._line_or_wait(aw)
if self._wait_for is not None:
continue
else:
continue
raise ServerDisconnectedException()
else:
sent_ping = False
self._process_queue.append(line)
for i in range(len(waiting_lines)):
line, emit = waiting_lines.pop(0)
line_aw = self._on_read(line, emit)
self._wait_for = await self._line_or_wait(line_aw)
if self._wait_for is not None:
break
line = self._process_queue.popleft()
emit = self.parse_tokens(line)
await self._on_read(line, emit)
async def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]],
sent_aw: Optional[Awaitable[SentLine]]=None,
label: Optional[str]=None,
timeout: float=WAIT_TIMEOUT
) -> Line:
@ -357,22 +309,13 @@ class Server(IServer):
else:
response_obj = response
deadline = monotonic()+timeout
our_wait_for = WaitFor(response_obj, deadline)
if self._wait_for_fut is not None:
self._wait_for_fut.set_result(our_wait_for)
else:
cur_task = asyncio.current_task()
if cur_task is not None:
self._wait_for = (cur_task, our_wait_for)
if sent_aw is not None:
sent_line = await sent_aw
label = str(sent_line.id)
our_wait_for.with_label(label)
wait_for = WaitFor(response_obj, label)
async with timeout_(timeout):
return (await our_wait_for)
while True:
line = await self._read_line()
self._process_queue.append(line)
if wait_for.match(self, line):
return line
async def _on_send_line(self, line: Line):
if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and