Skip to content

Commit

Permalink
Bug(GraphQL/Schema): If a type had an embedding as well as a list, hm…
Browse files Browse the repository at this point in the history
…_distance field was getting clobbered (#77)

Description: 

```
querySimilar<Type> queries defined a new derived type "<Type>WithDistance" with a new field hm_distance.

However, if <Type> had any lists, hm_distance was getting clobbered by an "<ListFieldName>Aggregate" being added to <Type>

The fix essentially does away with the derived type <Type>WithDistance as the resultType for querySimilar<Type> queries. Instead, we add <embeddingFieldName>Distance field for each embedding in the <Type> definition itself. This would make it easy to add support for filters on embeddings.
```

Fixes: HYP-447
  • Loading branch information
sunilmujumdar authored and harshil-goel committed Apr 17, 2024
1 parent 46c7c0f commit 7b6080b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 97 deletions.
30 changes: 15 additions & 15 deletions graphql/resolve/query_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -637,11 +637,11 @@ func rewriteAsGet(
// distance as math((v2 - v1) dot (v2 - v1))
// }
// querySimilarProductById(func: uid(distance), orderasc: val(distance)) {
// ProductWithDistance.id : Product.id
// ProductWithDistance.description : Product.description
// ProductWithDistance.title : Product.title
// ProductWithDistance.imageUrl : Product.imageUrl
// ProductWithDistance.hm_distance : val(distance)
// Product.id : Product.id
// Product.description : Product.description
// Product.title : Product.title
// Product.imageUrl : Product.imageUrl
// Product.hm_distance : val(distance)
// dgraph.uid : uid
// }
// }
Expand Down Expand Up @@ -746,11 +746,11 @@ func rewriteAsSimilarByIdQuery(

// order the result by euclidian distance, For example,
// querySimilarProductById(func: uid(distance), orderasc: val(distance)) {
// ProductWithDistance.id : Product.id
// ProductWithDistance.description : Product.description
// ProductWithDistance.title : Product.title
// ProductWithDistance.imageUrl : Product.imageUrl
// ProductWithDistance.hm_distance : val(distance)
// Product.id : Product.id
// Product.description : Product.description
// Product.title : Product.title
// Product.imageUrl : Product.imageUrl
// Product.hm_distance : val(distance)
// dgraph.uid : uid
// }
// }
Expand Down Expand Up @@ -779,11 +779,11 @@ func rewriteAsSimilarByIdQuery(
// distance as math((v2 - $search_vector) dot (v2 - $search_vector))
// }
// querySimilarProductById(func: uid(distance), orderasc: val(distance)) {
// ProductWithDistance.id : Product.id
// ProductWithDistance.description : Product.description
// ProductWithDistance.title : Product.title
// ProductWithDistance.imageUrl : Product.imageUrl
// ProductWithDistance.hm_distance : val(distance)
// Product.id : Product.id
// Product.description : Product.description
// Product.title : Product.title
// Product.imageUrl : Product.imageUrl
// Product.hm_distance : val(distance)
// dgraph.uid : uid
// }
// }
Expand Down
78 changes: 16 additions & 62 deletions graphql/schema/gqlschema.go
Original file line number Diff line number Diff line change
Expand Up @@ -2025,45 +2025,22 @@ func addGetQuery(schema *ast.Schema, defn *ast.Definition,
// defn - The type definition for the object for which this query will be added
func addSimilarByEmbeddingQuery(schema *ast.Schema, defn *ast.Definition) {

// Generate the new <Type> for similarity search results. This is done by
// adding a new field to
// the input type. The new field is "hm_distance". The name of the new type
// is <Type>WithDistance
fields := append(defn.Fields,
&ast.FieldDefinition{
Name: SimilarQueryDistanceFieldName,
Type: &ast.Type{NamedType: "Float"}})
// Add dgraph directive for the new type <Type>WithDistance.
// @dgraph(type: <Type>)
args := []*ast.Argument{}
args = append(args, &ast.Argument{
Name: "type",
Value: &ast.Value{Kind: ast.StringValue, Raw: defn.Name}})
dir := &ast.Directive{
Name: dgraphDirective,
Arguments: args,
}

// new type, <Type>WithDistance
resultTypeName := defn.Name + SimilarQueryResultTypeSuffix
resultType := &ast.Definition{
Kind: ast.Object,
Name: resultTypeName,
Fields: fields,
Directives: []*ast.Directive{dir},
}

// create the new query, querySimilar<Type>ByEmbedding
schema.Types[resultTypeName] = resultType
qry := &ast.FieldDefinition{
Name: SimilarQueryPrefix + defn.Name + SimilarByEmbeddingQuerySuffix,
Type: &ast.Type{
Elem: &ast.Type{
NamedType: resultTypeName,
NamedType: defn.Name,
},
},
}

// The new field is "hm_distance". Add it to input Type
if defn.Fields.ForName(SimilarQueryDistanceFieldName) == nil {
defn.Fields = append(defn.Fields,
&ast.FieldDefinition{
Name: SimilarQueryDistanceFieldName,
Type: &ast.Type{NamedType: "Float"}})
}
// Define the enum to
//select from among all predicates with "@hm_embedding" directives
enumName := defn.Name + EmbeddingEnumSuffix
Expand Down Expand Up @@ -2125,46 +2102,23 @@ func addSimilarByIdQuery(schema *ast.Schema, defn *ast.Definition,
return
}

// Generate the new <Type> for similarity search results. This is done by
// adding a new field to the input type. The new field is "hm_distance".
// The name of the new type is <Type>WithDistance
fields := append(defn.Fields,
&ast.FieldDefinition{
Name: SimilarQueryDistanceFieldName,
Type: &ast.Type{NamedType: "Float"}})

// Add dgraph directive for the new type <Type>WithDistance.
// @dgraph(type: <Type>)
args := []*ast.Argument{}
args = append(args,
&ast.Argument{
Name: "type",
Value: &ast.Value{Kind: ast.StringValue, Raw: defn.Name}})
dir := &ast.Directive{
Name: dgraphDirective,
Arguments: args,
}

// new type, <Type>WithDistance
resultTypeName := defn.Name + SimilarQueryResultTypeSuffix
resultType := &ast.Definition{
Kind: ast.Object,
Name: resultTypeName,
Fields: fields,
Directives: []*ast.Directive{dir},
}
schema.Types[resultTypeName] = resultType

// create the new query, querySimilar<Type>ById
qry := &ast.FieldDefinition{
Name: SimilarQueryPrefix + defn.Name + SimilarByIdQuerySuffix,
Type: &ast.Type{
Elem: &ast.Type{
NamedType: resultTypeName,
NamedType: defn.Name,
},
},
}

// The new field is "hm_distance". Add it to input Type
if defn.Fields.ForName(SimilarQueryDistanceFieldName) == nil {
defn.Fields = append(defn.Fields,
&ast.FieldDefinition{
Name: SimilarQueryDistanceFieldName,
Type: &ast.Type{NamedType: "Float"}})
}
// If the defn, only specified one of ID/XID field, then they are mandatory.
// If it specified both, then they are optional.
if hasIDField {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ type Product {
title: String
imageUrl: String
product_vector: [Float!] @hm_embedding
hm_distance: Float
}

type Purchase @lambdaOnMutate(add: true) {
Expand All @@ -20,6 +21,7 @@ type User {
email: String! @id
purchase_history(filter: PurchaseFilter, order: PurchaseOrder, first: Int, offset: Int): [Purchase] @hasInverse(field: user)
user_vector: [Float!] @hm_embedding
hm_distance: Float
purchase_historyAggregate(filter: PurchaseFilter): PurchaseAggregateResult
}

Expand Down Expand Up @@ -336,15 +338,6 @@ type ProductAggregateResult {
imageUrlMax: String
}

type ProductWithDistance @dgraph(type: "Product") {
id: String! @id
description: String
title: String
imageUrl: String
product_vector: [Float!] @hm_embedding
hm_distance: Float
}

type PurchaseAggregateResult {
count: Int
dateMin: DateTime
Expand Down Expand Up @@ -372,13 +365,6 @@ type UserAggregateResult {
emailMax: String
}

type UserWithDistance @dgraph(type: "User") {
email: String! @id
purchase_history(filter: PurchaseFilter, order: PurchaseOrder, first: Int, offset: Int): [Purchase] @hasInverse(field: user)
user_vector: [Float!] @hm_embedding
purchase_historyAggregate(filter: PurchaseFilter): PurchaseAggregateResult
}

#######################
# Generated Enums
#######################
Expand All @@ -393,6 +379,7 @@ enum ProductHasFilter {
title
imageUrl
product_vector
hm_distance
}

enum ProductOrderable {
Expand Down Expand Up @@ -420,6 +407,7 @@ enum UserHasFilter {
email
purchase_history
user_vector
hm_distance
}

enum UserOrderable {
Expand Down Expand Up @@ -554,15 +542,15 @@ input UserRef {

type Query {
getProduct(id: String!): Product
querySimilarProductById(id: String!, by: ProductEmbedding!, topK: Int!): [ProductWithDistance]
querySimilarProductByEmbedding(by: ProductEmbedding!, topK: Int!, vector: [Float!]!): [ProductWithDistance]
querySimilarProductById(id: String!, by: ProductEmbedding!, topK: Int!): [Product]
querySimilarProductByEmbedding(by: ProductEmbedding!, topK: Int!, vector: [Float!]!): [Product]
queryProduct(filter: ProductFilter, order: ProductOrder, first: Int, offset: Int): [Product]
aggregateProduct(filter: ProductFilter): ProductAggregateResult
queryPurchase(filter: PurchaseFilter, order: PurchaseOrder, first: Int, offset: Int): [Purchase]
aggregatePurchase(filter: PurchaseFilter): PurchaseAggregateResult
getUser(email: String!): User
querySimilarUserById(email: String!, by: UserEmbedding!, topK: Int!): [UserWithDistance]
querySimilarUserByEmbedding(by: UserEmbedding!, topK: Int!, vector: [Float!]!): [UserWithDistance]
querySimilarUserById(email: String!, by: UserEmbedding!, topK: Int!): [User]
querySimilarUserByEmbedding(by: UserEmbedding!, topK: Int!, vector: [Float!]!): [User]
queryUser(filter: UserFilter, order: UserOrder, first: Int, offset: Int): [User]
aggregateUser(filter: UserFilter): UserAggregateResult
}
Expand Down

0 comments on commit 7b6080b

Please sign in to comment.