Skip to content

Commit

Permalink
add test for dotproduct and cosine index and fix failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
shivaji-dgraph committed May 15, 2024
1 parent 437ff80 commit c430009
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 59 deletions.
3 changes: 2 additions & 1 deletion graphql/resolve/query_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,8 @@ func rewriteAsSimilarByEmbeddingQuery(
if metric == schema.SimilarSearchMetricDotProduct {
distanceFormula = "math(( 1.0 - (($search_vector) dot v2)) /2.0)"
} else if metric == schema.SimilarSearchMetricCosine {
distanceFormula = "math((1.0 - ( (($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector)) * (v2 dot v2) ) )) / 2.0)"
distanceFormula = "math((1.0 - ( (($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector))" +
" * (v2 dot v2) ) )) / 2.0)"
}

// Save vectorString as a query variable, $search_vector
Expand Down
12 changes: 6 additions & 6 deletions graphql/resolve/query_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3367,7 +3367,7 @@
query querySimilarProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
var(func: similar_to(Product.productVector, 1, $search_vector)) @filter(type(Product)) {
v2 as Product.productVector
distance as math((v2 - $search_vector) dot (v2 - $search_vector))
distance as math(sqrt((v2 - $search_vector) dot (v2 - $search_vector)))
}
querySimilarProductByEmbedding(func: uid(distance), orderasc: val(distance)) {
Product.id : Product.id
Expand Down Expand Up @@ -3397,7 +3397,7 @@
}
var(func: similar_to(Product.productVector, 3, val(v1))) {
v2 as Product.productVector
distance as math((v2 - v1) dot (v2 - v1))
distance as math(sqrt((v2 - v1) dot (v2 - v1)))
}
querySimilarProductById(func: uid(distance), orderasc: val(distance)) {
Product.id : Product.id
Expand Down Expand Up @@ -3428,7 +3428,7 @@
}
var(func: similar_to(ProjectCosine.description_v, 3, val(v1))) {
v2 as ProjectCosine.description_v
distance as math((v1 dot v2) / ((v1 dot v1) * (v2 dot v2)))
distance as math((1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) ) )) / 2.0)
}
querySimilarProjectCosineById(func: uid(distance), orderasc: val(distance)) {
ProjectCosine.id : ProjectCosine.id
Expand All @@ -3453,7 +3453,7 @@
query querySimilarProjectCosineByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
var(func: similar_to(ProjectCosine.description_v, 1, $search_vector)) @filter(type(ProjectCosine)) {
v2 as ProjectCosine.description_v
distance as math(($search_vector dot v2) / (($search_vector dot $search_vector) * (v2 dot v2)))
distance as math((1.0 - ( (($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector)) * (v2 dot v2) ) )) / 2.0)
}
querySimilarProjectCosineByEmbedding(func: uid(distance), orderasc: val(distance)) {
ProjectCosine.id : ProjectCosine.id
Expand Down Expand Up @@ -3483,7 +3483,7 @@
}
var(func: similar_to(ProjectDotProduct.description_v, 3, val(v1))) {
v2 as ProjectDotProduct.description_v
distance as math(v1 dot v2)
distance as math((1.0 - (v1 dot v2)) /2.0)
}
querySimilarProjectDotProductById(func: uid(distance), orderasc: val(distance)) {
ProjectDotProduct.id : ProjectDotProduct.id
Expand All @@ -3508,7 +3508,7 @@
query querySimilarProjectDotProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
var(func: similar_to(ProjectDotProduct.description_v, 1, $search_vector)) @filter(type(ProjectDotProduct)) {
v2 as ProjectDotProduct.description_v
distance as math($search_vector dot v2)
distance as math(( 1.0 - (($search_vector) dot v2)) /2.0)
}
querySimilarProjectDotProductByEmbedding(func: uid(distance), orderasc: val(distance)) {
ProjectDotProduct.id : ProjectDotProduct.id
Expand Down
168 changes: 116 additions & 52 deletions query/vector/vector_graphql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package query

import (
"encoding/json"
"fmt"
"math/rand"
"testing"

"github.com/dgraph-io/dgraph/dgraphtest"
Expand All @@ -36,29 +38,56 @@ const (
type Project {
id: ID!
title: String! @search(by: [exact])
title_v: [Float!] @embedding @search(by: ["hnsw(metric: euclidian, exponent: 4)"])
}
`
title_v: [Float!] @embedding @search(by: ["hnsw(metric: %v, exponent: 4)"])
} `
)

var (
projects = []ProjectInput{ProjectInput{
Title: "iCreate with a Mini iPad",
TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2},
}, ProjectInput{
Title: "Resistive Touchscreen",
TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2},
}, ProjectInput{
Title: "Fitness Band",
TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2},
}, ProjectInput{
Title: "Smart Watch",
TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2},
}, ProjectInput{
Title: "Smart Ring",
TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2},
}}
)
func generateProjects(count int) []ProjectInput {
var projects []ProjectInput
for i := 0; i < count; i++ {
title := generateUniqueRandomTitle(projects)
titleV := generateRandomTitleV(5) // Assuming size is fixed at 5
project := ProjectInput{
Title: title,
TitleV: titleV,
}
projects = append(projects, project)
}
return projects
}

func isTitleExists(title string, existingTitles []ProjectInput) bool {
for _, project := range existingTitles {
if project.Title == title {
return true
}
}
return false
}

func generateUniqueRandomTitle(existingTitles []ProjectInput) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
const titleLength = 10
title := make([]byte, titleLength)
for {
for i := range title {
title[i] = charset[rand.Intn(len(charset))]
}
titleStr := string(title)
if !isTitleExists(titleStr, existingTitles) {
return titleStr
}
}
}

func generateRandomTitleV(size int) []float32 {
var titleV []float32
for i := 0; i < size; i++ {
value := rand.Float32()
titleV = append(titleV, value)
}
return titleV
}

func addProject(t *testing.T, hc *dgraphtest.HTTPClient, project ProjectInput) {
query := `
Expand All @@ -79,6 +108,7 @@ func addProject(t *testing.T, hc *dgraphtest.HTTPClient, project ProjectInput) {
_, err := hc.RunGraphqlQuery(params, false)
require.NoError(t, err)
}

func queryProjectUsingTitle(t *testing.T, hc *dgraphtest.HTTPClient, title string) ProjectInput {
query := ` query QueryProject($title: String!) {
queryProject(filter: { title: { eq: $title } }) {
Expand All @@ -96,19 +126,17 @@ func queryProjectUsingTitle(t *testing.T, hc *dgraphtest.HTTPClient, title strin
type QueryResult struct {
QueryProject []ProjectInput `json:"queryProject"`
}

var resp QueryResult
err = json.Unmarshal([]byte(string(response)), &resp)
require.NoError(t, err)

return resp.QueryProject[0]
}

func queryProjectsSimilarByEmbedding(t *testing.T, hc *dgraphtest.HTTPClient, vector []float32) []ProjectInput {
func queryProjectsSimilarByEmbedding(t *testing.T, hc *dgraphtest.HTTPClient, vector []float32, topk int) []ProjectInput {
// query similar project by embedding
queryProduct := `query QuerySimilarProjectByEmbedding($by: ProjectEmbedding!, $topK: Int!, $vector: [Float!]!) {
querySimilarProjectByEmbedding(by: $by, topK: $topK, vector: $vector) {
id
title
title_v
}
Expand All @@ -120,13 +148,13 @@ func queryProjectsSimilarByEmbedding(t *testing.T, hc *dgraphtest.HTTPClient, ve
Query: queryProduct,
Variables: map[string]interface{}{
"by": "title_v",
"topK": 3,
"topK": topk,
"vector": vector,
}}
response, err := hc.RunGraphqlQuery(params, false)
require.NoError(t, err)
type QueryResult struct {
QueryProject []ProjectInput `json:"queryProject"`
QueryProject []ProjectInput `json:"querySimilarProjectByEmbedding"`
}
var resp QueryResult
err = json.Unmarshal([]byte(string(response)), &resp)
Expand All @@ -143,64 +171,100 @@ func TestVectorGraphQLAddVectorPredicate(t *testing.T) {
require.NoError(t, err)
hc.LoginIntoNamespace("groot", "password", 0)
// add schema
require.NoError(t, hc.UpdateGQLSchema(graphQLVectorSchema))
require.NoError(t, hc.UpdateGQLSchema(fmt.Sprintf(graphQLVectorSchema, "euclidean")))
}

func TestVectorGraphQlMutationAndQuery(t *testing.T) {
func TestVectorSchema(t *testing.T) {
require.NoError(t, client.DropAll())

hc, err := dc.HTTPClient()
require.NoError(t, err)
hc.LoginIntoNamespace("groot", "password", 0)

schema := `type Project {
id: ID!
title: String! @search(by: [exact])
title_v: [Float!]
}`

// add schema
require.NoError(t, hc.UpdateGQLSchema(schema))
require.Error(t, hc.UpdateGQLSchema(fmt.Sprintf(graphQLVectorSchema, "euclidean")))
require.NoError(t, client.DropAll())
require.Error(t, hc.UpdateGQLSchema(fmt.Sprintf(graphQLVectorSchema, "dotproduct")))
require.NoError(t, client.DropAll())
require.Error(t, hc.UpdateGQLSchema(fmt.Sprintf(graphQLVectorSchema, "cosine")))
}

func TestVectorGraphQlEuclidianIndexMutationAndQuery(t *testing.T) {
require.NoError(t, client.DropAll())
hc, err := dc.HTTPClient()
require.NoError(t, err)
hc.LoginIntoNamespace("groot", "password", 0)

schema := fmt.Sprintf(graphQLVectorSchema, "euclidean")
// add schema
require.NoError(t, hc.UpdateGQLSchema(graphQLVectorSchema))
require.NoError(t, hc.UpdateGQLSchema(schema))
testVectorGraphQlMutationAndQuery(t, hc)

}

// add project
func TestVectorGraphQlCosineIndexMutationAndQuery(t *testing.T) {
require.NoError(t, client.DropAll())
hc, err := dc.HTTPClient()
require.NoError(t, err)
hc.LoginIntoNamespace("groot", "password", 0)

schema := fmt.Sprintf(graphQLVectorSchema, "cosine")
// add schema
require.NoError(t, hc.UpdateGQLSchema(schema))
testVectorGraphQlMutationAndQuery(t, hc)

}

func TestVectorGraphQlDotProductIndexMutationAndQuery(t *testing.T) {
require.NoError(t, client.DropAll())
hc, err := dc.HTTPClient()
require.NoError(t, err)
hc.LoginIntoNamespace("groot", "password", 0)

schema := fmt.Sprintf(graphQLVectorSchema, "dotproduct")
// add schema
require.NoError(t, hc.UpdateGQLSchema(schema))
testVectorGraphQlMutationAndQuery(t, hc)

}

func testVectorGraphQlMutationAndQuery(t *testing.T, hc *dgraphtest.HTTPClient) {
var vectors [][]float32
numProjects := 100
projects := generateProjects(numProjects)
fmt.Println("projects", len(projects))
for _, project := range projects {
vectors = append(vectors, project.TitleV)
addProject(t, hc, project)
}

for _, project := range projects {
p := queryProjectUsingTitle(t, hc, project.Title)
fmt.Println("p", p)
require.Equal(t, project.Title, p.Title)
require.Equal(t, project.TitleV, p.TitleV)
}

for _, project := range projects {
p := queryProjectUsingTitle(t, hc, project.Title)
fmt.Println("p1", p)

require.Equal(t, project.Title, p.Title)
require.Equal(t, project.TitleV, p.TitleV)
}

// query similar project by embedding
for _, project := range projects {
similarProjects := queryProjectsSimilarByEmbedding(t, hc, project.TitleV)

similarProjects := queryProjectsSimilarByEmbedding(t, hc, project.TitleV, numProjects)
for _, similarVec := range similarProjects {
require.Contains(t, vectors, similarVec.TitleV)
}
}
}

func TestVectorSchema(t *testing.T) {
require.NoError(t, client.DropAll())

hc, err := dc.HTTPClient()
require.NoError(t, err)
hc.LoginIntoNamespace("groot", "password", 0)

schema := `type Project {
id: ID!
title: String! @search(by: [exact])
title_v: [Float!]
}`

// add schema
require.NoError(t, hc.UpdateGQLSchema(schema))
require.Error(t, hc.UpdateGQLSchema(graphQLVectorSchema))
require.NoError(t, client.DropAll())
require.NoError(t, hc.UpdateGQLSchema(graphQLVectorSchema))
}

0 comments on commit c430009

Please sign in to comment.