Source code for asyncutils.channels

  1from asyncutils.constants import _NO_DEFAULT
  2from asyncutils._internal import log as L, patch as P
  3from asyncutils._internal.compat import Queue, QueueEmpty, QueueShutDown
  4from asyncutils._internal.helpers import copy_and_clear, filter_out, get_loop_and_set, subscriptable, fullname
  5from asyncutils._internal.submodules import channels_all as __all__
  6from _functools import partial
  7from _weakrefset import WeakSet
  8import asyncio as I, asyncutils as A
  9from collections import defaultdict, deque, namedtuple
 10from itertools import repeat, starmap
 11from sys import addaudithook, audit
[docs] 12@subscriptable 13class Observable(A.LoopContextMixin): 14 __slots__ = '_data', '_event', '_lock', '_queue', '_to_remove' 15 @property 16 def idle(self): return self._event.is_set() 17 @property 18 def notifying(self): return not self.idle
[docs] 19 async def notify(self, *a, _ret_exc_=False, **k): 20 if not self: return 21 async with self._lock: 22 if self.notifying: 23 if (q := self._queue) is None: await self.wait_until_idle() 24 else: return await q.put((_ret_exc_, a, k)) 25 self._event.clear() 26 try: await self._notify_helper(_ret_exc_, a, k); await self.handle_notifications() 27 finally: self._event.set(); await self.handle_unsubscriptions()
[docs] 28 async def notify_sequential(self, *a, _silent_=False, _persistent_=False, **k): 29 for observer in self._data.copy(): 30 try: yield await observer(*a, **k) 31 except Exception: 32 if _silent_: 33 if _persistent_: continue 34 break 35 elif _persistent_: L.exception('error in observer') 36 else: raise
[docs] 37 async def wait_for_next(self, timeout=None, strict=False): 38 async def f(*a, **k): F.set_result((a, k)) # noqa: RUF029 39 F, u = self.make_fut(), self.subscribe_nowait(f) 40 try: return await I.wait_for(F, timeout) 41 finally: u(strict)
[docs] 42 async def wait_until_idle(self, timeout=None): await I.wait_for(self._event.wait(), timeout)
[docs] 43 async def subscribe(self, observer): await self.wait_until_idle(); return self.subscribe_nowait(observer)
[docs] 44 async def unsubscribe(self, observer, strict=False): await self.wait_until_idle(); self.unsubscribe_nowait(observer, strict)
[docs] 45 async def handle_notifications(self): 46 while (q := self._queue) is not None: 47 try: await self._notify_helper(*q.get_nowait()) 48 except QueueEmpty: break
[docs] 49 async def handle_unsubscriptions(self): 50 async with self._lock: self._data -= (s := self._to_remove); s.clear()
51 async def _notify_helper(self, r, a, k): await I.gather(*self.make_multiple(obs(*a, **k) for obs in self._data.copy()), return_exceptions=r) 52 def __init__(self, init_observers=(), maxsize=0): audit('asyncutils.channels.Observable', maxsize); self._data, self._lock, self._to_remove, self._queue, self._event = set(init_observers), I.Lock(), set(), None if maxsize is None else Queue(maxsize), I.Event()
[docs] 53 def __iter__(self): return self._data.__iter__()
[docs] 54 def __aiter__(self): return A.iter_to_agen(self._data)
[docs] 55 async def __setup__(self): A.LoopContextMixin.__init__(self)
[docs] 56 async def __cleanup__(self): await I.gather(self.handle_notifications(), self.handle_unsubscriptions())
[docs] 57 def start_accumulation(self): return self.restart_accumulation() or True if self._queue is None else False
[docs] 58 async def restart_accumulation(self, flush=True): 59 if flush: await self.handle_notifications() 60 self._queue = Queue()
[docs] 61 def subscribe_nowait(self, observer): self._data.add(observer); return partial(self.unsubscribe_nowait, observer)
[docs] 62 def unsubscribe_eventually(self, observer, asap=True): 63 if asap and self._event.is_set(): self.unsubscribe_nowait(observer) 64 else: self._to_remove.add(observer)
[docs] 65 def unsubscribe_nowait(self, observer, strict=False): getattr(self._data, 'remove' if strict else 'discard')(observer)
[docs] 66 def subscribe_syncf(self, observer): return self.subscribe_nowait(A.to_async(observer))
[docs] 67 def ntimes(self, observer, n=None): 68 if n is None: n = A.getcontext().OBSERVABLE_DEFAULT_NTIMES_N 69 if n <= 0: raise ValueError('asyncutils.channels.Observable.ntimes: n must be positive') 70 async def wrapper(*a, **k): 71 nonlocal n; await observer(*a, **k); n -= 1 # ty: ignore[unsupported-operator] 72 if n == 0: await self.unsubscribe(wrapper) 73 self.subscribe_nowait(wrapper); return partial(self.unsubscribe_nowait, wrapper)
[docs] 74 def filter(self, pred, ret_exc=False): 75 f = partial((_ := type(self)())._notify_helper, ret_exc) 76 async def filtered(*a, **k): 77 if pred(*a, **k): await f(a, k) 78 self.subscribe_nowait(filtered); return _
[docs] 79 def map(self, transform, ret_exc=False): 80 f = partial((_ := Observable())._notify_helper, ret_exc) 81 async def mapped(*a, **k): await f(*transform(*a, **k)) 82 self.subscribe_nowait(mapped); return _
[docs] 83 def debounce(self, delay, ret_exc=False): 84 f = partial((_ := type(self)())._notify_helper, ret_exc); t = None 85 async def debounced(*a, **k): 86 nonlocal t 87 if t is not None: await A.safe_cancel(t) 88 async def notifier(): 89 with A.ignore_cancellation: await I.sleep(delay); await f(a, k) 90 t = self.make(notifier()) 91 self.subscribe_nowait(debounced); return _
[docs] 92 def throttle(self, interval, ret_exc=False): 93 f, t = partial((_ := type(self)())._notify_helper, ret_exc), 0 94 async def throttled(*a, **k): 95 nonlocal t 96 with A.event_loop.from_flags(0) as l: 97 if (c := l.time())-t >= interval: t = c; await f(a, k) 98 self.subscribe_nowait(throttled); return _
[docs] 99 def buffer(self, count, ret_exc=False): 100 f, b, c = (_ := type(self)())._notify_helper, [], max(1, count) 101 async def buffered(*a, **k): 102 b.append((a, k)) 103 if len(b) >= c: await I.gather(*starmap(f, copy_and_clear(b)), return_exceptions=ret_exc) 104 self.subscribe_nowait(buffered); return _
[docs] 105 def at_change(self, key=lambda *a, **k: (a, frozenset(k.items())), ret_exc=False): 106 f, l = partial((_ := type(self)())._notify_helper, ret_exc), object() 107 async def distinct(*a, **k): 108 nonlocal l 109 if (c := key(*a, **k)) != l: l = c; await f(a, k) 110 self.subscribe_nowait(distinct); return _
[docs] 111 def fork(self, ret_exc=False): self.subscribe_nowait(partial((_ := type(self)()).notify, _ret_exc_=ret_exc)); return _
[docs] 112 def merge(*obs, ret_exc=False): 113 p = partial((_ := type(obs[0])()).notify, _ret_exc_=ret_exc) 114 for o in obs: o._data.add(p) 115 return _
[docs] 116class EventBus(A.LoopContextMixin): 117 __slots__ = '_auditing', '_handler', '_is_shutdown', '_lock', '_middlewares', '_published', '_publishers', '_sem', '_subscribers', '_tracking', 'auditor', 'name' 118 def __init__(self, name=None, *, handler=None, max_concurrent=None, tracking_stats=False): 119 if max_concurrent is None: max_concurrent = A.getcontext().EVENT_BUS_DEFAULT_MAX_CONCURRENT 120 def auditor(*a, f=self.is_auditing, _=self.sync_start_publish): 121 if f(): _(*a) 122 audit('asyncutils.channels.EventBus', name, id(self)); self.auditor, self._subscribers, self._published, self._middlewares, self._publishers, self.name, self._lock, self._auditing, self._handler, self._sem, self._is_shutdown, self._tracking, s[None] = auditor, (s := defaultdict(WeakSet)), defaultdict(int), [], set(), f'{fullname(self)} {name or self._inc_cnt()}', I.Lock(), False, handler or (lambda _: None), I.Semaphore(max_concurrent), False, tracking_stats, WeakSet()
[docs] 123 def raise_for_shutdown(self): 124 if self._is_shutdown: raise A.BusShutDown(f'{self.name} is shutting down')
[docs] 125 def get_event_stats(self): 126 if self._tracking: return self._published.copy() 127 raise A.BusStatsError(f'{self.name} is not tracking event stats')
[docs] 128 def subscribers_for(self, event_type): return self._subscribers[event_type].copy()
[docs] 129 def event_names(self): (s := set(self._subscribers)).discard(None); return s
[docs] 130 def has_subscribers(self, event_type): return bool(self._subscribers[event_type])
[docs] 131 @staticmethod 132 def is_valid_event_type(event_type): return event_type is None or isinstance(event_type, str)
[docs] 133 def is_subscribed(self, subscriber, event_type=_NO_DEFAULT): return any(subscriber in i for i in self._subscribers.values()) if event_type is _NO_DEFAULT else subscriber in self._subscribers.get(event_type, ())
134 @property 135 def total_subscribers(self): return sum(map(len, self._subscribers.values())) 136 @property 137 def wildcards(self): return self.subscribers_for(None) 138 @property 139 def wildcard_count(self): return len(self._subscribers[None]) 140 @property 141 def active_tasks(self): return self._sem._value 142 @property 143 def stream_queue(self): 144 if (r := getattr(self, '_stream_queue', None)) is None: self._stream_queue = r = Queue() 145 return r 146 @stream_queue.setter 147 def stream_queue(self, val, /): self._stream_queue = val
[docs] 148 def is_auditing(self): return self._auditing
149 auditing = property(is_auditing, lambda self, val, /: (self.start_audit if val else self.stop_audit)())
[docs] 150 def start_audit(self): 151 if not (self._auditing or getattr(a := self.auditor, 'added', False)): audit('asyncutils.channels.EventBus.start_audit', id(self)); addaudithook(a); self._auditing = a.added = True # ty: ignore[unresolved-attribute]
[docs] 152 def stop_audit(self): audit('asyncutils.channels.EventBus.stop_audit', id(self)); self._auditing = False
[docs] 153 def add_middleware(self, middleware): r = len(m := self._middlewares); m.append((middleware, None)); return r
[docs] 154 def remove_middleware(self, cookie, *, result=None, strict=False): 155 r, m[cookie] = (m := self._middlewares)[cookie], None 156 if r: 157 if (F := r[1]).done(): return F.result() 158 F.set_result(result) 159 elif strict: raise ValueError(cookie) 160 return result
[docs] 161 def add_temp_middleware(self, middleware, until): self._middlewares.append((middleware, until))
[docs] 162 @(c := A.dualcontextmanager(use_existing_executor=False, create_executor=False, strict=False)) 163 def audit_context(self): 164 o = not self._auditing 165 try: 166 if o: self.start_audit() 167 yield 168 finally: 169 if o: self.stop_audit()
[docs] 170 @c 171 def tracking_context(self, stats_receiver=None): 172 o = not self._tracking 173 try: 174 if o: self.start_tracking() 175 yield 176 finally: 177 if o: self.stop_tracking() if stats_receiver is None else stats_receiver.set_result(self.stop_tracking(True))
[docs] 178 def start_tracking(self): self._tracking = True
[docs] 179 def stop_tracking(self, ret_stats=False): self._tracking = False; return copy_and_clear(self._published) if ret_stats else self._published.clear()
[docs] 180 def subscribe(self, subscriber, /, event_type=None): self.raise_for_shutdown(); self._subscribers[event_type].add(subscriber); return subscriber
[docs] 181 def unsubscribe(self, subscriber, /, event_type=None): 182 self.raise_for_shutdown() 183 try: self._subscribers[event_type].remove(subscriber); return True 184 except KeyError: return False
[docs] 185 def subscribe_to(self, event_type): return partial(self.subscribe, event_type=event_type)
[docs] 186 def subscriber_count(self, event_type): return len(self._subscribers[event_type])
187 async def _publish_helper(self, d, s, I, *_, f=I.gather): await f(*((self._safe_callback(i, d, *_) for i in I) if s else (i(d, *_) for i in I)))
[docs] 188 async def publish(self, event_type, data=None, *, wait=True, **k): 189 p, f = self.sync_start_publish(event_type, data, **k) 190 if not wait: return 191 try: 192 await p 193 if f: raise ExceptionGroup(f'errors occurred in publishing middlewares of {self.name}', f) from None 194 L.info('%s: publishing of event %r succeeded', self.name, event_type); L.debug('final data: %r', data) 195 except TimeoutError: raise A.BusTimeout(f'publishing of event {event_type!r} in {self.name} took too long') from None 196 finally: await A.safe_cancel(p)
[docs] 197 def sync_start_publish(self, event_type, data=None, *, safe=None, timeout=None, chaperone=None): 198 self.raise_for_shutdown(); f = [] 199 if safe is None: safe = A.getcontext().EVENT_BUS_PUBLISH_DEFAULT_SAFE 200 async def g(C=(lambda e, /, a=f.extend, b=f.append: a(e.exceptions) if isinstance(e, BaseExceptionGroup) else b(e)) if chaperone is None else chaperone, D=data): 201 for t in self._middlewares: 202 if t is None: continue 203 m, F = t 204 if F is not None and F.done(): continue 205 try: 206 if I.iscoroutine(D := m(event_type, D)): D = await D 207 except A.CRITICAL: raise A.Critical 208 except (ExceptionGroup, Exception) as e: C(e) # noqa: BLE001 209 except BaseException as e: raise A.BusPublishingError(self, m) from e # ty: ignore[invalid-argument-type] 210 U = self._subscribers 211 if self._tracking: self._published[event_type] += 1 212 s, w = (U[_].copy() for _ in (event_type, None)) 213 await I.gather((f := partial(self._publish_helper, D, safe))(s), f(w, event_type)) 214 (P := self._publishers).add(p := self.make(I.wait_for(g(), timeout))); p.add_done_callback(lambda p, d=P.discard: d(p)); return p, f
[docs] 215 async def wait_for_event(self, event_type, *, timeout=None, condition=lambda _: True): 216 async def handler(d): 217 if F.done(): return 218 if I.iscoroutine(c := condition(d)): c = await c 219 if c: F.set_result(d) 220 return self.make(I.wait_for(await self.subscribe_until(F := self.loop.create_future(), handler, event_type), timeout))
[docs] 221 def subscribe_until(self, fut, subscriber, event_type=None, *, till_permanent=None, _=A.ignore_cancellation.combined(TimeoutError)): # noqa: B008 222 if fut.done(): raise RuntimeError('asyncutils.channels.EventBus.subscribe_until: fut is already done') 223 async def f(): 224 with _: r = await I.wait_for(fut, till_permanent); self.unsubscribe(subscriber, event_type); return r 225 self.subscribe(subscriber, event_type); return self.make(f())
[docs] 226 async def feed_event(self, *d, timeout=None): 227 if (q := self.stream_queue).full(): L.warning('event stream buffer full') 228 try: await I.wait_for(q.put(d[0] if len(d) == 1 else d), timeout) 229 except QueueShutDown: L.info('event stream is closing', exc_info=True) 230 except TimeoutError: 231 if q.full(): L.warning('event stream data lost', exc_info=True); q.get_nowait(); q.put_nowait(d)
[docs] 232 async def event_stream(self, event_type=None, *, timeout=_NO_DEFAULT, item_timeout=_NO_DEFAULT, bufsize=None): 233 self.raise_for_shutdown() 234 if not self._auditing: audit('asyncutils.channels.EventBus.event_stream', id(self), event_type) 235 t = await self.subscribe_until(F := self.loop.create_future(), partial(self.feed_event, timeout=A.getcontext().EVENT_BUS_STREAM_DEFAULT_TIMEOUT if timeout is _NO_DEFAULT else timeout), event_type); self.stream_queue = q = Queue(A.getcontext().EVENT_BUS_STREAM_DEFAULT_BUFFER_SIZE if bufsize is None else bufsize) 236 if _NO_DEFAULT.is_(item_timeout): item_timeout = A.getcontext().EVENT_BUS_STREAM_DEFAULT_ITEM_TIMEOUT 237 try: 238 while True: yield await I.wait_for(q.get(), item_timeout) 239 except QueueShutDown: L.info('event stream of %s has been shut down', self.name, exc_info=True) 240 except TimeoutError: L.exception('event stream of %s is stopping because of timeout in waiting for item', self.name) 241 finally: F.set_result(None); await t
[docs] 242 async def shutdown(self, immediate=False, *, timeout=None, preserve_stats=False): 243 if self._is_shutdown: return 244 self._is_shutdown, f = True, self._sem.acquire; self.stop_audit(); self._middlewares.clear() 245 self.clear() 246 if not preserve_stats: self.clear_stats() 247 try: 248 async with I.timeout(timeout): 249 self.stream_queue.shutdown(immediate) 250 for _ in repeat(None, self.active_tasks): await f() 251 except TimeoutError: L.exception('%s shutdown timed out, some tasks may be incomplete', self.name) 252 finally: 253 if p := self._publishers: await A.safe_cancel_batch(p) 254 del self._lock, self._handler, self._sem, self._publishers
[docs] 255 async def handle_exception(self, e): 256 if I.iscoroutine(e := self._handler(e)): await e
[docs] 257 def clear(self, event_type=_NO_DEFAULT): return self._subscribers.clear() if event_type is _NO_DEFAULT else self._subscribers.pop(event_type, None)
[docs] 258 def clear_all(self): self.clear(); self.clear_stats()
[docs] 259 def clear_wildcards(self): return self.clear(None)
[docs] 260 def clear_stats(self): self._published.clear()
261 async def _safe_callback(self, c, d, t=None, i=None): 262 try: 263 async with self._sem: 264 if I.iscoroutine(r := c(*filter_out(t, s=_NO_DEFAULT), d)): await I.wait_for(r, i) 265 except TimeoutError: L.warning('callback %s timed out', fullname(c), exc_info=True) 266 except A.CRITICAL: raise A.Critical 267 except BaseException as e: await self.handle_exception(e) # noqa: BLE001
[docs] 268 async def __setup__(self): super().__init__()
[docs] 269 def __cleanup__(self): return self.shutdown(immediate=True)
270 P.patch_classmethod_signatures((_ := lambda _, /, f='#%d', c=__import__('itertools').count(1).__next__: f%c(), '')); P.patch_method_signatures((__init__, 'name=None, *, handler=None, max_concurrent=128, tracking_stats=False'), (subscribe_until, 'fut, subscriber, event_type=None, *, till_permanent=None')); WILDCARD, _inc_cnt = None, classmethod(_); del _, c # noqa: B008
[docs] 271@subscriptable 272class Rendezvous: 273 __slots__ = '_getters', '_lock', '_loop', '_putters', '_task' 274 def __init__(self, *, loop=None, lock=None): self._getters, self._putters, self._loop, self._lock = deque(), deque(), get_loop_and_set() if loop is None else loop, I.Lock() if lock is None else lock; self._make_task() 275 async def _maintainer(self): 276 f, g = I.sleep.__get__(A.getcontext().RENDEZVOUS_MAINTENANCE_INTERVAL), self.cleanup 277 while True: await f(); g()
[docs] 278 async def put(self, v, /, *, timeout=None): 279 try: await self.raising_put(v, timeout=timeout); return True 280 except (I.CancelledError, TimeoutError): return False
[docs] 281 async def raising_put(self, v, /, *, timeout): await I.wait_for(await I.shield(self._put_helper(v)), timeout)
[docs] 282 async def get(self, default=_NO_DEFAULT, *, timeout=None, _=100): 283 f = (p := self._putters).popleft 284 while p: 285 v, F = f() 286 if not F.done(): F.set_result(None); return v 287 if timeout is None and default is not _NO_DEFAULT: return default 288 self._getters.append(F := self._loop.create_future()) 289 try: return await I.wait_for(F, timeout) 290 except TimeoutError: 291 if default is _NO_DEFAULT: raise 292 return default
[docs] 293 def __length_hint__(self): return len(self._getters)+len(self._putters)
[docs] 294 def state_snapshot(self, _=namedtuple('StateSnapshot', 'num_getters num_putters num_ops idle', module='asyncutils.channels')): self.cleanup(); t = len(self._getters), len(self._putters); return _(*t, sum(t), not any(t))
[docs] 295 def cleanup(self): self._getters, self._putters = deque(F for F in self._getters if not F.done()), deque(t for t in self._putters if not t[1].done())
[docs] 296 async def exchange(self, v, /, *, timeout=None, asap=False): 297 g, f = self._getters, True 298 async with I.timeout(timeout): 299 async with self._lock: 300 while g: 301 if not (F := g.popleft()).done(): break 302 else: g.append(F := self._loop.create_future()); f = False 303 if f: F.set_result(v); return await self.get() 304 await (self._put_helper if asap else self.put)(v); g.appendleft(F); return await F
305 async def _put_helper(self, v, /): 306 g = self._getters 307 async with self._lock: 308 while g: 309 if not (F := g.popleft()).done(): F.set_result(v); break 310 else: self._putters.append((v, F := self._loop.create_future())) 311 return F
[docs] 312 async def reset(self, _=partial(A.safe_cancel_batch, disembowel=True)): 313 async with self._lock: await I.gather(A.safe_cancel_batch(self._getters, disembowel=True), A.safe_cancel_batch(F async for _, F in A.adisembowel(self._putters))) 314 await A.safe_cancel(self._task); self._make_task()
315 def _make_task(self): self._task = self._loop.create_task(self._maintainer()) 316 P.patch_method_signatures((reset, ''), (state_snapshot, ''), (_maintainer, ''))
317del P