Skip to content

Commit

Permalink
Merge pull request #195 from milderhc/get-collection
Browse files Browse the repository at this point in the history
Add VectorStoreRecordCollectionOptions interface and make getCollecti…
  • Loading branch information
dsgrieve authored Aug 22, 2024
2 parents 9bf0833 + c213a08 commit 38c0b70
Show file tree
Hide file tree
Showing 18 changed files with 208 additions and 135 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStore;
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreOptions;
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider;
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollectionOptions;
import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider;
import com.microsoft.semantickernel.connectors.data.postgres.PostgreSQLVectorStoreQueryProvider;
import com.microsoft.semantickernel.tests.connectors.memory.Hotel;
Expand Down Expand Up @@ -88,7 +89,10 @@ public void getCollectionNamesAsync(QueryProvider provider) {
List<String> collectionNames = Arrays.asList("collection1", "collection2", "collection3");

for (String collectionName : collectionNames) {
vectorStore.getCollection(collectionName, Hotel.class, null).createCollectionAsync().block();
vectorStore.getCollection(collectionName,
JDBCVectorStoreRecordCollectionOptions.<Hotel>builder()
.withRecordClass(Hotel.class)
.build()).createCollectionAsync().block();
}

List<String> retrievedCollectionNames = vectorStore.getCollectionNamesAsync().block();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package com.microsoft.semantickernel.tests.connectors.memory.redis;

import com.microsoft.semantickernel.connectors.data.redis.RedisHashSetVectorStoreRecordCollectionOptions;
import com.microsoft.semantickernel.connectors.data.redis.RedisJsonVectorStoreRecordCollectionOptions;
import com.microsoft.semantickernel.connectors.data.redis.RedisStorageType;
import com.microsoft.semantickernel.connectors.data.redis.RedisVectorStore;
import com.microsoft.semantickernel.connectors.data.redis.RedisVectorStoreOptions;
import com.microsoft.semantickernel.data.VectorStoreRecordCollectionOptions;
import com.microsoft.semantickernel.tests.connectors.memory.Hotel;
import com.redis.testcontainers.RedisContainer;
import org.junit.jupiter.api.BeforeAll;
Expand Down Expand Up @@ -36,6 +39,18 @@ public static JedisPooled buildClient(RedisStorageType storageType) {
}
}

private static VectorStoreRecordCollectionOptions<String, Hotel> getRecordCollectionOptions(RedisStorageType storageType) {
if (storageType == RedisStorageType.JSON) {
return RedisJsonVectorStoreRecordCollectionOptions.<Hotel>builder()
.withRecordClass(Hotel.class)
.build();
} else {
return RedisHashSetVectorStoreRecordCollectionOptions.<Hotel>builder()
.withRecordClass(Hotel.class)
.build();
}
}

@ParameterizedTest
@EnumSource(RedisStorageType.class)
public void getCollectionNamesAsync(RedisStorageType storageType) {
Expand All @@ -46,7 +61,7 @@ public void getCollectionNamesAsync(RedisStorageType storageType) {
List<String> collectionNames = Arrays.asList("collection1", "collection2", "collection3");

for (String collectionName : collectionNames) {
vectorStore.getCollection(collectionName, Hotel.class, null).createCollectionAsync().block();
vectorStore.getCollection(collectionName, getRecordCollectionOptions(storageType)).createCollectionAsync().block();
}

List<String> retrievedCollectionNames = vectorStore.getCollectionNamesAsync().block();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import com.microsoft.semantickernel.aiservices.openai.textembedding.OpenAITextEmbeddingGenerationService;
import com.microsoft.semantickernel.connectors.data.azureaisearch.AzureAISearchVectorStore;
import com.microsoft.semantickernel.connectors.data.azureaisearch.AzureAISearchVectorStoreOptions;
import com.microsoft.semantickernel.connectors.data.azureaisearch.AzureAISearchVectorStoreRecordCollectionOptions;
import com.microsoft.semantickernel.data.VectorStoreRecordCollection;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordDataAttribute;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute;
Expand Down Expand Up @@ -125,8 +126,9 @@ public static void dataStorageWithAzureAISearch(
String collectionName = "skgithubfiles";
var collection = azureAISearchVectorStore.getCollection(
collectionName,
GitHubFile.class,
null);
AzureAISearchVectorStoreRecordCollectionOptions.<GitHubFile>builder()
.withRecordClass(GitHubFile.class)
.build());

// Create collection if it does not exist and store data
collection
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.microsoft.semantickernel.aiservices.openai.textembedding.OpenAITextEmbeddingGenerationService;
import com.microsoft.semantickernel.data.VectorStoreRecordCollection;
import com.microsoft.semantickernel.data.VolatileVectorStore;
import com.microsoft.semantickernel.data.VolatileVectorStoreRecordCollectionOptions;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordDataAttribute;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordVectorAttribute;
Expand Down Expand Up @@ -101,7 +102,10 @@ public static void inMemoryDataStorage(
var volatileVectorStore = new VolatileVectorStore();

String collectionName = "skgithubfiles";
var collection = volatileVectorStore.getCollection(collectionName, GitHubFile.class, null);
var collection = volatileVectorStore.getCollection(collectionName,
VolatileVectorStoreRecordCollectionOptions.<GitHubFile>builder()
.withRecordClass(GitHubFile.class)
.build());

// Create collection if it does not exist and store data
List<String> ids = collection
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.microsoft.semantickernel.aiservices.openai.textembedding.OpenAITextEmbeddingGenerationService;
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStore;
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreOptions;
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollectionOptions;
import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider;
import com.microsoft.semantickernel.data.VectorStoreRecordCollection;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordDataAttribute;
Expand Down Expand Up @@ -132,8 +133,9 @@ public static void dataStorageWithMySQL(

String collectionName = "skgithubfiles";
var collection = jdbcVectorStore.getCollection(collectionName,
GitHubFile.class,
null);
JDBCVectorStoreRecordCollectionOptions.<GitHubFile>builder()
.withRecordClass(GitHubFile.class)
.build());

// Create collection if it does not exist and store data
List<String> ids = collection
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.azure.core.util.MetricsOptions;
import com.azure.core.util.TracingOptions;
import com.microsoft.semantickernel.aiservices.openai.textembedding.OpenAITextEmbeddingGenerationService;
import com.microsoft.semantickernel.connectors.data.redis.RedisJsonVectorStoreRecordCollectionOptions;
import com.microsoft.semantickernel.connectors.data.redis.RedisVectorStore;
import com.microsoft.semantickernel.connectors.data.redis.RedisVectorStoreOptions;
import com.microsoft.semantickernel.data.VectorStoreRecordCollection;
Expand Down Expand Up @@ -120,7 +121,10 @@ public static void dataStorageWithRedis(
.build();

String collectionName = "skgithubfiles";
var collection = vectorStore.getCollection(collectionName, GitHubFile.class, null);
var collection = vectorStore.getCollection(collectionName,
RedisJsonVectorStoreRecordCollectionOptions.<GitHubFile>builder()
.withRecordClass(GitHubFile.class)
.build());

// Create collection if it does not exist and store data
List<String> ids = collection
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import com.azure.search.documents.indexes.models.SearchIndex;
import com.microsoft.semantickernel.data.VectorStore;
import com.microsoft.semantickernel.data.VectorStoreRecordCollection;
import com.microsoft.semantickernel.data.VectorStoreRecordCollectionOptions;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition;
import com.microsoft.semantickernel.exceptions.SKException;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.util.List;
import javax.annotation.Nonnull;
Expand Down Expand Up @@ -34,44 +36,34 @@ public AzureAISearchVectorStore(@Nonnull SearchIndexAsyncClient client,
* Gets a new instance of {@link AzureAISearchVectorStoreRecordCollection}
*
* @param collectionName The name of the collection.
* @param recordClass The class type of the record.
* @param recordDefinition The record definition.
* @param options The options for the collection.
* @return The collection.
*/
@Override
public final <Key, Record> VectorStoreRecordCollection<Key, Record> getCollection(
@Nonnull String collectionName,
@Nonnull Class<Key> keyClass,
@Nonnull Class<Record> recordClass,
@Nullable VectorStoreRecordDefinition recordDefinition) {
if (!keyClass.equals(String.class)) {
throw new IllegalArgumentException("Azure AI Search only supports string keys");
@Nonnull VectorStoreRecordCollectionOptions<Key, Record> options) {
if (!options.getKeyClass().equals(String.class)) {
throw new SKException("Azure AI Search only supports string keys");
}
if (options.getRecordClass() == null) {
throw new SKException("Record class is required");
}

return (VectorStoreRecordCollection<Key, Record>) getCollection(
collectionName, recordClass, recordDefinition);
}

public <Record> AzureAISearchVectorStoreRecordCollection<Record> getCollection(
@Nonnull String collectionName,
@Nonnull Class<Record> recordClass,
@Nullable VectorStoreRecordDefinition recordDefinition) {
if (options.getVectorStoreRecordCollectionFactory() != null) {
return options.getVectorStoreRecordCollectionFactory()
if (this.options.getVectorStoreRecordCollectionFactory() != null) {
return (VectorStoreRecordCollection<Key, Record>) this.options
.getVectorStoreRecordCollectionFactory()
.createVectorStoreRecordCollection(
client,
collectionName,
recordClass,
recordDefinition);
options.getRecordClass(),
options.getRecordDefinition());
}

return new AzureAISearchVectorStoreRecordCollection<>(
return (VectorStoreRecordCollection<Key, Record>) new AzureAISearchVectorStoreRecordCollection<>(
client,
collectionName,
AzureAISearchVectorStoreRecordCollectionOptions.<Record>builder()
.withRecordClass(recordClass)
.withRecordDefinition(recordDefinition)
.build());
(AzureAISearchVectorStoreRecordCollectionOptions<Record>) options);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package com.microsoft.semantickernel.connectors.data.azureaisearch;

import com.azure.search.documents.SearchDocument;
import com.microsoft.semantickernel.data.VectorStoreRecordCollectionOptions;
import com.microsoft.semantickernel.data.VectorStoreRecordMapper;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition;
import javax.annotation.Nonnull;
Expand All @@ -12,12 +13,11 @@
*
* @param <Record> the record type
*/
public class AzureAISearchVectorStoreRecordCollectionOptions<Record> {

public class AzureAISearchVectorStoreRecordCollectionOptions<Record>
implements VectorStoreRecordCollectionOptions<String, Record> {
private final Class<Record> recordClass;
@Nullable
private final VectorStoreRecordMapper<Record, SearchDocument> vectorStoreRecordMapper;

@Nullable
private final VectorStoreRecordDefinition recordDefinition;

Expand All @@ -31,6 +31,16 @@ public static <Record> Builder<Record> builder() {
return new Builder<>();
}

/**
* Gets the key class.
*
* @return the key class
*/
@Override
public Class<String> getKeyClass() {
return String.class;
}

/**
* Gets the record class.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
package com.microsoft.semantickernel.connectors.data.jdbc;

import com.microsoft.semantickernel.data.VectorStoreRecordCollection;
import com.microsoft.semantickernel.data.VectorStoreRecordCollectionOptions;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition;
import com.microsoft.semantickernel.exceptions.SKException;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
Expand Down Expand Up @@ -55,53 +57,42 @@ public static Builder builder() {
* Gets a collection from the vector store.
*
* @param collectionName The name of the collection.
* @param recordClass The class type of the record.
* @param recordDefinition The record definition.
* @param options The options for the collection.
* @return The collection.
*/
@Override
public <Key, Record> VectorStoreRecordCollection<Key, Record> getCollection(
@Nonnull String collectionName, @Nonnull Class<Key> keyClass,
@Nonnull Class<Record> recordClass,
@Nullable VectorStoreRecordDefinition recordDefinition) {
if (keyClass != String.class) {
throw new IllegalArgumentException("Redis only supports string keys");
@Nonnull String collectionName,
@Nonnull VectorStoreRecordCollectionOptions<Key, Record> options) {
if (!options.getKeyClass().equals(String.class)) {
throw new SKException("JDBC only supports string keys");
}
if (options.getRecordClass() == null) {
throw new SKException("Record class is required");
}

return (VectorStoreRecordCollection<Key, Record>) getCollection(
collectionName,
recordClass,
recordDefinition);
}

/**
* Gets a collection from the vector store.
*
* @param collectionName The name of the collection.
* @param recordClass The class type of the record.
* @param recordDefinition The record definition.
* @return The collection.
*/
public <Record> JDBCVectorStoreRecordCollection<Record> getCollection(
@Nonnull String collectionName,
@Nonnull Class<Record> recordClass,
@Nullable VectorStoreRecordDefinition recordDefinition) {
if (this.options != null && this.options.getVectorStoreRecordCollectionFactory() != null) {
return this.options.getVectorStoreRecordCollectionFactory()
return (VectorStoreRecordCollection<Key, Record>) this.options
.getVectorStoreRecordCollectionFactory()
.createVectorStoreRecordCollection(
dataSource,
collectionName,
recordClass,
recordDefinition);
options.getRecordClass(),
options.getRecordDefinition());
}

return new JDBCVectorStoreRecordCollection<>(
JDBCVectorStoreRecordCollectionOptions<Record> jdbcOptions = (JDBCVectorStoreRecordCollectionOptions<Record>) options;
return (VectorStoreRecordCollection<Key, Record>) new JDBCVectorStoreRecordCollection<>(
dataSource,
collectionName,
JDBCVectorStoreRecordCollectionOptions.<Record>builder()
.withRecordClass(recordClass)
.withRecordDefinition(recordDefinition)
.withQueryProvider(this.queryProvider)
.withCollectionsTableName(jdbcOptions.getCollectionsTableName())
.withPrefixForCollectionTables(jdbcOptions.getPrefixForCollectionTables())
.withQueryProvider(jdbcOptions.getQueryProvider() == null ? queryProvider
: jdbcOptions.getQueryProvider())
.withRecordClass(jdbcOptions.getRecordClass())
.withRecordDefinition(jdbcOptions.getRecordDefinition())
.withVectorStoreRecordMapper(jdbcOptions.getVectorStoreRecordMapper())
.build());
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.connectors.data.jdbc;

import com.microsoft.semantickernel.data.VectorStoreRecordCollectionOptions;
import com.microsoft.semantickernel.data.VectorStoreRecordMapper;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
Expand All @@ -11,7 +12,8 @@
import static com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider.DEFAULT_COLLECTIONS_TABLE;
import static com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider.DEFAULT_PREFIX_FOR_COLLECTION_TABLES;

public class JDBCVectorStoreRecordCollectionOptions<Record> {
public class JDBCVectorStoreRecordCollectionOptions<Record>
implements VectorStoreRecordCollectionOptions<String, Record> {
private final Class<Record> recordClass;
private final VectorStoreRecordMapper<Record, ResultSet> vectorStoreRecordMapper;
private final VectorStoreRecordDefinition recordDefinition;
Expand Down Expand Up @@ -43,6 +45,16 @@ public static <Record> Builder<Record> builder() {
return new Builder<>();
}

/**
* Gets the key class.
*
* @return the key class
*/
@Override
public Class<String> getKeyClass() {
return String.class;
}

/**
* Gets the record class.
* @return the record class
Expand Down
Loading

0 comments on commit 38c0b70

Please sign in to comment.