diff --git a/akernel/akernel.py b/akernel/akernel.py index 8313fbc..9df1a17 100644 --- a/akernel/akernel.py +++ b/akernel/akernel.py @@ -16,14 +16,14 @@ def install( None, "-c", help="Path to the cache directory, if mode is 'cache'." ), ): - name = "akernel" + kernel_name = "akernel" if mode: modes = mode.split("-") modes.sort() mode = "-".join(modes) - name += f"-{mode}" - display_name = f"Python 3 ({name})" - write_kernelspec(name, mode, display_name, cache_dir) + kernel_name += f"-{mode}" + display_name = f"Python 3 ({kernel_name})" + write_kernelspec(kernel_name, mode, display_name, cache_dir) @cli.command() diff --git a/akernel/cache.py b/akernel/cache.py index 6be5ee2..3d5361b 100644 --- a/akernel/cache.py +++ b/akernel/cache.py @@ -1,13 +1,14 @@ +from __future__ import annotations + import os import pickle import sys import zlib -from typing import Optional from zict import File, Func, LRU # type: ignore -def cache(cache_dir: Optional[str]): +def cache(cache_dir: str | None): if not cache_dir: cache_dir = os.path.join( sys.prefix, "share", "jupyter", "kernels", "akernel", "cache" diff --git a/akernel/code.py b/akernel/code.py index 80dec3c..43b11fd 100644 --- a/akernel/code.py +++ b/akernel/code.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import gast # type: ignore @@ -90,8 +92,9 @@ def get_return_body(val): class Transform: - def __init__(self, code: str, react: bool = False) -> None: + def __init__(self, code: str, task_i: int | None = None, react: bool = False) -> None: self.gtree = gast.parse(code) + self.task_i = task_i self.react = react c = GlobalUseCollector() c.visit(self.gtree) @@ -115,9 +118,10 @@ def get_async_ast(self) -> gast.Module: new_body += self.gtree.body + body_globals_update_locals + last_statement else: new_body += self.gtree.body + body_globals_update_locals + name = "__async_cell__" if self.task_i is None else f"__async_cell{self.task_i}__" body = [ gast.AsyncFunctionDef( - name="__async_cell__", + name=name, args=gast.arguments( args=[], posonlyargs=[], diff --git a/akernel/execution.py b/akernel/execution.py index bb01090..bea45eb 100644 --- a/akernel/execution.py +++ b/akernel/execution.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import hashlib import pickle -from typing import List, Dict, Tuple, Any, Optional +from typing import List, Dict, Tuple, Any from colorama import Fore, Style # type: ignore @@ -12,16 +14,17 @@ def pre_execute( code: str, globals_: Dict[str, Any], locals_: Dict[str, Any], + task_i: int | None = None, execution_count: int = 0, react: bool = False, - cache: Optional[Dict[str, Any]] = None, -) -> Tuple[List[str], Optional[SyntaxError], Dict[str, Any]]: + cache: Dict[str, Any] | None = None, +) -> Tuple[List[str], SyntaxError | None, Dict[str, Any]]: traceback = [] exception = None cache_info: Dict[str, Any] = {"cached": False} try: - transform = Transform(code, react) + transform = Transform(code, task_i, react) async_bytecode = transform.get_async_bytecode() exec(async_bytecode, globals_, locals_) except SyntaxError as e: @@ -100,7 +103,7 @@ def pre_execute( def cache_execution( - cache: Optional[Dict[str, Any]], + cache: Dict[str, Any] | None, cache_info: Dict[str, Any], globals_: Dict[str, Any], result: Any, @@ -132,9 +135,8 @@ async def execute( code: str, globals_: Dict[str, Any], locals_: Dict[str, Any], - chain: bool = False, react: bool = False, - cache: Optional[Dict[str, Any]] = None, + cache: Dict[str, Any] | None = None, ) -> Tuple[Any, List[str], bool]: result = None interrupted = False @@ -147,8 +149,6 @@ async def execute( if cache_info["cached"]: result = cache_info["result"] else: - if chain: - await locals_["__task__"]() try: result = await locals_["__async_cell__"]() except KeyboardInterrupt: diff --git a/akernel/kernel.py b/akernel/kernel.py index ee24b03..e6db2cf 100644 --- a/akernel/kernel.py +++ b/akernel/kernel.py @@ -4,7 +4,7 @@ import json from io import StringIO from contextvars import ContextVar -from typing import Dict, Any, List, Optional, Union, Awaitable, cast +from typing import Dict, Any, List, Union, Awaitable, cast from zmq.asyncio import Socket import comm # type: ignore @@ -41,22 +41,23 @@ class Kernel: key: str comm_manager: CommManager kernel_mode: str + cell_done: Dict[int, asyncio.Event] running_cells: Dict[int, asyncio.Task] task_i: int execution_count: int execution_state: str globals: Dict[str, Dict[str, Any]] locals: Dict[str, Dict[str, Any]] - _multi_kernel: Optional[bool] - _cache_kernel: Optional[bool] - _react_kernel: Optional[bool] + _multi_kernel: bool | None + _cache_kernel: bool | None + _react_kernel: bool | None kernel_initialized: List[str] - cache: Optional[Dict[str, Any]] + cache: Dict[str, Any] | None def __init__( self, kernel_mode: str, - cache_dir: Optional[str], + cache_dir: str | None, connection_file: str, ): global KERNEL @@ -78,6 +79,7 @@ def get_comm_manager(): self.globals = {} self.locals = {} self._chain_execution = not self.concurrent_kernel + self.cell_done = {} self.running_cells = {} self.task_i = 0 self.execution_count = 1 @@ -258,6 +260,7 @@ async def listen_shell(self) -> None: code, self.globals[namespace], self.locals[namespace], + self.task_i, self.execution_count, react=self.react_kernel, cache=self.cache, @@ -289,6 +292,7 @@ async def listen_shell(self) -> None: cache_info, ) ) + self.cell_done[self.task_i] = asyncio.Event() self.running_cells[self.task_i] = task self.task_i += 1 self.execution_count += 1 @@ -353,15 +357,17 @@ async def execute_and_finish( code: str, cache_info: Dict[str, Any], ) -> None: - if self._chain_execution: - await self.task() + prev_task_i = task_i - 1 + if self._chain_execution and prev_task_i in self.cell_done: + await self.cell_done[prev_task_i].wait() + del self.cell_done[prev_task_i] PARENT_VAR.set(parent) IDENTS_VAR.set(idents) parent_header = parent["header"] traceback, exception = [], None namespace = self.get_namespace(parent_header) try: - result = await self.locals[namespace]["__async_cell__"]() + result = await self.locals[namespace][f"__async_cell{task_i}__"]() except KeyboardInterrupt: self.interrupt() except Exception as e: @@ -371,6 +377,8 @@ async def execute_and_finish( self.show_result(result, self.globals[namespace], parent_header) cache_execution(self.cache, cache_info, self.globals[namespace], result) finally: + self.cell_done[task_i].set() + del self.locals[namespace][f"__async_cell{task_i}__"] self.finish_execution( idents, parent_header, @@ -385,8 +393,8 @@ def finish_execution( self, idents: List[bytes], parent_header: Dict[str, Any], - execution_count: Optional[int], - exception: Optional[Exception] = None, + execution_count: int | None, + exception: Exception | None = None, no_exec: bool = False, traceback: List[str] = [], result=None, diff --git a/akernel/kernelspec.py b/akernel/kernelspec.py index b245a45..c0a8644 100644 --- a/akernel/kernelspec.py +++ b/akernel/kernelspec.py @@ -1,11 +1,12 @@ +from __future__ import annotations + import os import sys import json -from typing import Optional def write_kernelspec( - dir_name: str, mode: str, display_name: str, cache_dir: Optional[str] + dir_name: str, mode: str, display_name: str, cache_dir: str | None ) -> None: argv = ["akernel", "launch"] if mode: diff --git a/akernel/message.py b/akernel/message.py index d9d313e..730103a 100644 --- a/akernel/message.py +++ b/akernel/message.py @@ -2,7 +2,7 @@ import hmac import hashlib from datetime import datetime, timezone -from typing import List, Dict, Tuple, Any, Optional, cast +from typing import List, Dict, Tuple, Any, cast from zmq.utils import jsonapi from zmq.asyncio import Socket @@ -86,7 +86,7 @@ def serialize(msg: Dict[str, Any], key: str, address: bytes = b"") -> List[bytes async def receive_message( sock: Socket, timeout: float = float("inf") -) -> Optional[Tuple[List[bytes], Dict[str, Any]]]: +) -> Tuple[List[bytes], Dict[str, Any]] | None: timeout *= 1000 # in ms ready = await sock.poll(timeout) if ready: @@ -101,7 +101,7 @@ def send_message( sock: Socket, key: str, address: bytes = b"", - buffers: Optional[List] = None, + buffers: List | None = None, ) -> None: to_send = serialize(msg, key, address) buffers = buffers or [] diff --git a/akernel/tests/conftest.py b/akernel/tests/conftest.py index fc53e83..cc0447b 100644 --- a/akernel/tests/conftest.py +++ b/akernel/tests/conftest.py @@ -12,6 +12,9 @@ def all_modes(request): mode = request.param kernel_name = "akernel" if mode: + modes = mode.split("-") + modes.sort() + mode = "-".join(modes) kernel_name += f"-{mode}" if mode == "cache": cache_dir = os.path.join( @@ -19,4 +22,4 @@ def all_modes(request): ) shutil.rmtree(cache_dir, ignore_errors=True) display_name = f"Python 3 ({kernel_name})" - write_kernelspec("akernel", mode, display_name, None) + write_kernelspec(kernel_name, mode, display_name, None) diff --git a/akernel/tests/test_execution.py b/akernel/tests/test_execution.py index b0e4a5a..9bd81c4 100644 --- a/akernel/tests/test_execution.py +++ b/akernel/tests/test_execution.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import sys import time from textwrap import dedent import re from math import sin -from typing import List, Dict, Tuple, Any, Optional +from typing import List, Dict, Tuple, Any import pytest @@ -12,16 +14,15 @@ async def run( code: str, - globals_: Optional[Dict[str, Any]] = None, - chain: bool = False, + globals_: Dict[str, Any] | None = None, react: bool = False, - cache: Optional[Dict[str, Any]] = None, + cache: Dict[str, Any] | None = None, ) -> Tuple[Any, List[str], bool, Dict[str, Any], Dict[str, Any]]: if globals_ is None: globals_ = {} locals_: Dict[str, Any] = {} result, traceback, interrupted = await execute( - code, globals_, locals_, chain=chain, react=react, cache=cache + code, globals_, locals_, react=react, cache=cache ) if "__builtins__" in globals_: del globals_["__builtins__"] diff --git a/akernel/traceback.py b/akernel/traceback.py index e4e6791..baa1b89 100644 --- a/akernel/traceback.py +++ b/akernel/traceback.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import sys import types -from typing import Optional, cast +from typing import cast from colorama import Fore, Style # type: ignore @@ -13,7 +15,7 @@ def get_traceback(code: str, exception, execution_count: int = 0): break tb = tb.tb_next stack = [] - frame: Optional[types.FrameType] = tb.tb_frame + frame: types.FrameType | None = tb.tb_frame while True: assert frame is not None stack.append(frame) @@ -31,7 +33,7 @@ def get_traceback(code: str, exception, execution_count: int = 0): with open(filename) as f: code = f.read() filename = f"{Fore.CYAN}File{Style.RESET_ALL} {Fore.GREEN}{filename}{Style.RESET_ALL}" - if frame.f_code.co_name == "__async_cell__": + if frame.f_code.co_name.startswith("__async_cell"): name = "" else: name = frame.f_code.co_name