Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add storage name support for all connectors #193

Merged
merged 5 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.microsoft.semantickernel.tests.connectors.memory;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordDataAttribute;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordVectorAttribute;
Expand All @@ -13,8 +15,10 @@ public class Hotel {
private final String name;
@VectorStoreRecordDataAttribute
private final int code;
@JsonProperty("summary")
@VectorStoreRecordDataAttribute(hasEmbedding = true, embeddingFieldName = "descriptionEmbedding")
private final String description;
@JsonProperty("summaryEmbedding")
@VectorStoreRecordVectorAttribute(dimensions = 3)
private final List<Float> descriptionEmbedding;
@VectorStoreRecordDataAttribute
Expand All @@ -24,7 +28,14 @@ public Hotel() {
this(null, null, 0, null, null, 0.0);
}

public Hotel(String id, String name, int code, String description, List<Float> descriptionEmbedding, double rating) {
@JsonCreator
public Hotel(
@JsonProperty("id") String id,
@JsonProperty("name") String name,
@JsonProperty("code") int code,
@JsonProperty("summary") String description,
@JsonProperty("summaryVector") List<Float> descriptionEmbedding,
@JsonProperty("rating") double rating) {
this.id = id;
this.name = name;
this.code = code;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.DockerImageName;

import javax.annotation.Nonnull;
import javax.sql.DataSource;
import java.util.Arrays;
import java.util.List;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ static void setup() {
List<VectorStoreRecordField> fields = new ArrayList<>();
fields.add(VectorStoreRecordKeyField.builder()
.withName("id")
.withFieldType(String.class)
.build());
fields.add(VectorStoreRecordDataField.builder()
.withName("name")
Expand All @@ -64,12 +65,15 @@ static void setup() {
.build());
fields.add(VectorStoreRecordDataField.builder()
.withName("description")
.withStorageName("summary")
.withFieldType(String.class)
.withHasEmbedding(true)
.withEmbeddingFieldName("descriptionEmbedding")
.build());
fields.add(VectorStoreRecordVectorField.builder()
.withName("descriptionEmbedding")
.withStorageName("summaryEmbedding")
.withFieldType(List.class)
.withDimensions(768)
.build());
fields.add(VectorStoreRecordDataField.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ static void setup() {
List<VectorStoreRecordField> fields = new ArrayList<>();
fields.add(VectorStoreRecordKeyField.builder()
.withName("id")
.withFieldType(String.class)
.build());
fields.add(VectorStoreRecordDataField.builder()
.withName("name")
Expand All @@ -65,12 +66,15 @@ static void setup() {
.build());
fields.add(VectorStoreRecordDataField.builder()
.withName("description")
.withStorageName("summary")
.withFieldType(String.class)
.withHasEmbedding(true)
.withEmbeddingFieldName("descriptionEmbedding")
.build());
fields.add(VectorStoreRecordVectorField.builder()
.withName("descriptionEmbedding")
.withStorageName("summaryEmbedding")
.withFieldType(List.class)
.withDimensions(768)
.build());
fields.add(VectorStoreRecordDataField.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import com.azure.core.util.TracingOptions;
import com.azure.search.documents.indexes.SearchIndexAsyncClient;
import com.azure.search.documents.indexes.SearchIndexClientBuilder;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.microsoft.semantickernel.aiservices.openai.textembedding.OpenAITextEmbeddingGenerationService;
import com.microsoft.semantickernel.connectors.data.azureaisearch.AzureAISearchVectorStore;
import com.microsoft.semantickernel.connectors.data.azureaisearch.AzureAISearchVectorStoreOptions;
Expand Down Expand Up @@ -45,7 +47,7 @@ public class AzureAISearch_DataStorage {
private static final int EMBEDDING_DIMENSIONS = 1536;

static class GitHubFile {

@JsonProperty("fileId") // Set a different name for the storage field if needed
@VectorStoreRecordKeyAttribute()
private final String id;
@VectorStoreRecordDataAttribute(hasEmbedding = true, embeddingFieldName = "embedding")
Expand All @@ -60,10 +62,10 @@ public GitHubFile() {
}

public GitHubFile(
String id,
String description,
String link,
List<Float> embedding) {
@JsonProperty("fileId") String id,
@JsonProperty("description") String description,
@JsonProperty("link") String link,
@JsonProperty("embedding") List<Float> embedding) {
this.id = id;
this.description = description;
this.link = link;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
public class AzureAISearchVectorStoreCollectionCreateMapping {

private static String getVectorSearchProfileName(VectorStoreRecordVectorField vectorField) {
return vectorField.getName() + "Profile";
return vectorField.getEffectiveStorageName() + "Profile";
}

private static String getAlgorithmConfigName(VectorStoreRecordVectorField vectorField) {
return vectorField.getName() + "AlgorithmConfig";
return vectorField.getEffectiveStorageName() + "AlgorithmConfig";
}

private static VectorSearchAlgorithmMetric getAlgorithmMetric(
Expand Down Expand Up @@ -68,24 +68,24 @@ private static VectorSearchAlgorithmConfiguration getAlgorithmConfig(
}

public static SearchField mapKeyField(VectorStoreRecordKeyField keyField) {
return new SearchField(keyField.getName(), SearchFieldDataType.STRING)
return new SearchField(keyField.getEffectiveStorageName(), SearchFieldDataType.STRING)
.setKey(true)
.setFilterable(true);
}

public static SearchField mapDataField(VectorStoreRecordDataField dataField) {
if (dataField.getFieldType() == null) {
throw new IllegalArgumentException(
"Field type is required: " + dataField.getName());
"Field type is required: " + dataField.getEffectiveStorageName());
}

return new SearchField(dataField.getName(),
return new SearchField(dataField.getEffectiveStorageName(),
getSearchFieldDataType(dataField.getFieldType()))
.setFilterable(dataField.isFilterable());
}

public static SearchField mapVectorField(VectorStoreRecordVectorField vectorField) {
return new SearchField(vectorField.getName(),
return new SearchField(vectorField.getEffectiveStorageName(),
SearchFieldDataType.collection(SearchFieldDataType.SINGLE))
.setSearchable(true)
.setVectorSearchDimensions(vectorField.getDimensions())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,19 @@ public AzureAISearchVectorStoreRecordCollection(

// Validate supported types
VectorStoreRecordDefinition.validateSupportedTypes(
Collections
.singletonList(recordDefinition.getKeyDeclaredField(this.options.getRecordClass())),
Collections.singletonList(recordDefinition.getKeyField()),
supportedKeyTypes);
VectorStoreRecordDefinition.validateSupportedTypes(
recordDefinition.getDataDeclaredFields(this.options.getRecordClass()),
new ArrayList<>(recordDefinition.getDataFields()),
supportedDataTypes);
VectorStoreRecordDefinition.validateSupportedTypes(
recordDefinition.getVectorDeclaredFields(this.options.getRecordClass()),
new ArrayList<>(recordDefinition.getVectorFields()),
supportedVectorTypes);

// Add non-vector fields to the list
nonVectorFields.add(this.recordDefinition.getKeyField().getName());
nonVectorFields.add(this.recordDefinition.getKeyField().getEffectiveStorageName());
nonVectorFields.addAll(this.recordDefinition.getDataFields().stream()
.map(VectorStoreRecordDataField::getName)
.map(VectorStoreRecordDataField::getEffectiveStorageName)
.collect(Collectors.toList()));
}

Expand Down Expand Up @@ -256,7 +255,7 @@ public Mono<Void> deleteBatchAsync(List<String> keys, DeleteRecordOptions option

return client.deleteDocuments(keys.stream().map(key -> {
SearchDocument document = new SearchDocument();
document.put(this.recordDefinition.getKeyField().getName(), key);
document.put(this.recordDefinition.getKeyField().getEffectiveStorageName(), key);
return document;
}).collect(Collectors.toList())).then();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import javax.annotation.Nonnull;
import javax.sql.DataSource;
import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
Expand All @@ -30,9 +29,9 @@
public class JDBCVectorStoreDefaultQueryProvider
implements JDBCVectorStoreQueryProvider {

private Map<Class<?>, String> supportedKeyTypes;
private Map<Class<?>, String> supportedDataTypes;
private Map<Class<?>, String> supportedVectorTypes;
private final Map<Class<?>, String> supportedKeyTypes;
private final Map<Class<?>, String> supportedDataTypes;
private final Map<Class<?>, String> supportedVectorTypes;
private final DataSource dataSource;
private final String collectionsTable;
private final String prefixForCollectionTables;
Expand Down Expand Up @@ -88,13 +87,24 @@ protected String getWildcardString(int wildcards) {
.collect(Collectors.joining(", "));
}

/**
* Gets the key column name from a key field.
* @param keyField the key field
* @return the key column name
*/
protected String getKeyColumnName(VectorStoreRecordField keyField) {
return validateSQLidentifier(keyField.getEffectiveStorageName());
}

/**
* Formats the query columns from a record definition.
* @param fields the fields to get the columns from
* @return the formatted query columns
*/
protected String getQueryColumnsFromFields(List<VectorStoreRecordField> fields) {
return fields.stream().map(VectorStoreRecordField::getName)
return fields.stream()
.map(VectorStoreRecordField::getEffectiveStorageName)
.map(JDBCVectorStoreDefaultQueryProvider::validateSQLidentifier)
.collect(Collectors.joining(", "));
}

Expand All @@ -104,9 +114,11 @@ protected String getQueryColumnsFromFields(List<VectorStoreRecordField> fields)
* @param types the types
* @return the formatted column names and types
*/
protected String getColumnNamesAndTypes(List<Field> fields, Map<Class<?>, String> types) {
protected String getColumnNamesAndTypes(List<VectorStoreRecordField> fields,
Map<Class<?>, String> types) {
List<String> columns = fields.stream()
.map(field -> field.getName() + " " + types.get(field.getType()))
.map(field -> validateSQLidentifier(field.getEffectiveStorageName()) + " "
+ types.get(field.getFieldType()))
.collect(Collectors.toList());

return String.join(", ", columns);
Expand Down Expand Up @@ -169,20 +181,20 @@ public void prepareVectorStore() {
/**
* Checks if the types of the record class fields are supported.
*
* @param recordClass the record class
* @param recordDefinition the record definition
* @throws IllegalArgumentException if the types are not supported
*/
@Override
public void validateSupportedTypes(Class<?> recordClass,
VectorStoreRecordDefinition recordDefinition) {
public void validateSupportedTypes(VectorStoreRecordDefinition recordDefinition) {

VectorStoreRecordDefinition.validateSupportedTypes(
Collections.singletonList(recordDefinition.getKeyDeclaredField(recordClass)),
Collections.singletonList(recordDefinition.getKeyField()),
getSupportedKeyTypes().keySet());
VectorStoreRecordDefinition.validateSupportedTypes(
recordDefinition.getDataDeclaredFields(recordClass), getSupportedDataTypes().keySet());
new ArrayList<>(recordDefinition.getDataFields()),
getSupportedDataTypes().keySet());
VectorStoreRecordDefinition.validateSupportedTypes(
recordDefinition.getVectorDeclaredFields(recordClass),
new ArrayList<>(recordDefinition.getVectorFields()),
getSupportedVectorTypes().keySet());
}

Expand Down Expand Up @@ -212,23 +224,23 @@ public boolean collectionExists(String collectionName) {
* Creates a collection.
*
* @param collectionName the collection name
* @param recordClass the record class
* @param recordDefinition the record definition
* @throws SKException if an error occurs while creating the collection
*/
@Override
@SuppressFBWarnings("SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING") // SQL query is generated dynamically with valid identifiers
public void createCollection(String collectionName, Class<?> recordClass,
public void createCollection(String collectionName,
VectorStoreRecordDefinition recordDefinition) {
Field keyDeclaredField = recordDefinition.getKeyDeclaredField(recordClass);
List<Field> dataDeclaredFields = recordDefinition.getDataDeclaredFields(recordClass);
List<Field> vectorDeclaredFields = recordDefinition.getVectorDeclaredFields(recordClass);

String createStorageTable = "CREATE TABLE IF NOT EXISTS "
+ getCollectionTableName(collectionName)
+ " (" + keyDeclaredField.getName() + " VARCHAR(255) PRIMARY KEY, "
+ getColumnNamesAndTypes(dataDeclaredFields, getSupportedDataTypes()) + ", "
+ getColumnNamesAndTypes(vectorDeclaredFields, getSupportedVectorTypes()) + ");";
+ getCollectionTableName(collectionName) + " ("
+ getKeyColumnName(recordDefinition.getKeyField()) + " VARCHAR(255) PRIMARY KEY, "
+ getColumnNamesAndTypes(new ArrayList<>(recordDefinition.getDataFields()),
getSupportedDataTypes())
+ ", "
+ getColumnNamesAndTypes(new ArrayList<>(recordDefinition.getVectorFields()),
getSupportedVectorTypes())
+ ");";

String insertCollectionQuery = "INSERT INTO " + validateSQLidentifier(collectionsTable)
+ " (collectionId) VALUES (?)";
Expand Down Expand Up @@ -329,7 +341,7 @@ public <Record> List<Record> getRecords(String collectionName, List<String> keys

String query = "SELECT " + getQueryColumnsFromFields(fields)
+ " FROM " + getCollectionTableName(collectionName)
+ " WHERE " + recordDefinition.getKeyField().getName()
+ " WHERE " + getKeyColumnName(recordDefinition.getKeyField())
+ " IN (" + getWildcardString(keys.size()) + ")";

try (Connection connection = dataSource.getConnection();
Expand Down Expand Up @@ -371,7 +383,7 @@ public void upsertRecords(String collectionName, List<?> records,
public void deleteRecords(String collectionName, List<String> keys,
VectorStoreRecordDefinition recordDefinition, DeleteRecordOptions options) {
String query = "DELETE FROM " + getCollectionTableName(collectionName)
+ " WHERE " + recordDefinition.getKeyField().getName()
+ " WHERE " + getKeyColumnName(recordDefinition.getKeyField())
+ " IN (" + getWildcardString(keys.size()) + ")";

try (Connection connection = dataSource.getConnection();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,9 @@ public interface JDBCVectorStoreQueryProvider {
/**
* Checks if the types of the record class fields are supported.
*
* @param recordClass the record class
* @param recordDefinition the record definition
*/
void validateSupportedTypes(Class<?> recordClass, VectorStoreRecordDefinition recordDefinition);
void validateSupportedTypes(VectorStoreRecordDefinition recordDefinition);

/**
* Checks if a collection exists.
Expand All @@ -74,11 +73,9 @@ public interface JDBCVectorStoreQueryProvider {
* Creates a collection.
*
* @param collectionName the collection name
* @param recordClass the record class
* @param recordDefinition the record definition
*/
void createCollection(String collectionName, Class<?> recordClass,
VectorStoreRecordDefinition recordDefinition);
void createCollection(String collectionName, VectorStoreRecordDefinition recordDefinition);

/**
* Deletes a collection.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public JDBCVectorStoreRecordCollection(
}

// Check if the types are supported
queryProvider.validateSupportedTypes(options.getRecordClass(), recordDefinition);
queryProvider.validateSupportedTypes(recordDefinition);
}

/**
Expand Down Expand Up @@ -122,8 +122,7 @@ public Mono<Boolean> collectionExistsAsync() {
@Override
public Mono<VectorStoreRecordCollection<String, Record>> createCollectionAsync() {
return Mono.fromRunnable(
() -> queryProvider.createCollection(this.collectionName, options.getRecordClass(),
recordDefinition))
() -> queryProvider.createCollection(this.collectionName, recordDefinition))
.subscribeOn(Schedulers.boundedElastic())
.then(Mono.just(this));
}
Expand Down Expand Up @@ -200,7 +199,7 @@ public Mono<List<Record>> getBatchAsync(List<String> keys, GetRecordOptions opti
protected String getKeyFromRecord(Record data) {
try {
Field keyField = data.getClass()
.getDeclaredField(recordDefinition.getKeyField().getName());
.getDeclaredField(recordDefinition.getKeyField().getEffectiveStorageName());
keyField.setAccessible(true);
return (String) keyField.get(data);
} catch (NoSuchFieldException | IllegalAccessException e) {
Expand Down
Loading
Loading