Skip to content

Commit

Permalink
feat(menta): support all annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
Whth committed Jun 2, 2024
1 parent 052d3be commit 0680f92
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 13 deletions.
24 changes: 11 additions & 13 deletions src/mentabotix/modules/menta.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Dict,
Any,
Set,
Type,
)

from .exceptions import BadSignatureError, RequirementError
Expand Down Expand Up @@ -171,7 +172,7 @@ def construct_inlined_function(
usages: List[SamplerUsage],
judging_source: List[SourceCode] | SourceCode,
extra_context: Dict[str, Any] = None,
return_type_varname: Optional[str] = None,
return_type: Optional[Type] = None,
return_raw: bool = False,
function_name: str = "_func",
) -> Callable[[], Any] | Tuple[str, Dict[str, Any]]:
Expand All @@ -194,14 +195,6 @@ def construct_inlined_function(
RequirementError: 如果返回类型变量名在extra_context中未定义,或者判断源中缺少必要的占位符。
RuntimeError: 如果遇到了不支持的采样器类型。
"""
if (
return_type_varname
and return_type_varname not in extra_context
and return_type_varname not in BUILTIN_TYPES
):
raise RequirementError(
f'The return_type_varname "{return_type_varname}" is not defined in extra_context: {extra_context} and is not in builtin types: {BUILTIN_TYPES}'
)

# 将judging_source统一处理为字符串格式
judging_source: str = judging_source if isinstance(judging_source, str) else "\n ".join(judging_source)
Expand Down Expand Up @@ -265,17 +258,22 @@ def construct_inlined_function(
)
_logger.debug(f"Created temp_var_source: {temp_var_source}")

# 构建完整的函数源码并编译执行
function_head: str = (
f"def {function_name}() -> {return_type_varname}:" if return_type_varname else f"def {function_name}():"
)
if return_type:

return_type_varname = f"{function_name}_return_type"
compile_context.update({return_type_varname: return_type})
# 构建完整的函数源码并编译执行
function_head: str = f"def {function_name}() -> {return_type_varname}:"
else:
function_head: str = f"def {function_name}():"
func_source = f"{function_head}\n" f" {temp_var_source}\n" f" {judging_source}\n" f" return {RET_IDENTIFIER}"
_logger.debug(f"Created func_source: {func_source}")
_logger.debug("Compiling func_source")

# 更新执行环境中的采样器和额外上下文信息

compile_context.update(extra_context) if extra_context else None

if return_raw:
return func_source, compile_context
exec(func_source, compile_context) # exec the source with the context
Expand Down
18 changes: 18 additions & 0 deletions tests/test_menta.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,24 @@ def test_function_name(self):
self.assertEqual(func.__name__, function_name)
print(func())

def test_func_ret_type(self):
from inspect import signature

sages = [
SamplerUsage(used_sampler_index=0, required_data_indexes=[0, 2]),
SamplerUsage(used_sampler_index=1, required_data_indexes=[5]),
SamplerUsage(used_sampler_index=2, required_data_indexes=[0, 1, 2]),
]
function_name = "func_a"
func = self.menta.construct_inlined_function(
usages=sages,
judging_source="ret=s0,s1,s2,s3,s4+s5",
return_raw=False,
return_type=int,
function_name=function_name,
)
self.assertEqual(signature(func).return_annotation, int)

def tearDown(self):
# 清理可能的副作用
pass
Expand Down

0 comments on commit 0680f92

Please sign in to comment.