From 0680f92f23c5453eee691f0d71cdb84c787483ef Mon Sep 17 00:00:00 2001 From: Whth <88489697+Whth@users.noreply.github.com> Date: Sun, 2 Jun 2024 22:20:56 +0800 Subject: [PATCH] feat(menta): support all annotation --- src/mentabotix/modules/menta.py | 24 +++++++++++------------- tests/test_menta.py | 18 ++++++++++++++++++ 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/src/mentabotix/modules/menta.py b/src/mentabotix/modules/menta.py index 857951a..34c0dab 100644 --- a/src/mentabotix/modules/menta.py +++ b/src/mentabotix/modules/menta.py @@ -15,6 +15,7 @@ Dict, Any, Set, + Type, ) from .exceptions import BadSignatureError, RequirementError @@ -171,7 +172,7 @@ def construct_inlined_function( usages: List[SamplerUsage], judging_source: List[SourceCode] | SourceCode, extra_context: Dict[str, Any] = None, - return_type_varname: Optional[str] = None, + return_type: Optional[Type] = None, return_raw: bool = False, function_name: str = "_func", ) -> Callable[[], Any] | Tuple[str, Dict[str, Any]]: @@ -194,14 +195,6 @@ def construct_inlined_function( RequirementError: 如果返回类型变量名在extra_context中未定义,或者判断源中缺少必要的占位符。 RuntimeError: 如果遇到了不支持的采样器类型。 """ - if ( - return_type_varname - and return_type_varname not in extra_context - and return_type_varname not in BUILTIN_TYPES - ): - raise RequirementError( - f'The return_type_varname "{return_type_varname}" is not defined in extra_context: {extra_context} and is not in builtin types: {BUILTIN_TYPES}' - ) # 将judging_source统一处理为字符串格式 judging_source: str = judging_source if isinstance(judging_source, str) else "\n ".join(judging_source) @@ -265,10 +258,14 @@ def construct_inlined_function( ) _logger.debug(f"Created temp_var_source: {temp_var_source}") - # 构建完整的函数源码并编译执行 - function_head: str = ( - f"def {function_name}() -> {return_type_varname}:" if return_type_varname else f"def {function_name}():" - ) + if return_type: + + return_type_varname = f"{function_name}_return_type" + compile_context.update({return_type_varname: return_type}) + # 构建完整的函数源码并编译执行 + function_head: str = f"def {function_name}() -> {return_type_varname}:" + else: + function_head: str = f"def {function_name}():" func_source = f"{function_head}\n" f" {temp_var_source}\n" f" {judging_source}\n" f" return {RET_IDENTIFIER}" _logger.debug(f"Created func_source: {func_source}") _logger.debug("Compiling func_source") @@ -276,6 +273,7 @@ def construct_inlined_function( # 更新执行环境中的采样器和额外上下文信息 compile_context.update(extra_context) if extra_context else None + if return_raw: return func_source, compile_context exec(func_source, compile_context) # exec the source with the context diff --git a/tests/test_menta.py b/tests/test_menta.py index fd88a4c..ba1cdfd 100644 --- a/tests/test_menta.py +++ b/tests/test_menta.py @@ -268,6 +268,24 @@ def test_function_name(self): self.assertEqual(func.__name__, function_name) print(func()) + def test_func_ret_type(self): + from inspect import signature + + sages = [ + SamplerUsage(used_sampler_index=0, required_data_indexes=[0, 2]), + SamplerUsage(used_sampler_index=1, required_data_indexes=[5]), + SamplerUsage(used_sampler_index=2, required_data_indexes=[0, 1, 2]), + ] + function_name = "func_a" + func = self.menta.construct_inlined_function( + usages=sages, + judging_source="ret=s0,s1,s2,s3,s4+s5", + return_raw=False, + return_type=int, + function_name=function_name, + ) + self.assertEqual(signature(func).return_annotation, int) + def tearDown(self): # 清理可能的副作用 pass