Skip to content

Commit

Permalink
Add support for decorative partial functions
Browse files Browse the repository at this point in the history
  • Loading branch information
dafu-wu committed Nov 24, 2022
1 parent b496c55 commit ad066f5
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import inspect
import operator
import itertools
import functools
from contextlib import _GeneratorContextManager
from inspect import getfullargspec, iscoroutinefunction, isgeneratorfunction

Expand Down Expand Up @@ -71,7 +72,7 @@ def __init__(self, func=None, name=None, signature=None,
self.name = '_lambda_'
self.doc = func.__doc__
self.module = func.__module__
if inspect.isroutine(func):
if inspect.isroutine(func) or isinstance(func, functools.partial):
argspec = getfullargspec(func)
self.annotations = getattr(func, '__annotations__', {})
for a in ('args', 'varargs', 'varkw', 'defaults', 'kwonlyargs',
Expand Down Expand Up @@ -214,6 +215,8 @@ def decorate(func, caller, extras=(), kwsyntax=False):
does. By default kwsyntax is False and the the arguments are untouched.
"""
sig = inspect.signature(func)
if isinstance(func, functools.partial):
func = functools.update_wrapper(func, func.func)
if iscoroutinefunction(caller):
async def fun(*args, **kw):
if not kwsyntax:
Expand All @@ -230,6 +233,7 @@ def fun(*args, **kw):
if not kwsyntax:
args, kw = fix(args, kw, sig)
return caller(func, *(extras + args), **kw)

fun.__name__ = func.__name__
fun.__doc__ = func.__doc__
fun.__wrapped__ = func
Expand Down
16 changes: 16 additions & 0 deletions src/tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import unittest
import decimal
import inspect
import functools
from asyncio import get_event_loop
from collections import defaultdict, ChainMap, abc as c
from decorator import dispatch_on, contextmanager, decorator
Expand Down Expand Up @@ -509,5 +510,20 @@ def __len__(self):
h(u)


@decorator
def partial_before_after(func, *args, **kwargs):
return "<before>" + func(*args, **kwargs) + "<after>"


class PartialTestCase(unittest.TestCase):
def test_before_after(self):
def origin_func(x, y):
return x + y
_func = functools.partial(origin_func, "x")
partial_func = partial_before_after(_func)
out = partial_func("y")
self.assertEqual(out, '<before>xy<after>')


if __name__ == '__main__':
unittest.main()

0 comments on commit ad066f5

Please sign in to comment.