scrap deferred wait_for, actually catch server disconnection

This commit is contained in:
jesopo 2020-09-24 19:43:03 +00:00
parent eb9888d0c4
commit a264e4e347
3 changed files with 83 additions and 61 deletions

View File

@ -1,7 +1,8 @@
from asyncio import Future
from irctokens import Line
from typing import (Any, Awaitable, Callable, Generator, Generic, Optional,
TypeVar)
from irctokens import Line
from .matching import IMatchResponse
from .interface import IServer
from .ircv3 import TAG_LABEL
@ -17,8 +18,10 @@ class MaybeAwait(Generic[TEvent]):
class WaitFor(object):
def __init__(self,
response: IMatchResponse):
self.response = response
response: IMatchResponse,
deadline: float):
self.response = response
self.deadline = deadline
self._label: Optional[str] = None
self._our_fut: "Future[Line]" = Future()

View File

@ -36,13 +36,12 @@ class Bot(IBot):
return server
async def _run_server(self, server: Server):
async with anyio.create_task_group() as tg:
async def _read():
while True:
async for line, emit in server._read_lines():
pass
await tg.spawn(_read)
await tg.spawn(server._send_lines)
try:
async with anyio.create_task_group() as tg:
await tg.spawn(server._read_lines)
await tg.spawn(server._send_lines)
except ServerDisconnectedException:
server.disconnected = True
await self.disconnected(server)

View File

@ -58,16 +58,16 @@ class Server(IServer):
rate_limit=100, period=THROTTLE_TIME)
self.sasl_state = SASLResult.NONE
self.last_read = -1.0
self.last_read = monotonic()
self._sent_count: int = 0
self._send_queue: PriorityQueue[SentLine] = PriorityQueue()
self.desired_caps: Set[ICapability] = set([])
self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
self._read_queue: Deque[Line] = deque()
self._wait_fors: List[WaitFor] = []
self._wait_for_fut: Dict[asyncio.Task, Future[bool]] = {}
self._wait_for: Optional[Tuple[Awaitable, WaitFor]] = None
self._wait_for_fut: Optional[Future[WaitFor]] = None
self._pending_who: Deque[str] = deque()
@ -202,73 +202,93 @@ class Server(IServer):
line = await self.wait_for(end)
async def _next_lines(self) -> AsyncIterable[Line]:
ping_sent = False
async def _read_line(self, timeout: float) -> Optional[Line]:
while True:
if self._read_queue:
return self._read_queue.popleft()
try:
async with timeout_(PING_TIMEOUT):
async with timeout_(timeout):
data = await self._reader.read(1024)
except asyncio.TimeoutError:
if ping_sent:
data = b"" # empty data means the socket disconnected
else:
ping_sent = True
await self.send(build("PING", ["hello"]))
continue
return None
self.last_read = monotonic()
ping_sent = False
try:
lines = self.recv(data)
except ServerDisconnectedException:
self.disconnected = True
raise
lines = self.recv(data)
for line in lines:
yield line
self._read_queue.append(line)
async def _line_or_wait(self,
line_aw: asyncio.Task
) -> Optional[Awaitable]:
wait_for_fut: Future[bool] = Future()
self._wait_for_fut[line_aw] = wait_for_fut
) -> Optional[Tuple[Awaitable, WaitFor]]:
wait_for_fut: Future[WaitFor] = Future()
self._wait_for_fut = wait_for_fut
done, pend = await asyncio.wait([line_aw, wait_for_fut],
return_when=asyncio.FIRST_COMPLETED)
del self._wait_for_fut[line_aw]
self._wait_for_fut = None
if wait_for_fut.done():
new_line_aw = list(pend)[0]
return new_line_aw
return (new_line_aw, wait_for_fut.result())
else:
return None
async def _read_lines(self) -> AsyncIterable[Tuple[Line, Optional[Emit]]]:
async with anyio.create_task_group() as tg:
async for line in self._next_lines():
async def _read_lines(self):
waiting_lines: List[Tuple[Line, Optional[Emit]]] = []
sent_ping = True
while True:
now = monotonic()
timeouts: List[float] = []
timeouts.append((self.last_read+PING_TIMEOUT)-now)
if self._wait_for is not None:
_, wait_for = self._wait_for
timeouts.append(wait_for.deadline-now)
line = await self._read_line(max([0.1, min(timeouts)]))
if line is None:
now = monotonic()
since = now-self.last_read
if self._wait_for is not None:
aw, wait_for = self._wait_for
if wait_for.deadline <= now:
self._wait_for = None
await aw
if since >= PING_TIMEOUT:
if since >= (PING_TIMEOUT*2):
raise ServerDisconnectedException()
elif not sent_ping:
await self.send(build("PING", ["hello"]))
continue
else:
emit = self.parse_tokens(line)
waiting_lines.append((line, emit))
self.line_preread(line)
for i, wait_for in enumerate(self._wait_fors):
if self._wait_for is not None:
aw, wait_for = self._wait_for
if wait_for.match(self, line):
wait_for.resolve(line)
self._wait_fors.pop(i)
self._wait_for = await self._line_or_wait(aw)
if self._wait_for is not None:
continue
else:
continue
for i in range(len(waiting_lines)):
line, emit = waiting_lines.pop(0)
line_aw = self._on_read(line, emit)
self._wait_for = await self._line_or_wait(line_aw)
if self._wait_for is not None:
break
line_aw = asyncio.create_task(self._on_read(line, emit))
new_wait = await self._line_or_wait(line_aw)
if new_wait is not None:
async def _aw():
await new_wait
await tg.spawn(_aw)
yield (line, emit)
async def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]],
sent_aw: Optional[Awaitable[SentLine]]=None,
timeout: float=WAIT_TIMEOUT
timeout: float=WAIT_TIMEOUT
) -> Line:
response_obj: IMatchResponse
@ -277,13 +297,14 @@ class Server(IServer):
else:
response_obj = response
our_wait_for = WaitFor(response_obj)
self._wait_fors.append(our_wait_for)
cur_task = asyncio.current_task()
if cur_task is not None and cur_task in self._wait_for_fut:
wait_for_fut = self._wait_for_fut[cur_task]
wait_for_fut.set_result(True)
deadline = monotonic()+timeout
our_wait_for = WaitFor(response_obj, deadline)
if self._wait_for_fut is not None:
self._wait_for_fut.set_result(our_wait_for)
else:
cur_task = asyncio.current_task()
if cur_task is not None:
self._wait_for = (cur_task, our_wait_for)
if sent_aw is not None:
sent_line = await sent_aw
@ -297,8 +318,7 @@ class Server(IServer):
if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and
not self.cap_agreed(CAP_ECHO)):
new_line = line.with_source(self.hostmask())
emit = self.parse_tokens(new_line)
self._read_queue.append((new_line, emit))
self._read_queue.append(new_line)
async def _send_lines(self):
while True: