Skip to content

Commit

Permalink
Fix chained execution mode
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Jul 13, 2023
1 parent 5e0e68f commit d267e3f
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 42 deletions.
8 changes: 4 additions & 4 deletions akernel/akernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ def install(
None, "-c", help="Path to the cache directory, if mode is 'cache'."
),
):
name = "akernel"
kernel_name = "akernel"
if mode:
modes = mode.split("-")
modes.sort()
mode = "-".join(modes)
name += f"-{mode}"
display_name = f"Python 3 ({name})"
write_kernelspec(name, mode, display_name, cache_dir)
kernel_name += f"-{mode}"
display_name = f"Python 3 ({kernel_name})"
write_kernelspec(kernel_name, mode, display_name, cache_dir)


@cli.command()
Expand Down
5 changes: 3 additions & 2 deletions akernel/cache.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

import os
import pickle
import sys
import zlib
from typing import Optional

from zict import File, Func, LRU # type: ignore


def cache(cache_dir: Optional[str]):
def cache(cache_dir: str | None):
if not cache_dir:
cache_dir = os.path.join(
sys.prefix, "share", "jupyter", "kernels", "akernel", "cache"
Expand Down
8 changes: 6 additions & 2 deletions akernel/code.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import copy

import gast # type: ignore
Expand Down Expand Up @@ -90,8 +92,9 @@ def get_return_body(val):


class Transform:
def __init__(self, code: str, react: bool = False) -> None:
def __init__(self, code: str, task_i: int | None = None, react: bool = False) -> None:
self.gtree = gast.parse(code)
self.task_i = task_i
self.react = react
c = GlobalUseCollector()
c.visit(self.gtree)
Expand All @@ -115,9 +118,10 @@ def get_async_ast(self) -> gast.Module:
new_body += self.gtree.body + body_globals_update_locals + last_statement
else:
new_body += self.gtree.body + body_globals_update_locals
name = "__async_cell__" if self.task_i is None else f"__async_cell{self.task_i}__"
body = [
gast.AsyncFunctionDef(
name="__async_cell__",
name=name,
args=gast.arguments(
args=[],
posonlyargs=[],
Expand Down
18 changes: 9 additions & 9 deletions akernel/execution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import hashlib
import pickle
from typing import List, Dict, Tuple, Any, Optional
from typing import List, Dict, Tuple, Any

from colorama import Fore, Style # type: ignore

Expand All @@ -12,16 +14,17 @@ def pre_execute(
code: str,
globals_: Dict[str, Any],
locals_: Dict[str, Any],
task_i: int | None = None,
execution_count: int = 0,
react: bool = False,
cache: Optional[Dict[str, Any]] = None,
) -> Tuple[List[str], Optional[SyntaxError], Dict[str, Any]]:
cache: Dict[str, Any] | None = None,
) -> Tuple[List[str], SyntaxError | None, Dict[str, Any]]:
traceback = []
exception = None
cache_info: Dict[str, Any] = {"cached": False}

try:
transform = Transform(code, react)
transform = Transform(code, task_i, react)
async_bytecode = transform.get_async_bytecode()
exec(async_bytecode, globals_, locals_)
except SyntaxError as e:
Expand Down Expand Up @@ -100,7 +103,7 @@ def pre_execute(


def cache_execution(
cache: Optional[Dict[str, Any]],
cache: Dict[str, Any] | None,
cache_info: Dict[str, Any],
globals_: Dict[str, Any],
result: Any,
Expand Down Expand Up @@ -132,9 +135,8 @@ async def execute(
code: str,
globals_: Dict[str, Any],
locals_: Dict[str, Any],
chain: bool = False,
react: bool = False,
cache: Optional[Dict[str, Any]] = None,
cache: Dict[str, Any] | None = None,
) -> Tuple[Any, List[str], bool]:
result = None
interrupted = False
Expand All @@ -147,8 +149,6 @@ async def execute(
if cache_info["cached"]:
result = cache_info["result"]
else:
if chain:
await locals_["__task__"]()
try:
result = await locals_["__async_cell__"]()
except KeyboardInterrupt:
Expand Down
30 changes: 19 additions & 11 deletions akernel/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
from io import StringIO
from contextvars import ContextVar
from typing import Dict, Any, List, Optional, Union, Awaitable, cast
from typing import Dict, Any, List, Union, Awaitable, cast

from zmq.asyncio import Socket
import comm # type: ignore
Expand Down Expand Up @@ -41,22 +41,23 @@ class Kernel:
key: str
comm_manager: CommManager
kernel_mode: str
cell_done: Dict[int, asyncio.Event]
running_cells: Dict[int, asyncio.Task]
task_i: int
execution_count: int
execution_state: str
globals: Dict[str, Dict[str, Any]]
locals: Dict[str, Dict[str, Any]]
_multi_kernel: Optional[bool]
_cache_kernel: Optional[bool]
_react_kernel: Optional[bool]
_multi_kernel: bool | None
_cache_kernel: bool | None
_react_kernel: bool | None
kernel_initialized: List[str]
cache: Optional[Dict[str, Any]]
cache: Dict[str, Any] | None

def __init__(
self,
kernel_mode: str,
cache_dir: Optional[str],
cache_dir: str | None,
connection_file: str,
):
global KERNEL
Expand All @@ -78,6 +79,7 @@ def get_comm_manager():
self.globals = {}
self.locals = {}
self._chain_execution = not self.concurrent_kernel
self.cell_done = {}
self.running_cells = {}
self.task_i = 0
self.execution_count = 1
Expand Down Expand Up @@ -258,6 +260,7 @@ async def listen_shell(self) -> None:
code,
self.globals[namespace],
self.locals[namespace],
self.task_i,
self.execution_count,
react=self.react_kernel,
cache=self.cache,
Expand Down Expand Up @@ -289,6 +292,7 @@ async def listen_shell(self) -> None:
cache_info,
)
)
self.cell_done[self.task_i] = asyncio.Event()
self.running_cells[self.task_i] = task
self.task_i += 1
self.execution_count += 1
Expand Down Expand Up @@ -353,15 +357,17 @@ async def execute_and_finish(
code: str,
cache_info: Dict[str, Any],
) -> None:
if self._chain_execution:
await self.task()
prev_task_i = task_i - 1
if self._chain_execution and prev_task_i in self.cell_done:
await self.cell_done[prev_task_i].wait()
del self.cell_done[prev_task_i]
PARENT_VAR.set(parent)
IDENTS_VAR.set(idents)
parent_header = parent["header"]
traceback, exception = [], None
namespace = self.get_namespace(parent_header)
try:
result = await self.locals[namespace]["__async_cell__"]()
result = await self.locals[namespace][f"__async_cell{task_i}__"]()
except KeyboardInterrupt:
self.interrupt()
except Exception as e:
Expand All @@ -371,6 +377,8 @@ async def execute_and_finish(
self.show_result(result, self.globals[namespace], parent_header)
cache_execution(self.cache, cache_info, self.globals[namespace], result)
finally:
self.cell_done[task_i].set()
del self.locals[namespace][f"__async_cell{task_i}__"]
self.finish_execution(
idents,
parent_header,
Expand All @@ -385,8 +393,8 @@ def finish_execution(
self,
idents: List[bytes],
parent_header: Dict[str, Any],
execution_count: Optional[int],
exception: Optional[Exception] = None,
execution_count: int | None,
exception: Exception | None = None,
no_exec: bool = False,
traceback: List[str] = [],
result=None,
Expand Down
5 changes: 3 additions & 2 deletions akernel/kernelspec.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import os
import sys
import json
from typing import Optional


def write_kernelspec(
dir_name: str, mode: str, display_name: str, cache_dir: Optional[str]
dir_name: str, mode: str, display_name: str, cache_dir: str | None
) -> None:
argv = ["akernel", "launch"]
if mode:
Expand Down
6 changes: 3 additions & 3 deletions akernel/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import hmac
import hashlib
from datetime import datetime, timezone
from typing import List, Dict, Tuple, Any, Optional, cast
from typing import List, Dict, Tuple, Any, cast

from zmq.utils import jsonapi
from zmq.asyncio import Socket
Expand Down Expand Up @@ -86,7 +86,7 @@ def serialize(msg: Dict[str, Any], key: str, address: bytes = b"") -> List[bytes

async def receive_message(
sock: Socket, timeout: float = float("inf")
) -> Optional[Tuple[List[bytes], Dict[str, Any]]]:
) -> Tuple[List[bytes], Dict[str, Any]] | None:
timeout *= 1000 # in ms
ready = await sock.poll(timeout)
if ready:
Expand All @@ -101,7 +101,7 @@ def send_message(
sock: Socket,
key: str,
address: bytes = b"",
buffers: Optional[List] = None,
buffers: List | None = None,
) -> None:
to_send = serialize(msg, key, address)
buffers = buffers or []
Expand Down
5 changes: 4 additions & 1 deletion akernel/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ def all_modes(request):
mode = request.param
kernel_name = "akernel"
if mode:
modes = mode.split("-")
modes.sort()
mode = "-".join(modes)
kernel_name += f"-{mode}"
if mode == "cache":
cache_dir = os.path.join(
sys.prefix, "share", "jupyter", "kernels", "akernel", "cache"
)
shutil.rmtree(cache_dir, ignore_errors=True)
display_name = f"Python 3 ({kernel_name})"
write_kernelspec("akernel", mode, display_name, None)
write_kernelspec(kernel_name, mode, display_name, None)
11 changes: 6 additions & 5 deletions akernel/tests/test_execution.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import sys
import time
from textwrap import dedent
import re
from math import sin
from typing import List, Dict, Tuple, Any, Optional
from typing import List, Dict, Tuple, Any

import pytest

Expand All @@ -12,16 +14,15 @@

async def run(
code: str,
globals_: Optional[Dict[str, Any]] = None,
chain: bool = False,
globals_: Dict[str, Any] | None = None,
react: bool = False,
cache: Optional[Dict[str, Any]] = None,
cache: Dict[str, Any] | None = None,
) -> Tuple[Any, List[str], bool, Dict[str, Any], Dict[str, Any]]:
if globals_ is None:
globals_ = {}
locals_: Dict[str, Any] = {}
result, traceback, interrupted = await execute(
code, globals_, locals_, chain=chain, react=react, cache=cache
code, globals_, locals_, react=react, cache=cache
)
if "__builtins__" in globals_:
del globals_["__builtins__"]
Expand Down
8 changes: 5 additions & 3 deletions akernel/traceback.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import sys
import types
from typing import Optional, cast
from typing import cast

from colorama import Fore, Style # type: ignore

Expand All @@ -13,7 +15,7 @@ def get_traceback(code: str, exception, execution_count: int = 0):
break
tb = tb.tb_next
stack = []
frame: Optional[types.FrameType] = tb.tb_frame
frame: types.FrameType | None = tb.tb_frame
while True:
assert frame is not None
stack.append(frame)
Expand All @@ -31,7 +33,7 @@ def get_traceback(code: str, exception, execution_count: int = 0):
with open(filename) as f:
code = f.read()
filename = f"{Fore.CYAN}File{Style.RESET_ALL} {Fore.GREEN}{filename}{Style.RESET_ALL}"
if frame.f_code.co_name == "__async_cell__":
if frame.f_code.co_name.startswith("__async_cell"):
name = "<module>"
else:
name = frame.f_code.co_name
Expand Down

0 comments on commit d267e3f

Please sign in to comment.