simplify wait_for

This commit is contained in:
jesopo 2021-05-12 10:52:39 +00:00
parent 90fb4b7bba
commit 6a05370a12
2 changed files with 45 additions and 80 deletions

View File

@ -6,6 +6,7 @@ from collections import deque
from time import monotonic
import anyio
from asyncio_rlock import RLock
from asyncio_throttle import Throttler
from async_timeout import timeout as timeout_
from ircstates import Emit, Channel, ChannelUser
@ -63,10 +64,13 @@ class Server(IServer):
self._send_queue: PriorityQueue[SentLine] = PriorityQueue()
self.desired_caps: Set[ICapability] = set([])
self._read_queue: Deque[Line] = deque()
self._read_queue: Deque[Line] = deque()
self._process_queue: Deque[Line] = deque()
self._wait_for: Optional[Tuple[Awaitable, WaitFor]] = None
self._wait_for_fut: Optional[Future[WaitFor]] = None
self._read_lguard = RLock()
self.read_lock = self._read_lguard
self._read_lwork = asyncio.Lock()
self._wait_for = asyncio.Event()
self._pending_who: Deque[str] = deque()
self._alt_nicks: List[str] = []
@ -273,76 +277,42 @@ class Server(IServer):
self.last_read = monotonic()
lines = self.recv(data)
for line in lines:
self.line_preread(line)
self._read_queue.append(line)
async def _line_or_wait(self,
line_aw: asyncio.Task
) -> Optional[Tuple[Awaitable, WaitFor]]:
wait_for_fut: Future[WaitFor] = Future()
self._wait_for_fut = wait_for_fut
done, pend = await asyncio.wait([line_aw, wait_for_fut],
return_when=asyncio.FIRST_COMPLETED)
self._wait_for_fut = None
if wait_for_fut.done():
new_line_aw = list(pend)[0]
return (new_line_aw, wait_for_fut.result())
else:
return None
async def _read_lines(self):
waiting_lines: List[Tuple[Line, Optional[Emit]]] = []
sent_ping = False
ping_sent = False
while True:
now = monotonic()
timeouts: List[float] = []
timeouts.append((self.last_read+PING_TIMEOUT)-now)
if self._wait_for is not None:
_, wait_for = self._wait_for
timeouts.append(wait_for.deadline-now)
async with self._read_lguard:
pass
line = await self._read_line(max([0.1, min(timeouts)]))
if line is None:
now = monotonic()
since = now-self.last_read
if not self._process_queue:
async with self._read_lwork:
read_aw = self._read_line(PING_TIMEOUT)
dones, notdones = await asyncio.wait(
[read_aw, self._wait_for.wait()],
return_when=asyncio.FIRST_COMPLETED
)
self._wait_for.clear()
if self._wait_for is not None:
aw, wait_for = self._wait_for
if wait_for.deadline <= now:
self._wait_for = None
await aw
for done in dones:
if isinstance(done.result(), Line):
line = done.result()
self._process_queue.append(line)
elif done.result() is None:
if ping_sent:
await self.send(build("PING", ["hello"]))
ping_sent = True
else:
await self.disconnect()
raise ServerDisconnectedException()
for notdone in notdones:
notdone.cancel()
if since >= PING_TIMEOUT:
if since >= (PING_TIMEOUT*2):
raise ServerDisconnectedException()
elif not sent_ping:
sent_ping = True
await self.send(build("PING", ["hello"]))
continue
else:
sent_ping = False
line = self._process_queue.popleft()
emit = self.parse_tokens(line)
waiting_lines.append((line, emit))
self.line_preread(line)
if self._wait_for is not None:
aw, wait_for = self._wait_for
if wait_for.match(self, line):
wait_for.resolve(line)
self._wait_for = await self._line_or_wait(aw)
if self._wait_for is not None:
continue
else:
continue
for i in range(len(waiting_lines)):
line, emit = waiting_lines.pop(0)
line_aw = self._on_read(line, emit)
self._wait_for = await self._line_or_wait(line_aw)
if self._wait_for is not None:
break
await self._on_read(line, emit)
async def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]],
@ -356,22 +326,16 @@ class Server(IServer):
else:
response_obj = response
deadline = monotonic()+timeout
our_wait_for = WaitFor(response_obj, deadline)
if self._wait_for_fut is not None:
self._wait_for_fut.set_result(our_wait_for)
else:
cur_task = asyncio.current_task()
if cur_task is not None:
self._wait_for = (cur_task, our_wait_for)
if sent_aw is not None:
sent_line = await sent_aw
label = str(sent_line.id)
our_wait_for.with_label(label)
async with timeout_(timeout):
return (await our_wait_for)
async with self._read_lguard:
self._wait_for.set()
async with self._read_lwork:
async with timeout_(timeout):
while True:
line = await self._read_line(timeout)
if line:
self._process_queue.append(line)
if response_obj.match(self, line):
return line
async def _on_send_line(self, line: Line):
if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and

View File

@ -1,4 +1,5 @@
anyio ~=2.0.2
asyncio-rlock ~=0.1.0
asyncio-throttle ~=1.0.1
dataclasses ~=0.6; python_version<"3.7"
ircstates ~=0.11.8