Skip to content

Commit

Permalink
gazelle: Populate plugins attributes with annotation processors (#276)
Browse files Browse the repository at this point in the history
  • Loading branch information
illicitonion committed May 10, 2024
1 parent 01812dd commit f6b2319
Show file tree
Hide file tree
Showing 21 changed files with 644 additions and 98 deletions.
17 changes: 17 additions & 0 deletions java/gazelle/configure.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/bazel-contrib/rules_jvm/java/gazelle/javaconfig"
"github.com/bazel-contrib/rules_jvm/java/gazelle/private/javaparser"
"github.com/bazel-contrib/rules_jvm/java/gazelle/private/maven"
"github.com/bazel-contrib/rules_jvm/java/gazelle/private/types"
"github.com/bazelbuild/bazel-gazelle/config"
"github.com/bazelbuild/bazel-gazelle/rule"
bzl "github.com/bazelbuild/buildtools/build"
Expand Down Expand Up @@ -64,6 +65,7 @@ func (jc *Configurer) KnownDirectives() []string {
javaconfig.JavaTestMode,
javaconfig.JavaGenerateProto,
javaconfig.JavaMavenRepositoryName,
javaconfig.JavaAnnotationProcessorPlugin,
}
}

Expand Down Expand Up @@ -129,6 +131,21 @@ func (jc *Configurer) Configure(c *config.Config, rel string, f *rule.File) {
jc.lang.logger.Fatal().Msgf("invalid value for directive %q: %s: possible values are true/false",
javaconfig.JavaGenerateProto, d.Value)
}
case javaconfig.JavaAnnotationProcessorPlugin:
// Format: # gazelle:java_annotation_processor_plugin com.example.AnnotationName com.example.AnnotationProcessorImpl
parts := strings.Split(d.Value, " ")
if len(parts) != 2 {
jc.lang.logger.Fatal().Msgf("invalid value for directive %q: %s: expected an annotation class-name followed by a processor class-name", javaconfig.JavaAnnotationProcessorPlugin, d.Value)
}
annotationClassName, err := types.ParseClassName(parts[0])
if err != nil {
jc.lang.logger.Fatal().Msgf("invalid value for directive %q: %q: couldn't parse annotation processor annotation class-name: %v", javaconfig.JavaAnnotationProcessorPlugin, parts[0], err)
}
processorClassName, err := types.ParseClassName(parts[1])
if err != nil {
jc.lang.logger.Fatal().Msgf("invalid value for directive %q: %q: couldn't parse annotation processor class-name: %v", javaconfig.JavaAnnotationProcessorPlugin, parts[1], err)
}
cfg.AddAnnotationProcessorPlugin(*annotationClassName, *processorClassName)
}
}
}
Expand Down
32 changes: 22 additions & 10 deletions java/gazelle/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ func (l javaLang) GenerateRules(args language.GenerateArgs) language.GenerateRes
// All java packages present in this bazel package.
allPackageNames := sorted_set.NewSortedSetFn([]types.PackageName{}, types.PackageNameLess)

annotationProcessorClasses := sorted_set.NewSortedSetFn(nil, types.ClassNameLess)

if isModule {
for mRel, mJavaPkg := range l.javaPackageCache {
if !strings.HasPrefix(mRel, args.Rel) {
Expand All @@ -152,6 +154,9 @@ func (l javaLang) GenerateRules(args language.GenerateArgs) language.GenerateRes
accumulateJavaFile(cfg, testJavaFiles, testHelperJavaFiles, separateTestJavaFiles, file, mJavaPkg.PerClassMetadata, log)
}
}
for _, annotationClass := range mJavaPkg.AllAnnotations().SortedSlice() {
annotationProcessorClasses.AddAll(cfg.GetAnnotationProcessorPluginClasses(annotationClass))
}
}
} else {
allPackageNames.Add(javaPkg.Name)
Expand All @@ -174,6 +179,9 @@ func (l javaLang) GenerateRules(args language.GenerateArgs) language.GenerateRes
productionJavaFiles.Add(path)
}
}
for _, annotationClass := range javaPkg.AllAnnotations().SortedSlice() {
annotationProcessorClasses.AddAll(cfg.GetAnnotationProcessorPluginClasses(annotationClass))
}
}

allPackageNamesSlice := allPackageNames.SortedSlice()
Expand All @@ -192,7 +200,7 @@ func (l javaLang) GenerateRules(args language.GenerateArgs) language.GenerateRes
}

if productionJavaFiles.Len() > 0 {
l.generateJavaLibrary(args.File, args.Rel, filepath.Base(args.Rel), productionJavaFiles.SortedSlice(), allPackageNames, nonLocalProductionJavaImports, nonLocalJavaExports, false, javaLibraryKind, &res)
l.generateJavaLibrary(args.File, args.Rel, filepath.Base(args.Rel), productionJavaFiles.SortedSlice(), allPackageNames, nonLocalProductionJavaImports, nonLocalJavaExports, annotationProcessorClasses, false, javaLibraryKind, &res)
}

var testHelperJavaClasses *sorted_set.SortedSet[types.ClassName]
Expand Down Expand Up @@ -228,7 +236,7 @@ func (l javaLang) GenerateRules(args language.GenerateArgs) language.GenerateRes
testJavaImportsWithHelpers.Add(tf.pkg)
srcs = append(srcs, tf.pathRelativeToBazelWorkspaceRoot)
}
l.generateJavaLibrary(args.File, args.Rel, filepath.Base(args.Rel), srcs, packages, testJavaImports, nonLocalJavaExports, true, javaLibraryKind, &res)
l.generateJavaLibrary(args.File, args.Rel, filepath.Base(args.Rel), srcs, packages, testJavaImports, nonLocalJavaExports, annotationProcessorClasses, true, javaLibraryKind, &res)
}
}

Expand All @@ -240,7 +248,7 @@ func (l javaLang) GenerateRules(args language.GenerateArgs) language.GenerateRes
case "file":
for _, tf := range testJavaFiles.SortedSlice() {
separateJavaTestReasons := separateTestJavaFiles[tf]
l.generateJavaTest(args.File, args.Rel, cfg.MavenRepositoryName(), tf, isModule, testJavaImportsWithHelpers, nil, separateJavaTestReasons.wrapper, separateJavaTestReasons.attributes, &res)
l.generateJavaTest(args.File, args.Rel, cfg.MavenRepositoryName(), tf, isModule, testJavaImportsWithHelpers, annotationProcessorClasses, nil, separateJavaTestReasons.wrapper, separateJavaTestReasons.attributes, &res)
}

case "suite":
Expand Down Expand Up @@ -268,6 +276,7 @@ func (l javaLang) GenerateRules(args language.GenerateArgs) language.GenerateRes
packageNames,
cfg.MavenRepositoryName(),
testJavaImportsWithHelpers,
annotationProcessorClasses,
cfg.GetCustomJavaTestFileSuffixes(),
testHelperJavaFiles.Len() > 0,
&res,
Expand All @@ -284,7 +293,7 @@ func (l javaLang) GenerateRules(args language.GenerateArgs) language.GenerateRes
testHelperDep = ptr(testHelperLibname(suiteName))
}
separateJavaTestReasons := separateTestJavaFiles[src]
l.generateJavaTest(args.File, args.Rel, cfg.MavenRepositoryName(), src, isModule, testJavaImportsWithHelpers, testHelperDep, separateJavaTestReasons.wrapper, separateJavaTestReasons.attributes, &res)
l.generateJavaTest(args.File, args.Rel, cfg.MavenRepositoryName(), src, isModule, testJavaImportsWithHelpers, annotationProcessorClasses, testHelperDep, separateJavaTestReasons.wrapper, separateJavaTestReasons.attributes, &res)
}
}
}
Expand Down Expand Up @@ -415,7 +424,7 @@ func addFilteringOutOwnPackage(to *sorted_set.SortedSet[types.PackageName], from

func accumulateJavaFile(cfg *javaconfig.Config, testJavaFiles, testHelperJavaFiles *sorted_set.SortedSet[javaFile], separateTestJavaFiles map[javaFile]separateJavaTestReasons, file javaFile, perClassMetadata map[string]java.PerClassMetadata, log zerolog.Logger) {
if cfg.IsJavaTestFile(filepath.Base(file.pathRelativeToBazelWorkspaceRoot)) {
annotationClassNames := sorted_set.NewSortedSet[string](nil)
annotationClassNames := sorted_set.NewSortedSetFn[types.ClassName](nil, types.ClassNameLess)
metadataForClass := perClassMetadata[file.ClassName().FullyQualifiedClassName()]
annotationClassNames.AddAll(metadataForClass.AnnotationClassNames)
for _, key := range metadataForClass.MethodAnnotationClassNames.Keys() {
Expand All @@ -425,15 +434,15 @@ func accumulateJavaFile(cfg *javaconfig.Config, testJavaFiles, testHelperJavaFil
perFileAttrs := make(map[string]bzl.Expr)
wrapper := ""
for _, annotationClassName := range annotationClassNames.SortedSlice() {
if attrs, ok := cfg.AttributesForAnnotation(annotationClassName); ok {
if attrs, ok := cfg.AttributesForAnnotation(annotationClassName.FullyQualifiedClassName()); ok {
for k, v := range attrs {
if old, ok := perFileAttrs[k]; ok {
log.Error().Str("file", file.pathRelativeToBazelWorkspaceRoot).Msgf("Saw conflicting attr overrides from annotations for attribute %v: %v and %v. Picking one at random.", k, old, v)
}
perFileAttrs[k] = v
}
}
newWrapper, ok := cfg.WrapperForAnnotation(annotationClassName)
newWrapper, ok := cfg.WrapperForAnnotation(annotationClassName.FullyQualifiedClassName())
if ok {
if wrapper != "" {
log.Error().Str("file", file.pathRelativeToBazelWorkspaceRoot).Msgf("Saw conflicting wrappers from annotations: %v and %v. Picking one at random.", wrapper, newWrapper)
Expand All @@ -453,7 +462,7 @@ func accumulateJavaFile(cfg *javaconfig.Config, testJavaFiles, testHelperJavaFil
}
}

func (l javaLang) generateJavaLibrary(file *rule.File, pathToPackageRelativeToBazelWorkspace string, name string, srcsRelativeToBazelWorkspace []string, packages, imports *sorted_set.SortedSet[types.PackageName], exports *sorted_set.SortedSet[types.PackageName], testonly bool, javaLibraryRuleKind string, res *language.GenerateResult) {
func (l javaLang) generateJavaLibrary(file *rule.File, pathToPackageRelativeToBazelWorkspace string, name string, srcsRelativeToBazelWorkspace []string, packages, imports *sorted_set.SortedSet[types.PackageName], exports *sorted_set.SortedSet[types.PackageName], annotationProcessorClasses *sorted_set.SortedSet[types.ClassName], testonly bool, javaLibraryRuleKind string, res *language.GenerateResult) {
const ruleKind = "java_library"
r := rule.NewRule(ruleKind, name)

Expand Down Expand Up @@ -487,6 +496,7 @@ func (l javaLang) generateJavaLibrary(file *rule.File, pathToPackageRelativeToBa
PackageNames: packages,
ImportedPackageNames: imports,
ExportedPackageNames: exports,
AnnotationProcessors: annotationProcessorClasses,
}
res.Imports = append(res.Imports, resolveInput)
}
Expand All @@ -511,7 +521,7 @@ func (l javaLang) generateJavaBinary(file *rule.File, m types.ClassName, libName
})
}

func (l javaLang) generateJavaTest(file *rule.File, pathToPackageRelativeToBazelWorkspace string, mavenRepositoryName string, f javaFile, includePackageInName bool, imports *sorted_set.SortedSet[types.PackageName], depOnTestHelpers *string, wrapper string, extraAttributes map[string]bzl.Expr, res *language.GenerateResult) {
func (l javaLang) generateJavaTest(file *rule.File, pathToPackageRelativeToBazelWorkspace string, mavenRepositoryName string, f javaFile, includePackageInName bool, imports *sorted_set.SortedSet[types.PackageName], annotationProcessorClasses *sorted_set.SortedSet[types.ClassName], depOnTestHelpers *string, wrapper string, extraAttributes map[string]bzl.Expr, res *language.GenerateResult) {
className := f.ClassName()
fullyQualifiedTestClass := className.FullyQualifiedClassName()
var testName string
Expand Down Expand Up @@ -571,6 +581,7 @@ func (l javaLang) generateJavaTest(file *rule.File, pathToPackageRelativeToBazel
resolveInput := types.ResolveInput{
PackageNames: sorted_set.NewSortedSetFn([]types.PackageName{f.pkg}, types.PackageNameLess),
ImportedPackageNames: testImports,
AnnotationProcessors: annotationProcessorClasses,
}
res.Imports = append(res.Imports, resolveInput)
}
Expand Down Expand Up @@ -598,7 +609,7 @@ var junit5RuntimeDeps = []string{
"org.junit.platform:junit-platform-reporting",
}

func (l javaLang) generateJavaTestSuite(file *rule.File, name string, srcs []string, packageNames *sorted_set.SortedSet[types.PackageName], mavenRepositoryName string, imports *sorted_set.SortedSet[types.PackageName], customTestSuffixes *[]string, hasHelpers bool, res *language.GenerateResult) {
func (l javaLang) generateJavaTestSuite(file *rule.File, name string, srcs []string, packageNames *sorted_set.SortedSet[types.PackageName], mavenRepositoryName string, imports *sorted_set.SortedSet[types.PackageName], annotationProcessorClasses *sorted_set.SortedSet[types.ClassName], customTestSuffixes *[]string, hasHelpers bool, res *language.GenerateResult) {
const ruleKind = "java_test_suite"
r := rule.NewRule(ruleKind, name)
r.SetAttr("srcs", srcs)
Expand Down Expand Up @@ -636,6 +647,7 @@ func (l javaLang) generateJavaTestSuite(file *rule.File, name string, srcs []str
resolveInput := types.ResolveInput{
PackageNames: packageNames,
ImportedPackageNames: suiteImports,
AnnotationProcessors: annotationProcessorClasses,
}
res.Imports = append(res.Imports, resolveInput)
}
Expand Down
4 changes: 2 additions & 2 deletions java/gazelle/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func TestSingleJavaTestFile(t *testing.T) {
var res language.GenerateResult

l := newTestJavaLang(t)
l.generateJavaTest(nil, "", "maven", f, tc.includePackageInName, stringsToPackageNames(tc.importedPackages), nil, tc.wrapper, nil, &res)
l.generateJavaTest(nil, "", "maven", f, tc.includePackageInName, stringsToPackageNames(tc.importedPackages), nil, nil, tc.wrapper, nil, &res)

require.Len(t, res.Gen, 1, "want 1 generated rule")

Expand Down Expand Up @@ -252,7 +252,7 @@ func TestSuite(t *testing.T) {
var res language.GenerateResult

l := newTestJavaLang(t)
l.generateJavaTestSuite(nil, "blah", []string{src}, stringsToPackageNames([]string{pkg}), "maven", stringsToPackageNames(tc.importedPackages), nil, false, &res)
l.generateJavaTestSuite(nil, "blah", []string{src}, stringsToPackageNames([]string{pkg}), "maven", stringsToPackageNames(tc.importedPackages), nil, nil, false, &res)

require.Len(t, res.Gen, 1, "want 1 generated rule")

Expand Down
2 changes: 2 additions & 0 deletions java/gazelle/javaconfig/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ go_library(
importpath = "github.com/bazel-contrib/rules_jvm/java/gazelle/javaconfig",
visibility = ["//visibility:public"],
deps = [
"//java/gazelle/private/sorted_set",
"//java/gazelle/private/types",
"@com_github_bazelbuild_buildtools//build",
],
)
Expand Down
49 changes: 37 additions & 12 deletions java/gazelle/javaconfig/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"path/filepath"
"strings"

"github.com/bazel-contrib/rules_jvm/java/gazelle/private/sorted_set"
"github.com/bazel-contrib/rules_jvm/java/gazelle/private/types"
bzl "github.com/bazelbuild/buildtools/build"
)

Expand Down Expand Up @@ -47,6 +49,10 @@ const (
// JavaMavenRepositoryName tells the code generator what the repository name that contains all maven dependencies is.
// Defaults to "maven"
JavaMavenRepositoryName = "java_maven_repository_name"

// JavaAnnotationProcessorPlugin tells the code generator about specific java_plugin targets needed to process
// specific annotations.
JavaAnnotationProcessorPlugin = "java_annotation_processor_plugin"
)

// Configs is an extension of map[string]*Config. It provides finding methods
Expand All @@ -60,6 +66,10 @@ func (c *Config) NewChild() *Config {
for key, value := range c.excludedArtifacts {
clonedExcludedArtifacts[key] = value
}
annotationProcessorFullQualifiedClassToPluginClass := make(map[string]*sorted_set.SortedSet[types.ClassName])
for key, value := range c.annotationProcessorFullQualifiedClassToPluginClass {
annotationProcessorFullQualifiedClassToPluginClass[key] = value.Clone()
}
return &Config{
parent: c,
extensionEnabled: c.extensionEnabled,
Expand All @@ -74,6 +84,7 @@ func (c *Config) NewChild() *Config {
annotationToWrapper: c.annotationToWrapper,
excludedArtifacts: clonedExcludedArtifacts,
mavenRepositoryName: c.mavenRepositoryName,
annotationProcessorFullQualifiedClassToPluginClass: annotationProcessorFullQualifiedClassToPluginClass,
}
}

Expand All @@ -91,18 +102,19 @@ func (c *Configs) ParentForPackage(pkg string) *Config {
type Config struct {
parent *Config

extensionEnabled bool
isModuleRoot bool
generateProto bool
mavenInstallFile string
moduleGranularity string
repoRoot string
testMode string
customTestFileSuffixes *[]string
excludedArtifacts map[string]struct{}
annotationToAttribute map[string]map[string]bzl.Expr
annotationToWrapper map[string]string
mavenRepositoryName string
extensionEnabled bool
isModuleRoot bool
generateProto bool
mavenInstallFile string
moduleGranularity string
repoRoot string
testMode string
customTestFileSuffixes *[]string
excludedArtifacts map[string]struct{}
annotationToAttribute map[string]map[string]bzl.Expr
annotationToWrapper map[string]string
mavenRepositoryName string
annotationProcessorFullQualifiedClassToPluginClass map[string]*sorted_set.SortedSet[types.ClassName]
}

type LoadInfo struct {
Expand All @@ -125,6 +137,7 @@ func New(repoRoot string) *Config {
annotationToAttribute: make(map[string]map[string]bzl.Expr),
annotationToWrapper: make(map[string]string),
mavenRepositoryName: "maven",
annotationProcessorFullQualifiedClassToPluginClass: make(map[string]*sorted_set.SortedSet[types.ClassName]),
}
}

Expand Down Expand Up @@ -269,6 +282,18 @@ func (c *Config) IsTestRule(ruleKind string) bool {
return false
}

func (c *Config) GetAnnotationProcessorPluginClasses(annotationClass types.ClassName) *sorted_set.SortedSet[types.ClassName] {
return c.annotationProcessorFullQualifiedClassToPluginClass[annotationClass.FullyQualifiedClassName()]
}

func (c *Config) AddAnnotationProcessorPlugin(annotationClass types.ClassName, processorClass types.ClassName) {
fullyQualifiedAnnotationClass := annotationClass.FullyQualifiedClassName()
if _, ok := c.annotationProcessorFullQualifiedClassToPluginClass[fullyQualifiedAnnotationClass]; !ok {
c.annotationProcessorFullQualifiedClassToPluginClass[fullyQualifiedAnnotationClass] = sorted_set.NewSortedSetFn[types.ClassName](nil, types.ClassNameLess)
}
c.annotationProcessorFullQualifiedClassToPluginClass[fullyQualifiedAnnotationClass].Add(processorClass)
}

func equalStringSlices(l, r []string) bool {
if len(l) != len(r) {
return false
Expand Down
19 changes: 17 additions & 2 deletions java/gazelle/private/java/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,22 @@ type Package struct {
PerClassMetadata map[string]PerClassMetadata
}

func (p *Package) AllAnnotations() *sorted_set.SortedSet[types.ClassName] {
annotations := sorted_set.NewSortedSetFn(nil, types.ClassNameLess)
for _, pcm := range p.PerClassMetadata {
annotations.AddAll(pcm.AnnotationClassNames)
for _, method := range pcm.MethodAnnotationClassNames.Keys() {
annotations.AddAll(pcm.MethodAnnotationClassNames.Values(method))
}
for _, field := range pcm.FieldAnnotationClassNames.Keys() {
annotations.AddAll(pcm.FieldAnnotationClassNames.Values(field))
}
}
return annotations
}

type PerClassMetadata struct {
AnnotationClassNames *sorted_set.SortedSet[string]
MethodAnnotationClassNames *sorted_multiset.SortedMultiSet[string, string]
AnnotationClassNames *sorted_set.SortedSet[types.ClassName]
MethodAnnotationClassNames *sorted_multiset.SortedMultiSet[string, types.ClassName]
FieldAnnotationClassNames *sorted_multiset.SortedMultiSet[string, types.ClassName]
}
Loading

0 comments on commit f6b2319

Please sign in to comment.