Skip to content

Commit

Permalink
Feat(GraphQL): First iteration of graphQL support for similarity sear…
Browse files Browse the repository at this point in the history
…ch based on vector indexes (#48)

- [Float] data type with an @Embedding directive in GraphQL is mapped to
"VFLOAT" data type in dgraph
- @Embedding directive is allowed only for [Float] data type
- Add querySimilar<Type> query that performs similarity search based on
HNSW (vector) indexes for types with fields having @Embedding directive
- querySimilar<Type> accepts 3 arguments

1. id - unique id for the object to run similarity search on.
2. <predicate> - An enum of embedding predicates defined in the type to
base the search on
3. <topK> - number of nearest neighbors to return
  • Loading branch information
sunilmujumdar authored and Harshil Goel committed Mar 12, 2024
1 parent 69dc735 commit 2825fa5
Show file tree
Hide file tree
Showing 69 changed files with 740 additions and 30 deletions.
33 changes: 31 additions & 2 deletions graphql/dgraph/graphquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,16 @@ import (
// validate query, and so doesn't return an error if query is 'malformed' - it might
// just write something that wouldn't parse as a Dgraph query.
func AsString(queries []*dql.GraphQuery) string {
if queries == nil {
if len(queries) == 0 {
return ""
}

var b strings.Builder
x.Check2(b.WriteString("query {\n"))
queryName := queries[len(queries)-1].Attr
x.Check2(b.WriteString("query "))
addQueryVars(&b, queryName, queries[0].Args)
x.Check2(b.WriteString("{\n"))

numRewrittenQueries := 0
for _, q := range queries {
if q == nil {
Expand All @@ -54,6 +58,24 @@ func AsString(queries []*dql.GraphQuery) string {
return b.String()
}

func addQueryVars(b *strings.Builder, queryName string, args map[string]string) {
dollarFound := false
for name, val := range args {
if strings.HasPrefix(name, "$") {
if !dollarFound {
x.Check2(b.WriteString(queryName + "("))
x.Check2(b.WriteString(name + ": " + val))
dollarFound = true
} else {
x.Check2(b.WriteString(", " + name + ": " + val))
}
}
}
if dollarFound {
x.Check2(b.WriteString(") "))
}
}

func writeQuery(b *strings.Builder, query *dql.GraphQuery, prefix string) {
if query.Var != "" || query.Alias != "" || query.Attr != "" {
x.Check2(b.WriteString(prefix))
Expand Down Expand Up @@ -145,6 +167,9 @@ func writeRoot(b *strings.Builder, q *dql.GraphQuery) {
}

switch {
// TODO: Instead of the hard-coded strings "uid", "type", etc., use the
// pre-defined constants in dql/parser.go such as dql.uidFunc, dql.typFunc,
// etc. This of course will require that we make these constants public.
case q.Func.Name == "uid":
x.Check2(b.WriteString("(func: "))
writeUIDFunc(b, q.Func.UID, q.Func.Args)
Expand All @@ -154,6 +179,10 @@ func writeRoot(b *strings.Builder, q *dql.GraphQuery) {
x.Check2(b.WriteString("(func: eq("))
writeFilterArguments(b, q.Func.Args)
x.Check2(b.WriteRune(')'))
case q.Func.Name == "similar_to":
x.Check2(b.WriteString("(func: similar_to("))
writeFilterArguments(b, q.Func.Args)
x.Check2(b.WriteRune(')'))
}
writeOrderAndPage(b, q, true)
x.Check2(b.WriteRune(')'))
Expand Down
1 change: 1 addition & 0 deletions graphql/e2e/schema/apollo_service_response.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ input GenerateMutationParams {

directive @hasInverse(field: String!) on FIELD_DEFINITION
directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION
directive @hm_embedding on FIELD_DEFINITION
directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION
directive @id(interface: Boolean) on FIELD_DEFINITION
directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION
Expand Down
1 change: 1 addition & 0 deletions graphql/e2e/schema/generatedSchema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ input GenerateMutationParams {

directive @hasInverse(field: String!) on FIELD_DEFINITION
directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION
directive @hm_embedding on FIELD_DEFINITION
directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION
directive @id(interface: Boolean) on FIELD_DEFINITION
directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION
Expand Down
6 changes: 6 additions & 0 deletions graphql/resolve/mutation_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1675,6 +1675,12 @@ func rewriteObject(
fieldName = fieldName[1 : len(fieldName)-1]
}

if fieldDef.HasEmbeddingDirective() {
// embedding is a JSON array of numbers. Rewrite it as a string, for now
var valBytes []byte
valBytes, _ = json.Marshal(val)
val = string(valBytes)
}
// TODO: Write a function for aggregating data of fragment from child nodes.
switch val := val.(type) {
case map[string]interface{}:
Expand Down
272 changes: 271 additions & 1 deletion graphql/resolve/query_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package resolve
import (
"bytes"
"context"
"encoding/json"
"fmt"
"sort"
"strconv"
Expand Down Expand Up @@ -147,7 +148,14 @@ func (qr *queryRewriter) Rewrite(

dgQuery := rewriteAsGet(gqlQuery, uid, xid, authRw)
return dgQuery, nil

case schema.SimilarByIdQuery:
xid, uid, err := gqlQuery.IDArgValue()
if err != nil {
return nil, err
}
return rewriteAsSimilarByIdQuery(gqlQuery, uid, xid, authRw), nil
case schema.SimilarByEmbeddingQuery:
return rewriteAsSimilarByEmbeddingQuery(gqlQuery, authRw), nil
case schema.FilterQuery:
return rewriteAsQuery(gqlQuery, authRw), nil
case schema.PasswordQuery:
Expand Down Expand Up @@ -612,6 +620,268 @@ func rewriteAsGet(
return dgQuery
}

// rewriteAsSimilarByIdQuery
//
// rewrites SimilarById graphQL query to nested DQL query blocks
// Example rewrittern query:
//
// query {
// var(func: eq(Product.id, "0528012398")) @filter(type(Product)) {
// vec as Product.embedding
// }
// var() {
// v1 as max(val(vec))
// }
// var(func: similar_to(Product.embedding, 8, val(v1))) {
// v2 as Product.embedding
// distance as math((v2 - v1) dot (v2 - v1))
// }
// querySimilarProductById(func: uid(distance), orderasc: val(distance)) {
// ProductWithDistance.id : Product.id
// ProductWithDistance.description : Product.description
// ProductWithDistance.title : Product.title
// ProductWithDistance.imageUrl : Product.imageUrl
// ProductWithDistance.hm_distance : val(distance)
// dgraph.uid : uid
// }
// }
func rewriteAsSimilarByIdQuery(
query schema.Query,
uid uint64,
xidArgToVal map[string]string,
auth *authRewriter) []*dql.GraphQuery {

// Get graphQL arguments
typ := query.Type()
pred := typ.DgraphPredicate(query.ArgValue(schema.SimilarByArgName).(string))
topK := query.ArgValue(schema.SimilarTopKArgName).(int64)

// First generate the query to fetch the uid
// for the given id. For Example,
// var(func: eq(Product.id, "0528012398")) @filter(type(Product)) {
// vec as Product.embedding
// }
dgQuery := rewriteAsGet(query, uid, xidArgToVal, auth)
lastQuery := dgQuery[len(dgQuery)-1]
// Turn the root query into "var"
lastQuery.Attr = "var"
// Save the result to be later used for the last query block, sortQuery
result := lastQuery.Children

// define the variable "vec" for the search vector
lastQuery.Children = []*dql.GraphQuery{{
Attr: pred,
Var: "vec",
}}

// Turn the variable into a "const" by
// remembering the max of it.
// The lookup is going to return exactly one uid
// anyway. For example,
// var() {
// v1 as max(val(vec))
// }
aggQuery := &dql.GraphQuery{
Attr: "var" + "()",
Children: []*dql.GraphQuery{
{
Var: "v1",
Attr: "max(val(vec))",
},
},
}

// Similar_to query, computes the distance for
// ordering the result later.
// Example:
// var(func: similar_to(Product.embedding, 8, val(v1))) {
// v2 as Product.embedding
// distance as math((v2 - v1) dot (v2 - v1))
// }
similarQuery := &dql.GraphQuery{
Attr: "var",
Children: []*dql.GraphQuery{
{
Var: "v2",
Attr: pred,
},
{
Var: "distance",
Attr: "math((v2 - v1) dot (v2 - v1))",
},
},
Func: &dql.Function{
Name: "similar_to",
Args: []dql.Arg{
{
Value: pred,
},
{
Value: fmt.Sprintf("%d", topK),
},
{
Value: "val(v1)",
},
},
},
}

// Rename the distance as <Type>.hm_distance
distance := &dql.GraphQuery{
Alias: typ.Name() + "." + schema.SimilarQueryDistanceFieldName,
Attr: "val(distance)",
}

var found bool = false
for _, child := range result {
if child.Alias == typ.Name()+"."+schema.SimilarQueryDistanceFieldName {
child.Attr = "val(distance)"
found = true
break
}
}
if !found {
result = append(result, distance)
}

// order the result by euclidian distance, For example,
// querySimilarProductById(func: uid(distance), orderasc: val(distance)) {
// ProductWithDistance.id : Product.id
// ProductWithDistance.description : Product.description
// ProductWithDistance.title : Product.title
// ProductWithDistance.imageUrl : Product.imageUrl
// ProductWithDistance.hm_distance : val(distance)
// dgraph.uid : uid
// }
// }
sortQuery := &dql.GraphQuery{
Attr: query.DgraphAlias(),
Children: result,
Func: &dql.Function{
Name: "uid",
Args: []dql.Arg{{Value: "distance"}},
},
Order: []*pb.Order{{Attr: "val(distance)", Desc: false}},
}

dgQuery = append(dgQuery, aggQuery, similarQuery, sortQuery)
return dgQuery
}

// rewriteAsSimilarByEmbeddingQuery
//
// rewrites SimilarByEmbedding graphQL query to nested DQL query blocks
// Example rewrittern query:
//
// query gQLTodQL($search_vector: vfloat = "<json array of float>") {
// var(func: similar_to(Product.embedding, 8, $search_vector)) {
// v2 as Product.embedding
// distance as math((v2 - $search_vector) dot (v2 - $search_vector))
// }
// querySimilarProductById(func: uid(distance), orderasc: val(distance)) {
// ProductWithDistance.id : Product.id
// ProductWithDistance.description : Product.description
// ProductWithDistance.title : Product.title
// ProductWithDistance.imageUrl : Product.imageUrl
// ProductWithDistance.hm_distance : val(distance)
// dgraph.uid : uid
// }
// }
func rewriteAsSimilarByEmbeddingQuery(
query schema.Query, auth *authRewriter) []*dql.GraphQuery {

dgQuery := rewriteAsQuery(query, auth)

// Remember dgQuery[0].Children as result type for the last block
// in the rewritten query
result := dgQuery[0].Children
typ := query.Type()

// Get all the arguments from graphQL query
pred := typ.DgraphPredicate(query.ArgValue(schema.SimilarByArgName).(string))
topK := query.ArgValue(schema.SimilarTopKArgName).(int64)
vec := query.ArgValue(schema.SimilarVectorArgName).([]interface{})
vecStr, _ := json.Marshal(vec)

// Save vectorString as a query variable, $search_vector
queryArgs := dgQuery[0].Args
if queryArgs == nil {
queryArgs = make(map[string]string)
}
queryArgs["$search_vector"] = " vfloat = \"" + string(vecStr) + "\""
thisFilter := &dql.FilterTree{
Func: dgQuery[0].Func,
}

// create the similar_to function and move existing root function
// to the filter tree
addToFilterTree(dgQuery[0], thisFilter)

// Create similar_to as the root function, passing $search_vector as
// the search vector
dgQuery[0].Attr = "var"
dgQuery[0].Func = &dql.Function{
Name: "similar_to",
Args: []dql.Arg{
{
Value: pred,
},
{
Value: fmt.Sprintf("%d", topK),
},
{
Value: "$search_vector",
},
},
}

// Compute the euclidian distance between the neighbor
// and the search vector
dgQuery[0].Children = []*dql.GraphQuery{
{
Var: "v2",
Attr: pred,
},
{
Var: "distance",
// TODO: generate different math formula based on index type
Attr: "math((v2 - $search_vector) dot (v2 - $search_vector))",
},
}

// Rename distance as <Type>.hm_distance
distance := &dql.GraphQuery{
Alias: typ.Name() + "." + schema.SimilarQueryDistanceFieldName,
Attr: "val(distance)",
}

var found bool = false
for _, child := range result {
if child.Alias == typ.Name()+"."+schema.SimilarQueryDistanceFieldName {
child.Attr = "val(distance)"
found = true
break
}
}
if !found {
result = append(result, distance)
}

// order by distance
sortQuery := &dql.GraphQuery{
Attr: query.DgraphAlias(),
Children: result,
Func: &dql.Function{
Name: "uid",
Args: []dql.Arg{{Value: "distance"}},
},
Order: []*pb.Order{{Attr: "val(distance)", Desc: false}},
}

dgQuery = append(dgQuery, sortQuery)
return dgQuery
}

// Adds common RBAC and UID, Type rules to DQL query.
// This function is used by rewriteAsQuery and aggregateQuery functions
func addCommonRules(
Expand Down
2 changes: 2 additions & 0 deletions graphql/resolve/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ func (rf *resolverFactory) WithConventionResolvers(
s schema.Schema, fns *ResolverFns) ResolverFactory {

queries := append(s.Queries(schema.GetQuery), s.Queries(schema.FilterQuery)...)
queries = append(queries, s.Queries(schema.SimilarByIdQuery)...)
queries = append(queries, s.Queries(schema.SimilarByEmbeddingQuery)...)
queries = append(queries, s.Queries(schema.PasswordQuery)...)
queries = append(queries, s.Queries(schema.AggregateQuery)...)
for _, q := range queries {
Expand Down
Loading

0 comments on commit 2825fa5

Please sign in to comment.