Skip to content

Commit

Permalink
Merge pull request #193 from milderhc/storage-name
Browse files Browse the repository at this point in the history
Add storage name support for all connectors
  • Loading branch information
milderhc authored Aug 22, 2024
2 parents c746f1f + 4d25116 commit 9bf0833
Show file tree
Hide file tree
Showing 26 changed files with 417 additions and 378 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.microsoft.semantickernel.tests.connectors.memory;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordDataAttribute;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordVectorAttribute;
Expand All @@ -13,8 +15,10 @@ public class Hotel {
private final String name;
@VectorStoreRecordDataAttribute
private final int code;
@JsonProperty("summary")
@VectorStoreRecordDataAttribute(hasEmbedding = true, embeddingFieldName = "descriptionEmbedding")
private final String description;
@JsonProperty("summaryEmbedding")
@VectorStoreRecordVectorAttribute(dimensions = 3)
private final List<Float> descriptionEmbedding;
@VectorStoreRecordDataAttribute
Expand All @@ -24,7 +28,14 @@ public Hotel() {
this(null, null, 0, null, null, 0.0);
}

public Hotel(String id, String name, int code, String description, List<Float> descriptionEmbedding, double rating) {
@JsonCreator
public Hotel(
@JsonProperty("id") String id,
@JsonProperty("name") String name,
@JsonProperty("code") int code,
@JsonProperty("summary") String description,
@JsonProperty("summaryVector") List<Float> descriptionEmbedding,
@JsonProperty("rating") double rating) {
this.id = id;
this.name = name;
this.code = code;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.DockerImageName;

import javax.annotation.Nonnull;
import javax.sql.DataSource;
import java.util.Arrays;
import java.util.List;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ static void setup() {
List<VectorStoreRecordField> fields = new ArrayList<>();
fields.add(VectorStoreRecordKeyField.builder()
.withName("id")
.withFieldType(String.class)
.build());
fields.add(VectorStoreRecordDataField.builder()
.withName("name")
Expand All @@ -64,12 +65,15 @@ static void setup() {
.build());
fields.add(VectorStoreRecordDataField.builder()
.withName("description")
.withStorageName("summary")
.withFieldType(String.class)
.withHasEmbedding(true)
.withEmbeddingFieldName("descriptionEmbedding")
.build());
fields.add(VectorStoreRecordVectorField.builder()
.withName("descriptionEmbedding")
.withStorageName("summaryEmbedding")
.withFieldType(List.class)
.withDimensions(768)
.build());
fields.add(VectorStoreRecordDataField.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ static void setup() {
List<VectorStoreRecordField> fields = new ArrayList<>();
fields.add(VectorStoreRecordKeyField.builder()
.withName("id")
.withFieldType(String.class)
.build());
fields.add(VectorStoreRecordDataField.builder()
.withName("name")
Expand All @@ -65,12 +66,15 @@ static void setup() {
.build());
fields.add(VectorStoreRecordDataField.builder()
.withName("description")
.withStorageName("summary")
.withFieldType(String.class)
.withHasEmbedding(true)
.withEmbeddingFieldName("descriptionEmbedding")
.build());
fields.add(VectorStoreRecordVectorField.builder()
.withName("descriptionEmbedding")
.withStorageName("summaryEmbedding")
.withFieldType(List.class)
.withDimensions(768)
.build());
fields.add(VectorStoreRecordDataField.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import com.azure.core.util.TracingOptions;
import com.azure.search.documents.indexes.SearchIndexAsyncClient;
import com.azure.search.documents.indexes.SearchIndexClientBuilder;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.microsoft.semantickernel.aiservices.openai.textembedding.OpenAITextEmbeddingGenerationService;
import com.microsoft.semantickernel.connectors.data.azureaisearch.AzureAISearchVectorStore;
import com.microsoft.semantickernel.connectors.data.azureaisearch.AzureAISearchVectorStoreOptions;
Expand Down Expand Up @@ -45,7 +47,7 @@ public class AzureAISearch_DataStorage {
private static final int EMBEDDING_DIMENSIONS = 1536;

static class GitHubFile {

@JsonProperty("fileId") // Set a different name for the storage field if needed
@VectorStoreRecordKeyAttribute()
private final String id;
@VectorStoreRecordDataAttribute(hasEmbedding = true, embeddingFieldName = "embedding")
Expand All @@ -60,10 +62,10 @@ public GitHubFile() {
}

public GitHubFile(
String id,
String description,
String link,
List<Float> embedding) {
@JsonProperty("fileId") String id,
@JsonProperty("description") String description,
@JsonProperty("link") String link,
@JsonProperty("embedding") List<Float> embedding) {
this.id = id;
this.description = description;
this.link = link;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
public class AzureAISearchVectorStoreCollectionCreateMapping {

private static String getVectorSearchProfileName(VectorStoreRecordVectorField vectorField) {
return vectorField.getName() + "Profile";
return vectorField.getEffectiveStorageName() + "Profile";
}

private static String getAlgorithmConfigName(VectorStoreRecordVectorField vectorField) {
return vectorField.getName() + "AlgorithmConfig";
return vectorField.getEffectiveStorageName() + "AlgorithmConfig";
}

private static VectorSearchAlgorithmMetric getAlgorithmMetric(
Expand Down Expand Up @@ -68,24 +68,24 @@ private static VectorSearchAlgorithmConfiguration getAlgorithmConfig(
}

public static SearchField mapKeyField(VectorStoreRecordKeyField keyField) {
return new SearchField(keyField.getName(), SearchFieldDataType.STRING)
return new SearchField(keyField.getEffectiveStorageName(), SearchFieldDataType.STRING)
.setKey(true)
.setFilterable(true);
}

public static SearchField mapDataField(VectorStoreRecordDataField dataField) {
if (dataField.getFieldType() == null) {
throw new IllegalArgumentException(
"Field type is required: " + dataField.getName());
"Field type is required: " + dataField.getEffectiveStorageName());
}

return new SearchField(dataField.getName(),
return new SearchField(dataField.getEffectiveStorageName(),
getSearchFieldDataType(dataField.getFieldType()))
.setFilterable(dataField.isFilterable());
}

public static SearchField mapVectorField(VectorStoreRecordVectorField vectorField) {
return new SearchField(vectorField.getName(),
return new SearchField(vectorField.getEffectiveStorageName(),
SearchFieldDataType.collection(SearchFieldDataType.SINGLE))
.setSearchable(true)
.setVectorSearchDimensions(vectorField.getDimensions())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,19 @@ public AzureAISearchVectorStoreRecordCollection(

// Validate supported types
VectorStoreRecordDefinition.validateSupportedTypes(
Collections
.singletonList(recordDefinition.getKeyDeclaredField(this.options.getRecordClass())),
Collections.singletonList(recordDefinition.getKeyField()),
supportedKeyTypes);
VectorStoreRecordDefinition.validateSupportedTypes(
recordDefinition.getDataDeclaredFields(this.options.getRecordClass()),
new ArrayList<>(recordDefinition.getDataFields()),
supportedDataTypes);
VectorStoreRecordDefinition.validateSupportedTypes(
recordDefinition.getVectorDeclaredFields(this.options.getRecordClass()),
new ArrayList<>(recordDefinition.getVectorFields()),
supportedVectorTypes);

// Add non-vector fields to the list
nonVectorFields.add(this.recordDefinition.getKeyField().getName());
nonVectorFields.add(this.recordDefinition.getKeyField().getEffectiveStorageName());
nonVectorFields.addAll(this.recordDefinition.getDataFields().stream()
.map(VectorStoreRecordDataField::getName)
.map(VectorStoreRecordDataField::getEffectiveStorageName)
.collect(Collectors.toList()));
}

Expand Down Expand Up @@ -256,7 +255,7 @@ public Mono<Void> deleteBatchAsync(List<String> keys, DeleteRecordOptions option

return client.deleteDocuments(keys.stream().map(key -> {
SearchDocument document = new SearchDocument();
document.put(this.recordDefinition.getKeyField().getName(), key);
document.put(this.recordDefinition.getKeyField().getEffectiveStorageName(), key);
return document;
}).collect(Collectors.toList())).then();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import javax.annotation.Nonnull;
import javax.sql.DataSource;
import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
Expand All @@ -30,9 +29,9 @@
public class JDBCVectorStoreDefaultQueryProvider
implements JDBCVectorStoreQueryProvider {

private Map<Class<?>, String> supportedKeyTypes;
private Map<Class<?>, String> supportedDataTypes;
private Map<Class<?>, String> supportedVectorTypes;
private final Map<Class<?>, String> supportedKeyTypes;
private final Map<Class<?>, String> supportedDataTypes;
private final Map<Class<?>, String> supportedVectorTypes;
private final DataSource dataSource;
private final String collectionsTable;
private final String prefixForCollectionTables;
Expand Down Expand Up @@ -88,13 +87,24 @@ protected String getWildcardString(int wildcards) {
.collect(Collectors.joining(", "));
}

/**
* Gets the key column name from a key field.
* @param keyField the key field
* @return the key column name
*/
protected String getKeyColumnName(VectorStoreRecordField keyField) {
return validateSQLidentifier(keyField.getEffectiveStorageName());
}

/**
* Formats the query columns from a record definition.
* @param fields the fields to get the columns from
* @return the formatted query columns
*/
protected String getQueryColumnsFromFields(List<VectorStoreRecordField> fields) {
return fields.stream().map(VectorStoreRecordField::getName)
return fields.stream()
.map(VectorStoreRecordField::getEffectiveStorageName)
.map(JDBCVectorStoreDefaultQueryProvider::validateSQLidentifier)
.collect(Collectors.joining(", "));
}

Expand All @@ -104,9 +114,11 @@ protected String getQueryColumnsFromFields(List<VectorStoreRecordField> fields)
* @param types the types
* @return the formatted column names and types
*/
protected String getColumnNamesAndTypes(List<Field> fields, Map<Class<?>, String> types) {
protected String getColumnNamesAndTypes(List<VectorStoreRecordField> fields,
Map<Class<?>, String> types) {
List<String> columns = fields.stream()
.map(field -> field.getName() + " " + types.get(field.getType()))
.map(field -> validateSQLidentifier(field.getEffectiveStorageName()) + " "
+ types.get(field.getFieldType()))
.collect(Collectors.toList());

return String.join(", ", columns);
Expand Down Expand Up @@ -169,20 +181,20 @@ public void prepareVectorStore() {
/**
* Checks if the types of the record class fields are supported.
*
* @param recordClass the record class
* @param recordDefinition the record definition
* @throws IllegalArgumentException if the types are not supported
*/
@Override
public void validateSupportedTypes(Class<?> recordClass,
VectorStoreRecordDefinition recordDefinition) {
public void validateSupportedTypes(VectorStoreRecordDefinition recordDefinition) {

VectorStoreRecordDefinition.validateSupportedTypes(
Collections.singletonList(recordDefinition.getKeyDeclaredField(recordClass)),
Collections.singletonList(recordDefinition.getKeyField()),
getSupportedKeyTypes().keySet());
VectorStoreRecordDefinition.validateSupportedTypes(
recordDefinition.getDataDeclaredFields(recordClass), getSupportedDataTypes().keySet());
new ArrayList<>(recordDefinition.getDataFields()),
getSupportedDataTypes().keySet());
VectorStoreRecordDefinition.validateSupportedTypes(
recordDefinition.getVectorDeclaredFields(recordClass),
new ArrayList<>(recordDefinition.getVectorFields()),
getSupportedVectorTypes().keySet());
}

Expand Down Expand Up @@ -212,23 +224,23 @@ public boolean collectionExists(String collectionName) {
* Creates a collection.
*
* @param collectionName the collection name
* @param recordClass the record class
* @param recordDefinition the record definition
* @throws SKException if an error occurs while creating the collection
*/
@Override
@SuppressFBWarnings("SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING") // SQL query is generated dynamically with valid identifiers
public void createCollection(String collectionName, Class<?> recordClass,
public void createCollection(String collectionName,
VectorStoreRecordDefinition recordDefinition) {
Field keyDeclaredField = recordDefinition.getKeyDeclaredField(recordClass);
List<Field> dataDeclaredFields = recordDefinition.getDataDeclaredFields(recordClass);
List<Field> vectorDeclaredFields = recordDefinition.getVectorDeclaredFields(recordClass);

String createStorageTable = "CREATE TABLE IF NOT EXISTS "
+ getCollectionTableName(collectionName)
+ " (" + keyDeclaredField.getName() + " VARCHAR(255) PRIMARY KEY, "
+ getColumnNamesAndTypes(dataDeclaredFields, getSupportedDataTypes()) + ", "
+ getColumnNamesAndTypes(vectorDeclaredFields, getSupportedVectorTypes()) + ");";
+ getCollectionTableName(collectionName) + " ("
+ getKeyColumnName(recordDefinition.getKeyField()) + " VARCHAR(255) PRIMARY KEY, "
+ getColumnNamesAndTypes(new ArrayList<>(recordDefinition.getDataFields()),
getSupportedDataTypes())
+ ", "
+ getColumnNamesAndTypes(new ArrayList<>(recordDefinition.getVectorFields()),
getSupportedVectorTypes())
+ ");";

String insertCollectionQuery = "INSERT INTO " + validateSQLidentifier(collectionsTable)
+ " (collectionId) VALUES (?)";
Expand Down Expand Up @@ -329,7 +341,7 @@ public <Record> List<Record> getRecords(String collectionName, List<String> keys

String query = "SELECT " + getQueryColumnsFromFields(fields)
+ " FROM " + getCollectionTableName(collectionName)
+ " WHERE " + recordDefinition.getKeyField().getName()
+ " WHERE " + getKeyColumnName(recordDefinition.getKeyField())
+ " IN (" + getWildcardString(keys.size()) + ")";

try (Connection connection = dataSource.getConnection();
Expand Down Expand Up @@ -371,7 +383,7 @@ public void upsertRecords(String collectionName, List<?> records,
public void deleteRecords(String collectionName, List<String> keys,
VectorStoreRecordDefinition recordDefinition, DeleteRecordOptions options) {
String query = "DELETE FROM " + getCollectionTableName(collectionName)
+ " WHERE " + recordDefinition.getKeyField().getName()
+ " WHERE " + getKeyColumnName(recordDefinition.getKeyField())
+ " IN (" + getWildcardString(keys.size()) + ")";

try (Connection connection = dataSource.getConnection();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,9 @@ public interface JDBCVectorStoreQueryProvider {
/**
* Checks if the types of the record class fields are supported.
*
* @param recordClass the record class
* @param recordDefinition the record definition
*/
void validateSupportedTypes(Class<?> recordClass, VectorStoreRecordDefinition recordDefinition);
void validateSupportedTypes(VectorStoreRecordDefinition recordDefinition);

/**
* Checks if a collection exists.
Expand All @@ -74,11 +73,9 @@ public interface JDBCVectorStoreQueryProvider {
* Creates a collection.
*
* @param collectionName the collection name
* @param recordClass the record class
* @param recordDefinition the record definition
*/
void createCollection(String collectionName, Class<?> recordClass,
VectorStoreRecordDefinition recordDefinition);
void createCollection(String collectionName, VectorStoreRecordDefinition recordDefinition);

/**
* Deletes a collection.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public JDBCVectorStoreRecordCollection(
}

// Check if the types are supported
queryProvider.validateSupportedTypes(options.getRecordClass(), recordDefinition);
queryProvider.validateSupportedTypes(recordDefinition);
}

/**
Expand Down Expand Up @@ -122,8 +122,7 @@ public Mono<Boolean> collectionExistsAsync() {
@Override
public Mono<VectorStoreRecordCollection<String, Record>> createCollectionAsync() {
return Mono.fromRunnable(
() -> queryProvider.createCollection(this.collectionName, options.getRecordClass(),
recordDefinition))
() -> queryProvider.createCollection(this.collectionName, recordDefinition))
.subscribeOn(Schedulers.boundedElastic())
.then(Mono.just(this));
}
Expand Down Expand Up @@ -200,7 +199,7 @@ public Mono<List<Record>> getBatchAsync(List<String> keys, GetRecordOptions opti
protected String getKeyFromRecord(Record data) {
try {
Field keyField = data.getClass()
.getDeclaredField(recordDefinition.getKeyField().getName());
.getDeclaredField(recordDefinition.getKeyField().getEffectiveStorageName());
keyField.setAccessible(true);
return (String) keyField.get(data);
} catch (NoSuchFieldException | IllegalAccessException e) {
Expand Down
Loading

0 comments on commit 9bf0833

Please sign in to comment.