diff --git a/libcst/metadata/__init__.py b/libcst/metadata/__init__.py index 75e382292..66e7e5251 100644 --- a/libcst/metadata/__init__.py +++ b/libcst/metadata/__init__.py @@ -5,6 +5,7 @@ from libcst._position import CodePosition, CodeRange +from libcst.metadata.accessor_provider import AccessorProvider from libcst.metadata.base_provider import ( BaseMetadataProvider, BatchableMetadataProvider, @@ -86,6 +87,7 @@ "Accesses", "TypeInferenceProvider", "FullRepoManager", + "AccessorProvider", # Experimental APIs: "ExperimentalReentrantCodegenProvider", "CodegenPartial", diff --git a/libcst/metadata/accessor_provider.py b/libcst/metadata/accessor_provider.py new file mode 100644 index 000000000..5d4f22e42 --- /dev/null +++ b/libcst/metadata/accessor_provider.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import dataclasses + +import libcst as cst + +from libcst.metadata.base_provider import VisitorMetadataProvider + + +class AccessorProvider(VisitorMetadataProvider[str]): + def on_visit(self, node: cst.CSTNode) -> bool: + for f in dataclasses.fields(node): + child = getattr(node, f.name) + self.set_metadata(child, f.name) + return True diff --git a/libcst/metadata/tests/test_accessor_provider.py b/libcst/metadata/tests/test_accessor_provider.py new file mode 100644 index 000000000..6ccfad5ee --- /dev/null +++ b/libcst/metadata/tests/test_accessor_provider.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import dataclasses + +from textwrap import dedent + +import libcst as cst +from libcst.metadata import AccessorProvider, MetadataWrapper +from libcst.testing.utils import data_provider, UnitTest + + +class DependentVisitor(cst.CSTVisitor): + METADATA_DEPENDENCIES = (AccessorProvider,) + + def __init__(self, *, test: UnitTest) -> None: + self.test = test + + def on_visit(self, node: cst.CSTNode) -> bool: + for f in dataclasses.fields(node): + child = getattr(node, f.name) + if type(child) is cst.CSTNode: + accessor = self.get_metadata(AccessorProvider, child) + self.test.assertEqual(accessor, f.name) + + return True + + +class AccessorProviderTest(UnitTest): + @data_provider( + ( + ( + """ + foo = 'toplevel' + fn1(foo) + fn2(foo) + def fn_def(): + foo = 'shadow' + fn3(foo) + """, + ), + ( + """ + global_var = None + @cls_attr + class Cls(cls_attr, kwarg=cls_attr): + cls_attr = 5 + def f(): + pass + """, + ), + ( + """ + iterator = None + condition = None + [elt for target in iterator if condition] + {elt for target in iterator if condition} + {elt: target for target in iterator if condition} + (elt for target in iterator if condition) + """, + ), + ) + ) + def test_accessor_provier(self, code: str) -> None: + wrapper = MetadataWrapper(cst.parse_module(dedent(code))) + wrapper.visit(DependentVisitor(test=self))