Skip to content

Commit

Permalink
Port over gh-91048 for awaiter support in Meta Python 3.12
Browse files Browse the repository at this point in the history
Summary:
Context: We are deciding to stick with the old "bpf" way of doing async stack walking for Meta Python 3.12, and maybe 3.13. This involves continuing support for the awaiter pointer which is why we need this port. Ideally we get the python stack from the runtime and there is expected to be an internal implementation of this for Meta Python 3.13 or 3.14 latest.

This is essentially a copy-paste port of https://github.com/python/cpython/pull/103976/files from last year. The main change of this diff from the PR is that gi_awaiter is renamed to gi_ci_awaiter

Reviewed By: jbower-fb

Differential Revision: D57072594

fbshipit-source-id: ca0d2ee0ffca2a72e16d479ca395555ed13498af
  • Loading branch information
Aniket Panse authored and facebook-github-bot committed Jul 16, 2024
1 parent b39f5b4 commit 36a1893
Show file tree
Hide file tree
Showing 20 changed files with 924 additions and 2 deletions.
11 changes: 11 additions & 0 deletions Include/cpython/genobject.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@
extern "C" {
#endif

static inline int
_PyAwaitable_SetAwaiter(PyObject *receiver, PyObject *awaiter)
{
PyTypeObject *ty = Py_TYPE(receiver);
PyAsyncMethods *am = (PyAsyncMethods *) ty->tp_as_async;
if ((am != NULL) && (am->am_set_awaiter != NULL)) {
return am->am_set_awaiter(receiver, awaiter);
}
return 0;
}

/* --- Generators --------------------------------------------------------- */

/* _PyGenObject_HEAD defines the initial segment of generator
Expand Down
2 changes: 2 additions & 0 deletions Include/cpython/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,14 @@ typedef struct {
} PyMappingMethods;

typedef PySendResult (*sendfunc)(PyObject *iter, PyObject *value, PyObject **result);
typedef int (*setawaiterfunc)(PyObject *iter, PyObject *awaiter);

typedef struct {
unaryfunc am_await;
unaryfunc am_aiter;
unaryfunc am_anext;
sendfunc am_send;
setawaiterfunc am_set_awaiter;
} PyAsyncMethods;

typedef struct {
Expand Down
1 change: 1 addition & 0 deletions Include/internal/pycore_global_objects_fini_generated.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Include/internal/pycore_global_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ struct _Py_global_strings {
STRUCT_FOR_ID(__rtruediv__)
STRUCT_FOR_ID(__rxor__)
STRUCT_FOR_ID(__set__)
STRUCT_FOR_ID(__set_awaiter__)
STRUCT_FOR_ID(__set_name__)
STRUCT_FOR_ID(__setattr__)
STRUCT_FOR_ID(__setitem__)
Expand Down
1 change: 1 addition & 0 deletions Include/internal/pycore_runtime_init_generated.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Include/internal/pycore_unicodeobject_generated.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions Include/typeslots.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,7 @@
/* New in 3.10 */
#define Py_am_send 81
#endif
#if !defined(Py_LIMITED_API) || Py_LIMITED_API+0 >= 0x030C0000
/* New in 3.12 */
#define Py_am_set_awaiter 82
#endif
63 changes: 63 additions & 0 deletions Lib/asyncio/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
'current_task', 'all_tasks',
'create_eager_task_factory', 'eager_task_factory',
'_register_task', '_unregister_task', '_enter_task', '_leave_task',
'get_async_stack',
)

import concurrent.futures
Expand All @@ -16,6 +17,7 @@
import inspect
import itertools
import types
import sys
import warnings
import weakref
from types import GenericAlias
Expand Down Expand Up @@ -732,6 +734,11 @@ def cancel(self, msg=None):
self._cancel_requested = True
return ret

def __set_awaiter__(self, awaiter):
for child in self._children:
if hasattr(child, "__set_awaiter__"):
child.__set_awaiter__(awaiter)


def gather(*coros_or_futures, return_exceptions=False):
"""Return a future aggregating results from the given coroutines/futures.
Expand Down Expand Up @@ -956,6 +963,62 @@ def callback():
return future


def get_async_stack():
"""Return the async call stack for the currently executing task as a list of
frames, with the most recent frame last.
The async call stack consists of the call stack for the currently executing
task, if any, plus the call stack formed by the transitive set of coroutines/async
generators awaiting the current task.
Consider the following example, where T represents a task, C represents
a coroutine, and A '->' B indicates A is awaiting B.
T0 +---> T1
| | |
C0 | C2
| | |
v | v
C1 | C3
| |
+-----|
The await stack from C3 would be C3, C2, C1, C0. In contrast, the
synchronous call stack while C3 is executing is only C3, C2.
"""
if not hasattr(sys, "_getframe"):
return []

task = current_task()
coro = task.get_coro()
coro_frame = coro.cr_frame

# Get the active portion of the stack
stack = []
frame = sys._getframe().f_back
while frame is not None:
stack.append(frame)
if frame is coro_frame:
break
frame = frame.f_back
assert frame is coro_frame

# Get the suspended portion of the stack
awaiter = coro.cr_awaiter
while awaiter is not None:
if hasattr(awaiter, "cr_frame"):
stack.append(awaiter.cr_frame)
awaiter = awaiter.cr_awaiter
elif hasattr(awaiter, "ag_frame"):
stack.append(awaiter.ag_frame)
awaiter = awaiter.ag_awaiter
else:
raise ValueError(f"Unexpected awaiter {awaiter}")

stack.reverse()
return stack


# WeakSet containing all alive tasks.
_all_tasks = weakref.WeakSet()


def create_eager_task_factory(custom_task_constructor):
"""Create a function suitable for use as a task factory on an event-loop.
Expand Down
34 changes: 34 additions & 0 deletions Lib/test/test_asyncgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1903,5 +1903,39 @@ async def run():
self.loop.run_until_complete(run())


class AsyncGeneratorAwaiterTest(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)

def tearDown(self):
self.loop.close()
self.loop = None
asyncio.set_event_loop_policy(None)

def test_basic_await(self):
async def async_gen():
self.assertIs(agen_obj.ag_awaiter, awaiter_obj)
yield 10

async def awaiter(agen):
async for x in agen:
pass

agen_obj = async_gen()
awaiter_obj = awaiter(agen_obj)
self.assertIsNone(agen_obj.ag_awaiter)
self.loop.run_until_complete(awaiter_obj)

def test_set_invalid_awaiter(self):
async def async_gen():
yield "hi"

agen_obj = async_gen()
msg = "awaiter must be None, a coroutine, or an async generator"
with self.assertRaisesRegex(TypeError, msg):
agen_obj.__set_awaiter__("testing 123")


if __name__ == "__main__":
unittest.main()
86 changes: 86 additions & 0 deletions Lib/test/test_asyncio/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2489,6 +2489,24 @@ def test_get_context(self):
finally:
loop.close()

def test_get_awaiter(self):
ctask = getattr(tasks, '_CTask', None)
if ctask is None or not issubclass(self.Task, ctask):
self.skipTest("Only subclasses of _CTask set cr_awaiter on wrapped coroutines")

async def coro():
self.assertIs(coro_obj.cr_awaiter, awaiter_obj)
return "ok"

async def awaiter(coro):
task = self.loop.create_task(coro)
return await task

coro_obj = coro()
awaiter_obj = awaiter(coro_obj)
self.assertIsNone(coro_obj.cr_awaiter)
self.assertEqual(self.loop.run_until_complete(awaiter_obj), "ok")
self.assertIsNone(coro_obj.cr_awaiter)

def add_subclass_tests(cls):
BaseTask = cls.Task
Expand Down Expand Up @@ -3237,6 +3255,22 @@ async def coro(s):
# NameError should not happen:
self.one_loop.call_exception_handler.assert_not_called()

def test_propagate_awaiter(self):
async def coro(idx):
self.assertIs(coro_objs[idx].cr_awaiter, awaiter_obj)
return "ok"

async def awaiter(coros):
tasks = [self.one_loop.create_task(c) for c in coros]
return await asyncio.gather(*tasks)

coro_objs = [coro(0), coro(1)]
awaiter_obj = awaiter(coro_objs)
self.assertIsNone(coro_objs[0].cr_awaiter)
self.assertIsNone(coro_objs[1].cr_awaiter)
self.assertEqual(self.one_loop.run_until_complete(awaiter_obj), ["ok", "ok"])
self.assertIsNone(coro_objs[0].cr_awaiter)
self.assertIsNone(coro_objs[1].cr_awaiter)

class RunCoroutineThreadsafeTests(test_utils.TestCase):
"""Test case for asyncio.run_coroutine_threadsafe."""
Expand Down Expand Up @@ -3449,5 +3483,57 @@ def tearDown(self):
super().tearDown()



class GetAsyncStackTests(test_utils.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)

def tearDown(self):
self.loop.close()
self.loop = None
asyncio.set_event_loop_policy(None)

def check_stack(self, frames, expected_funcs):
given = [f.f_code for f in frames]
expected = [f.__code__ for f in expected_funcs]
self.assertEqual(given, expected)

def test_single_task(self):
async def coro():
await coro2()

async def coro2():
stack = asyncio.get_async_stack()
self.check_stack(stack, [coro, coro2])

self.loop.run_until_complete(coro())

def test_cross_tasks(self):
async def coro():
t = asyncio.ensure_future(coro2())
await t

async def coro2():
t = asyncio.ensure_future(coro3())
await t

async def coro3():
stack = asyncio.get_async_stack()
self.check_stack(stack, [coro, coro2, coro3])

self.loop.run_until_complete(coro())

def test_cross_gather(self):
async def coro():
await asyncio.gather(coro2(), coro2())

async def coro2():
stack = asyncio.get_async_stack()
self.check_stack(stack, [coro, coro2])

self.loop.run_until_complete(coro())


if __name__ == '__main__':
unittest.main()
69 changes: 69 additions & 0 deletions Lib/test/test_coroutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2474,5 +2474,74 @@ async def foo():
self.assertEqual(foo().send(None), 1)



class CoroutineAwaiterTest(unittest.TestCase):
def test_basic_await(self):
async def coro():
self.assertIs(coro_obj.cr_awaiter, awaiter_obj)
return "success"

async def awaiter():
return await coro_obj

coro_obj = coro()
awaiter_obj = awaiter()
self.assertIsNone(coro_obj.cr_awaiter)
self.assertEqual(run_async(awaiter_obj), ([], "success"))

class FakeFuture:
def __await__(self):
return iter(["future"])

def test_coro_outlives_awaiter(self):
async def coro():
await self.FakeFuture()

async def awaiter(cr):
await cr

coro_obj = coro()
self.assertIsNone(coro_obj.cr_awaiter)
awaiter_obj = awaiter(coro_obj)
self.assertIsNone(coro_obj.cr_awaiter)

v1 = awaiter_obj.send(None)
self.assertEqual(v1, "future")
self.assertIs(coro_obj.cr_awaiter, awaiter_obj)

awaiter_id = id(awaiter_obj)
del awaiter_obj
self.assertEqual(id(coro_obj.cr_awaiter), awaiter_id)

def test_async_gen_awaiter(self):
async def coro():
self.assertIs(coro_obj.cr_awaiter, agen)
await self.FakeFuture()

async def async_gen(cr):
await cr
yield "hi"

coro_obj = coro()
self.assertIsNone(coro_obj.cr_awaiter)
agen = async_gen(coro_obj)
self.assertIsNone(coro_obj.cr_awaiter)

v1 = agen.asend(None).send(None)
self.assertEqual(v1, "future")

def test_set_invalid_awaiter(self):
async def coro():
return True

coro_obj = coro()
msg = "awaiter must be None, a coroutine, or an async generator"
with self.assertRaisesRegex(TypeError, msg):
coro_obj.__set_awaiter__("testing 123")
run_async(coro_obj)




if __name__=="__main__":
unittest.main()
Loading

0 comments on commit 36a1893

Please sign in to comment.