Skip to content

Commit

Permalink
stubber: Refactor merge Annotation storage.
Browse files Browse the repository at this point in the history
Signed-off-by: Jos Verlinde <Jos.Verlinde@microsoft.com>
  • Loading branch information
Josverl committed Aug 21, 2024
1 parent c8f7f69 commit 3a38d4d
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 40 deletions.
36 changes: 14 additions & 22 deletions src/stubber/codemod/merge_docstub.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
MODULE_KEY,
StubTypingCollector,
TypeInfo,
AnnoValue,
update_def_docstr,
update_module_docstr,
)
Expand Down Expand Up @@ -72,7 +73,8 @@ def __init__(self, context: CodemodContext, docstub_file: Union[Path, str]) -> N
# store the annotations
self.annotations: Dict[
Tuple[str, ...], # key: tuple of canonical class/function name
Union[TypeInfo, str, List[TypeInfo]], # value: TypeInfo
AnnoValue,
# Union[TypeInfo, str, List[TypeInfo]], # value: TypeInfo
] = {}
self.comments: List[str] = []

Expand Down Expand Up @@ -150,7 +152,7 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c

# update/replace module docstrings
# todo: or should we add / merge the docstrings?
docstub_docstr = self.annotations[MODULE_KEY]
docstub_docstr = self.annotations[MODULE_KEY].docstring
assert isinstance(docstub_docstr, str)
src_docstr = original_node.get_docstring() or ""
if src_docstr or docstub_docstr:
Expand All @@ -176,11 +178,11 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c

# hack to 2nd foo annotation
# updated_node = updated_node.with_changes( children=updated_node.children.append(self.annotations[('foo',)][1]))

# Insert the new function at the end of the module
new_function = self.annotations[("foo",)][1].def_node
modified_body = tuple(list(updated_node.body) + [new_function])
updated_node = updated_node.with_changes(body=modified_body)
if "overload" in docstub_docstr.lower():
# Insert the new function at the end of the module
new_function = self.annotations[("foo",)].overloads[-1].def_node
modified_body = tuple(list(updated_node.body) + [new_function])
updated_node = updated_node.with_changes(body=modified_body)

return updated_node
# --------------------------------------------------------------------
Expand All @@ -200,12 +202,11 @@ def leave_ClassDef(
# no changes to the class
return updated_node
# update the firmware_stub from the doc_stub information
doc_stub = self.annotations[stack_id]
assert not isinstance(doc_stub, str)
doc_stub = self.annotations[stack_id].type_info
# first update the docstring
updated_node = update_def_docstr(updated_node, doc_stub.docstr_node)
# Sometimes the MCU stubs and the doc stubs have different types : FunctionDef / ClassDef
# we need to be carefull not to copy over all the annotations if the types are different
# we need to be careful not to copy over all the annotations if the types are different
if doc_stub.def_type == "classdef":
# Same type, we can copy over all the annotations
# combine the decorators from the doc-stub and the firmware stub
Expand All @@ -224,13 +225,6 @@ def leave_ClassDef(
# for now just return the updated node
return updated_node

# ------------------------------------------------------------------------
# def visit_Iterable(self, node) -> Optional[bool]:
# return True

# def leave_Iterable(self, node) -> Optional[bool]:
# return True

# ------------------------------------------------------------------------
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
self.stack.append(node.name.value)
Expand All @@ -246,11 +240,9 @@ def leave_FunctionDef(
# no changes to the function
return updated_node
# update the firmware_stub from the doc_stub information
if isinstance(self.annotations[stack_id], List) and self.annotations[stack_id]:
doc_stub = self.annotations[stack_id][0]
else:
doc_stub = self.annotations[stack_id]
doc_stub = self.annotations[stack_id].type_info
assert isinstance(doc_stub, TypeInfo)
assert doc_stub
# first update the docstring
updated_node = update_def_docstr(updated_node, doc_stub.docstr_node, doc_stub.def_node)
# Sometimes the MCU stubs and the doc stubs have different types : FunctionDef / ClassDef
Expand Down Expand Up @@ -293,7 +285,7 @@ def leave_FunctionDef(
elif doc_stub.def_type == "classdef":
# Different type: ClassDef != FuncDef ,
if doc_stub.def_node and self.replace_functiondef_with_classdef:
# replace the functiondef with the classdef from the stub file
# replace the functionDef with the classdef from the stub file
return doc_stub.def_node
# for now just return the updated node
return updated_node
Expand Down
40 changes: 22 additions & 18 deletions src/stubber/cst_transformer.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
"""helper functions for stub transformations"""

# sourcery skip: snake-case-functions
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import libcst as cst


@dataclass
class TypeInfo:
"contains the functiondefs and classdefs info read from the stubs source"
"contains the functionDefs and classDefs info read from the stubs source"
name: str
decorators: Sequence[cst.Decorator]
params: Optional[cst.Parameters] = None
returns: Optional[cst.Annotation] = None
docstr_node: Optional[cst.SimpleStatementLine] = None
def_node: Optional[Union[cst.FunctionDef, cst.ClassDef]] = None
def_type: str = "?" # funcdef or classdef or module
def_type: str = "?" # funcDef or classDef or module


@dataclass
class AnnoValue:
"The different values for the annotations"
docstring: Optional[str] = "" # strings
type_info: Optional[TypeInfo] = None # simple type
overloads: List[TypeInfo] = field(default_factory=list) # store function / method overloads


class TransformError(Exception):
Expand All @@ -27,7 +35,7 @@ class TransformError(Exception):


MODULE_KEY = ("__module",)
MODDOC_KEY = ("__module_docstring",)
MOD_DOCSTR_KEY = ("__module_docstring",)

# debug helper
_m = cst.parse_module("")
Expand All @@ -44,10 +52,7 @@ def __init__(self):
# store the annotations
self.annotations: Dict[
Tuple[str, ...], # key: tuple of canonical class/function name
Union[TypeInfo,
str,
List[TypeInfo], # list of overloads
],
AnnoValue, # The TypeInfo or list of TypeInfo
] = {}
self.comments: List[str] = []

Expand All @@ -56,7 +61,7 @@ def visit_Module(self, node: cst.Module) -> bool:
"""Store the module docstring"""
docstr = node.get_docstring()
if docstr:
self.annotations[MODULE_KEY] = docstr
self.annotations[MODULE_KEY] = AnnoValue(docstring=docstr)
return True

def visit_Comment(self, node: cst.Comment) -> None:
Expand Down Expand Up @@ -86,7 +91,7 @@ def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
def_type="classdef",
def_node=node,
)
self.annotations[tuple(self.stack)] = ti
self.annotations[tuple(self.stack)] = AnnoValue(type_info=ti)

def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
"""remove the class name from the stack"""
Expand All @@ -110,14 +115,13 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
def_node=node,
)
key = tuple(self.stack)
if key not in self.annotations:
self.annotations[key] = []
assert isinstance(self.annotations[key], list)
self.annotations[key].append(ti)
# if node.decorators[0].decorator.value == "overload":
# # overload functions are not stored in the annotations
# return False
# self.annotations[tuple(self.stack)] = ti
if not key in self.annotations:
# store the first function/method signature
self.annotations[key] = AnnoValue(type_info=ti)

if len(node.decorators) > 0 and node.decorators[0].decorator.value == "overload": # type: ignore
# and store the overloads
self.annotations[key].overloads.append(ti)

def update_append_first_node(self, node):
"""Store the function/method docstring or function/method sig"""
Expand Down

0 comments on commit 3a38d4d

Please sign in to comment.