From 6a05370a12bebc3932eba99ecb349994eab37a34 Mon Sep 17 00:00:00 2001 From: jesopo Date: Wed, 12 May 2021 10:52:39 +0000 Subject: [PATCH] simplify wait_for --- ircrobots/server.py | 124 ++++++++++++++++---------------------------- requirements.txt | 1 + 2 files changed, 45 insertions(+), 80 deletions(-) diff --git a/ircrobots/server.py b/ircrobots/server.py index 73a61b6..14c8a40 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -6,6 +6,7 @@ from collections import deque from time import monotonic import anyio +from asyncio_rlock import RLock from asyncio_throttle import Throttler from async_timeout import timeout as timeout_ from ircstates import Emit, Channel, ChannelUser @@ -63,10 +64,13 @@ class Server(IServer): self._send_queue: PriorityQueue[SentLine] = PriorityQueue() self.desired_caps: Set[ICapability] = set([]) - self._read_queue: Deque[Line] = deque() + self._read_queue: Deque[Line] = deque() + self._process_queue: Deque[Line] = deque() - self._wait_for: Optional[Tuple[Awaitable, WaitFor]] = None - self._wait_for_fut: Optional[Future[WaitFor]] = None + self._read_lguard = RLock() + self.read_lock = self._read_lguard + self._read_lwork = asyncio.Lock() + self._wait_for = asyncio.Event() self._pending_who: Deque[str] = deque() self._alt_nicks: List[str] = [] @@ -273,76 +277,42 @@ class Server(IServer): self.last_read = monotonic() lines = self.recv(data) for line in lines: + self.line_preread(line) self._read_queue.append(line) - async def _line_or_wait(self, - line_aw: asyncio.Task - ) -> Optional[Tuple[Awaitable, WaitFor]]: - wait_for_fut: Future[WaitFor] = Future() - self._wait_for_fut = wait_for_fut - - done, pend = await asyncio.wait([line_aw, wait_for_fut], - return_when=asyncio.FIRST_COMPLETED) - self._wait_for_fut = None - - if wait_for_fut.done(): - new_line_aw = list(pend)[0] - return (new_line_aw, wait_for_fut.result()) - else: - return None - async def _read_lines(self): - waiting_lines: List[Tuple[Line, Optional[Emit]]] = [] - sent_ping = False + ping_sent = False while True: - now = monotonic() - timeouts: List[float] = [] - timeouts.append((self.last_read+PING_TIMEOUT)-now) - if self._wait_for is not None: - _, wait_for = self._wait_for - timeouts.append(wait_for.deadline-now) + async with self._read_lguard: + pass - line = await self._read_line(max([0.1, min(timeouts)])) - if line is None: - now = monotonic() - since = now-self.last_read + if not self._process_queue: + async with self._read_lwork: + read_aw = self._read_line(PING_TIMEOUT) + dones, notdones = await asyncio.wait( + [read_aw, self._wait_for.wait()], + return_when=asyncio.FIRST_COMPLETED + ) + self._wait_for.clear() - if self._wait_for is not None: - aw, wait_for = self._wait_for - if wait_for.deadline <= now: - self._wait_for = None - await aw + for done in dones: + if isinstance(done.result(), Line): + line = done.result() + self._process_queue.append(line) + elif done.result() is None: + if ping_sent: + await self.send(build("PING", ["hello"])) + ping_sent = True + else: + await self.disconnect() + raise ServerDisconnectedException() + for notdone in notdones: + notdone.cancel() - if since >= PING_TIMEOUT: - if since >= (PING_TIMEOUT*2): - raise ServerDisconnectedException() - elif not sent_ping: - sent_ping = True - await self.send(build("PING", ["hello"])) - continue else: - sent_ping = False + line = self._process_queue.popleft() emit = self.parse_tokens(line) - - waiting_lines.append((line, emit)) - self.line_preread(line) - - if self._wait_for is not None: - aw, wait_for = self._wait_for - if wait_for.match(self, line): - wait_for.resolve(line) - self._wait_for = await self._line_or_wait(aw) - if self._wait_for is not None: - continue - else: - continue - - for i in range(len(waiting_lines)): - line, emit = waiting_lines.pop(0) - line_aw = self._on_read(line, emit) - self._wait_for = await self._line_or_wait(line_aw) - if self._wait_for is not None: - break + await self._on_read(line, emit) async def wait_for(self, response: Union[IMatchResponse, Set[IMatchResponse]], @@ -356,22 +326,16 @@ class Server(IServer): else: response_obj = response - deadline = monotonic()+timeout - our_wait_for = WaitFor(response_obj, deadline) - if self._wait_for_fut is not None: - self._wait_for_fut.set_result(our_wait_for) - else: - cur_task = asyncio.current_task() - if cur_task is not None: - self._wait_for = (cur_task, our_wait_for) - - if sent_aw is not None: - sent_line = await sent_aw - label = str(sent_line.id) - our_wait_for.with_label(label) - - async with timeout_(timeout): - return (await our_wait_for) + async with self._read_lguard: + self._wait_for.set() + async with self._read_lwork: + async with timeout_(timeout): + while True: + line = await self._read_line(timeout) + if line: + self._process_queue.append(line) + if response_obj.match(self, line): + return line async def _on_send_line(self, line: Line): if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and diff --git a/requirements.txt b/requirements.txt index 170f4cc..771bd74 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ anyio ~=2.0.2 +asyncio-rlock ~=0.1.0 asyncio-throttle ~=1.0.1 dataclasses ~=0.6; python_version<"3.7" ircstates ~=0.11.8