diff --git a/ircrobots/asyncs.py b/ircrobots/asyncs.py index ce216c8..54d0b3b 100644 --- a/ircrobots/asyncs.py +++ b/ircrobots/asyncs.py @@ -1,7 +1,8 @@ from asyncio import Future -from irctokens import Line from typing import (Any, Awaitable, Callable, Generator, Generic, Optional, TypeVar) + +from irctokens import Line from .matching import IMatchResponse from .interface import IServer from .ircv3 import TAG_LABEL @@ -17,8 +18,10 @@ class MaybeAwait(Generic[TEvent]): class WaitFor(object): def __init__(self, - response: IMatchResponse): - self.response = response + response: IMatchResponse, + deadline: float): + self.response = response + self.deadline = deadline self._label: Optional[str] = None self._our_fut: "Future[Line]" = Future() diff --git a/ircrobots/bot.py b/ircrobots/bot.py index 6ea6b97..cdfe3ab 100644 --- a/ircrobots/bot.py +++ b/ircrobots/bot.py @@ -36,13 +36,12 @@ class Bot(IBot): return server async def _run_server(self, server: Server): - async with anyio.create_task_group() as tg: - async def _read(): - while True: - async for line, emit in server._read_lines(): - pass - await tg.spawn(_read) - await tg.spawn(server._send_lines) + try: + async with anyio.create_task_group() as tg: + await tg.spawn(server._read_lines) + await tg.spawn(server._send_lines) + except ServerDisconnectedException: + server.disconnected = True await self.disconnected(server) diff --git a/ircrobots/server.py b/ircrobots/server.py index 858c6b1..1623876 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -58,16 +58,16 @@ class Server(IServer): rate_limit=100, period=THROTTLE_TIME) self.sasl_state = SASLResult.NONE - self.last_read = -1.0 + self.last_read = monotonic() self._sent_count: int = 0 self._send_queue: PriorityQueue[SentLine] = PriorityQueue() self.desired_caps: Set[ICapability] = set([]) - self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque() + self._read_queue: Deque[Line] = deque() - self._wait_fors: List[WaitFor] = [] - self._wait_for_fut: Dict[asyncio.Task, Future[bool]] = {} + self._wait_for: Optional[Tuple[Awaitable, WaitFor]] = None + self._wait_for_fut: Optional[Future[WaitFor]] = None self._pending_who: Deque[str] = deque() @@ -202,73 +202,93 @@ class Server(IServer): line = await self.wait_for(end) - async def _next_lines(self) -> AsyncIterable[Line]: - ping_sent = False + async def _read_line(self, timeout: float) -> Optional[Line]: while True: + if self._read_queue: + return self._read_queue.popleft() + try: - async with timeout_(PING_TIMEOUT): + async with timeout_(timeout): data = await self._reader.read(1024) except asyncio.TimeoutError: - if ping_sent: - data = b"" # empty data means the socket disconnected - else: - ping_sent = True - await self.send(build("PING", ["hello"])) - continue + return None self.last_read = monotonic() - ping_sent = False - - try: - lines = self.recv(data) - except ServerDisconnectedException: - self.disconnected = True - raise - + lines = self.recv(data) for line in lines: - yield line + self._read_queue.append(line) async def _line_or_wait(self, line_aw: asyncio.Task - ) -> Optional[Awaitable]: - wait_for_fut: Future[bool] = Future() - self._wait_for_fut[line_aw] = wait_for_fut + ) -> Optional[Tuple[Awaitable, WaitFor]]: + wait_for_fut: Future[WaitFor] = Future() + self._wait_for_fut = wait_for_fut done, pend = await asyncio.wait([line_aw, wait_for_fut], return_when=asyncio.FIRST_COMPLETED) - del self._wait_for_fut[line_aw] + self._wait_for_fut = None if wait_for_fut.done(): new_line_aw = list(pend)[0] - return new_line_aw + return (new_line_aw, wait_for_fut.result()) else: return None - async def _read_lines(self) -> AsyncIterable[Tuple[Line, Optional[Emit]]]: - async with anyio.create_task_group() as tg: - async for line in self._next_lines(): + async def _read_lines(self): + waiting_lines: List[Tuple[Line, Optional[Emit]]] = [] + sent_ping = True + while True: + now = monotonic() + timeouts: List[float] = [] + timeouts.append((self.last_read+PING_TIMEOUT)-now) + if self._wait_for is not None: + _, wait_for = self._wait_for + timeouts.append(wait_for.deadline-now) + + line = await self._read_line(max([0.1, min(timeouts)])) + if line is None: + now = monotonic() + since = now-self.last_read + + if self._wait_for is not None: + aw, wait_for = self._wait_for + if wait_for.deadline <= now: + self._wait_for = None + await aw + + if since >= PING_TIMEOUT: + if since >= (PING_TIMEOUT*2): + raise ServerDisconnectedException() + elif not sent_ping: + await self.send(build("PING", ["hello"])) + continue + else: emit = self.parse_tokens(line) + + waiting_lines.append((line, emit)) self.line_preread(line) - for i, wait_for in enumerate(self._wait_fors): + if self._wait_for is not None: + aw, wait_for = self._wait_for if wait_for.match(self, line): wait_for.resolve(line) - self._wait_fors.pop(i) + self._wait_for = await self._line_or_wait(aw) + if self._wait_for is not None: + continue + else: + continue + + for i in range(len(waiting_lines)): + line, emit = waiting_lines.pop(0) + line_aw = self._on_read(line, emit) + self._wait_for = await self._line_or_wait(line_aw) + if self._wait_for is not None: break - line_aw = asyncio.create_task(self._on_read(line, emit)) - new_wait = await self._line_or_wait(line_aw) - if new_wait is not None: - async def _aw(): - await new_wait - await tg.spawn(_aw) - - yield (line, emit) - async def wait_for(self, response: Union[IMatchResponse, Set[IMatchResponse]], sent_aw: Optional[Awaitable[SentLine]]=None, - timeout: float=WAIT_TIMEOUT + timeout: float=WAIT_TIMEOUT ) -> Line: response_obj: IMatchResponse @@ -277,13 +297,14 @@ class Server(IServer): else: response_obj = response - our_wait_for = WaitFor(response_obj) - self._wait_fors.append(our_wait_for) - - cur_task = asyncio.current_task() - if cur_task is not None and cur_task in self._wait_for_fut: - wait_for_fut = self._wait_for_fut[cur_task] - wait_for_fut.set_result(True) + deadline = monotonic()+timeout + our_wait_for = WaitFor(response_obj, deadline) + if self._wait_for_fut is not None: + self._wait_for_fut.set_result(our_wait_for) + else: + cur_task = asyncio.current_task() + if cur_task is not None: + self._wait_for = (cur_task, our_wait_for) if sent_aw is not None: sent_line = await sent_aw @@ -297,8 +318,7 @@ class Server(IServer): if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and not self.cap_agreed(CAP_ECHO)): new_line = line.with_source(self.hostmask()) - emit = self.parse_tokens(new_line) - self._read_queue.append((new_line, emit)) + self._read_queue.append(new_line) async def _send_lines(self): while True: