1import asyncio as I, asyncutils as A, asyncutils._internal.log as L
2from asyncutils.constants import _NO_DEFAULT
3from asyncutils._internal.helpers import check_methods, fullname, get_loop_and_set
4from asyncutils._internal.submodules import locksmiths_all as __all__
5from enum import IntEnum
6from sys import audit
7ForceResult, RecognitionResult = IntEnum('ForceResult', 'UNFORCABLE NO_CURRENT_TASK OWNER_COMPLETED ALREADY_BEING_FORCED FAILURE RELEASED_WITH_FALSE SUCCESS RELEASED', module=__name__), IntEnum('RecognitionResult', 'FAILED_PRELIM FAILED_ACK ALREADY_RECOGNIZED SUCCESS', module=__name__)
8succeeded = frozenset((ForceResult.SUCCESS, ForceResult.RELEASED, RecognitionResult.ALREADY_RECOGNIZED, RecognitionResult.SUCCESS)).__contains__
[docs]
9class LocksmithBase:
10 __slots__ = '_lock', '_loop', '_recognized'; handlers = {} # noqa: RUF012
[docs]
11 @classmethod
12 def register_handler(cls, h, /, *, shadow=True):
13 def register(t, H=cls.handlers, h=h):
14 if not isinstance(t, type): raise TypeError('asyncutils.locksmiths.LocksmithBase: non-type cannot be registered')
15 if shadow: H[t] = h
16 elif h is not (h := H.setdefault(t, h)): raise KeyError('asyncutils.locksmiths.LocksmithBase: handler for type already registered', t, h)
17 return t
18 return register
19 @property
20 def currently_recognized(self): return frozenset(self._recognized)
21 def __init__(self, loop=None, ltyp=I.Lock): self._recognized, self._loop, self._lock = __import__('_weakrefset').WeakSet((l := ltyp(),)), loop or get_loop_and_set(), l
[docs]
22 async def recognize_lock(self, l, /):
23 if not self.preliminary_check_lock(l): return RecognitionResult.FAILED_PRELIM
24 async with self._lock:
25 if l in (r := self._recognized): return RecognitionResult.ALREADY_RECOGNIZED
26 if callable(f := getattr(l, 'acknowledge_locksmith_lock_held', None)):
27 try: return bool((await f) if I.iscoroutine(f := f(self)) else f)
28 except A.CRITICAL: raise A.Critical
29 except: return RecognitionResult.FAILED_ACK
30 r.add(l); return RecognitionResult.SUCCESS
[docs]
31 async def force(self, l, /, info=_NO_DEFAULT, *, purge_waiters=True):
32 audit('asyncutils.locksmiths.LocksmithBase.force', id(self), id(l))
33 async with self._lock:
34 if not self.can_force_lock_held(l): return ForceResult.UNFORCABLE
35 if info is _NO_DEFAULT: info = await self.get_info(l)
36 try:
37 if I.iscoroutine(r := l.release()): r = await r
38 except A.CRITICAL: raise A.Critical
39 except: return await self._force_except(l, info)
40 else: return await self.release_returned_false(l) if r is False else ForceResult.RELEASED
41 finally:
42 if purge_waiters: await self.purge_waiters(l)
43 async def _force_except(self, l, i, /):
44 if self.find_owner(l) is (o := I.current_task(self._loop)) and (r := await self._force_is_owner(l, i, o)): return r
45 try:
46 if callable(f := self.handlers.get(type(l))) and I.iscoroutine(r := f(l)): await r
47 except A.CRITICAL: raise A.Critical
48 return ForceResult.SUCCESS
49 async def _force_is_owner(self, l, i, o, /):
50 if o is None: return await self.throw_fallback(l)
51 if (c := o.get_coro()) is None: return await self.eager_fallback(l)
52 E = A.LockForceRequest(self, (F := self._loop.create_future()).set_result, l, i) # ty: ignore[invalid-argument-type]
53 try: c.throw(E)
54 except A.CRITICAL as e: return self.task_raised_critical(l, e)
55 except A.LockForceRequest as e:
56 if (r := e.requester) is not self: await self.lock_busy(l, r, {})
57 elif e is E: await self.task_reraised_request(l)
58 else: return await self.already_forcing(l)
59 except BaseException as e: await self.task_raised_other(l, e) # noqa: BLE001
60 else: await self.answer_received(l, await F)
[docs]
61 async def purge_waiters(self, l, /):
62 if w := getattr(l, '_waiters', None): await A.safe_cancel_batch(w, disembowel=True)
[docs]
63 async def host(self, t, l, /, *, timeout1=_NO_DEFAULT, timeout2=_NO_DEFAULT, timeout3=_NO_DEFAULT):
64 await I.wait(f := tuple(map(self.wrap_task, (self.force(l, purge_waiters=False), l.acquire()))), return_when='FIRST_COMPLETED'); f, a, T = *f, A.getcontext().LOCKSMITH_BASE_DEFAULT_TIMEOUTS
65 if await I.wait_for(f, T[0] if timeout1 is _NO_DEFAULT else timeout1): await a
66 else:
67 try: await I.wait_for(a, T[1] if timeout2 is _NO_DEFAULT else timeout2)
68 except TimeoutError: raise TimeoutError(f'{fullname(self)}.host: failed to acquire lock {l!r} within {timeout2} seconds') from None
69 self.patch_owner(t := self.wrap_task(t), l); return await I.wait_for(self._wait_on(t, l), T[2] if timeout3 is _NO_DEFAULT else timeout3)
[docs]
70 async def get_info(self, l, /): return f'potential deadlock situation involving {fullname(l)} at {id(l):#x}'
[docs]
71 async def lock_busy(self, l, r, _, /): await A.transient_block(self._loop, L.info, 'lock busy: %r; requesters: %r, %r', l, self, r)
[docs]
72 async def task_reraised_request(self, l, /): await A.transient_block(self._loop, L.warning, '%s.force: running task did not handle request to release %s at %#x properly', fullname(self), fullname(l), id(l))
[docs]
73 async def answer_received(self, l, a, /): await A.transient_block(self._loop, L.info, '%r received answer %r from %r', self, a, l)
[docs]
74 async def throw_fallback(self, _, /): return ForceResult.NO_CURRENT_TASK
[docs]
75 async def eager_fallback(self, _, /): return ForceResult.OWNER_COMPLETED
[docs]
76 async def release_returned_false(self, _, /): return ForceResult.RELEASED_WITH_FALSE
[docs]
77 async def already_forcing(self, _, /): return ForceResult.ALREADY_BEING_FORCED
[docs]
78 async def _wait_on(self, t, l, /):
79 try: return await t
80 finally:
81 if l.locked() and I.iscoroutine(a := l.release()): await a
[docs]
82 async def task_raised_other(self, l, e, /):
83 if not isinstance(e, RuntimeError): await A.transient_block(self._loop, L.error, 'error encountered in attempt to force %s at %#x', fullname(l), id(l), exc_info=e)
[docs]
84 def wrap_task(self, a, /): return self._loop.create_task(A.wrap_in_coro(a))
[docs]
85 def patch_owner(self, t, l, /):
86 if hasattr(l, '_owner'): l._owner = t
[docs]
87 def find_owner(self, l, /): return getattr(l, '_owner', None)
[docs]
88 def preliminary_check_lock(self, l, /): return check_methods(l, 'acquire', 'release', 'locked')
[docs]
89 def task_raised_critical(self, _, e, /): raise A.Critical(e) from None
[docs]
90 def can_force_lock_held(self, l, /): return l in self._recognized and l.locked()