diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java index 455aed49..e21146eb 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java @@ -87,13 +87,24 @@ protected String getWildcardString(int wildcards) { .collect(Collectors.joining(", ")); } + /** + * Gets the key column name from a key field. + * @param keyField the key field + * @return the key column name + */ + protected String getKeyColumnName(VectorStoreRecordField keyField) { + return validateSQLidentifier(keyField.getEffectiveStorageName()); + } + /** * Formats the query columns from a record definition. * @param fields the fields to get the columns from * @return the formatted query columns */ protected String getQueryColumnsFromFields(List fields) { - return fields.stream().map(VectorStoreRecordField::getEffectiveStorageName) + return fields.stream() + .map(VectorStoreRecordField::getEffectiveStorageName) + .map(JDBCVectorStoreDefaultQueryProvider::validateSQLidentifier) .collect(Collectors.joining(", ")); } @@ -106,7 +117,8 @@ protected String getQueryColumnsFromFields(List fields) protected String getColumnNamesAndTypes(List fields, Map, String> types) { List columns = fields.stream() - .map(field -> field.getEffectiveStorageName() + " " + types.get(field.getFieldType())) + .map(field -> validateSQLidentifier(field.getEffectiveStorageName()) + " " + + types.get(field.getFieldType())) .collect(Collectors.toList()); return String.join(", ", columns); @@ -221,9 +233,8 @@ public void createCollection(String collectionName, VectorStoreRecordDefinition recordDefinition) { String createStorageTable = "CREATE TABLE IF NOT EXISTS " - + getCollectionTableName(collectionName) - + " (" + recordDefinition.getKeyField().getEffectiveStorageName() - + " VARCHAR(255) PRIMARY KEY, " + + getCollectionTableName(collectionName) + " (" + + getKeyColumnName(recordDefinition.getKeyField()) + " VARCHAR(255) PRIMARY KEY, " + getColumnNamesAndTypes(new ArrayList<>(recordDefinition.getDataFields()), getSupportedDataTypes()) + ", " @@ -330,7 +341,7 @@ public List getRecords(String collectionName, List keys String query = "SELECT " + getQueryColumnsFromFields(fields) + " FROM " + getCollectionTableName(collectionName) - + " WHERE " + recordDefinition.getKeyField().getEffectiveStorageName() + + " WHERE " + getKeyColumnName(recordDefinition.getKeyField()) + " IN (" + getWildcardString(keys.size()) + ")"; try (Connection connection = dataSource.getConnection(); @@ -372,7 +383,7 @@ public void upsertRecords(String collectionName, List records, public void deleteRecords(String collectionName, List keys, VectorStoreRecordDefinition recordDefinition, DeleteRecordOptions options) { String query = "DELETE FROM " + getCollectionTableName(collectionName) - + " WHERE " + recordDefinition.getKeyField().getEffectiveStorageName() + + " WHERE " + getKeyColumnName(recordDefinition.getKeyField()) + " IN (" + getWildcardString(keys.size()) + ")"; try (Connection connection = dataSource.getConnection(); diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/mysql/MySQLVectorStoreQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/mysql/MySQLVectorStoreQueryProvider.java index 6ea52c95..c6440bad 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/mysql/MySQLVectorStoreQueryProvider.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/mysql/MySQLVectorStoreQueryProvider.java @@ -19,6 +19,7 @@ import java.sql.PreparedStatement; import java.sql.SQLException; import java.util.List; +import java.util.stream.Collectors; public class MySQLVectorStoreQueryProvider extends JDBCVectorStoreDefaultQueryProvider implements JDBCVectorStoreQueryProvider { @@ -28,10 +29,10 @@ public class MySQLVectorStoreQueryProvider extends @SuppressFBWarnings("EI_EXPOSE_REP2") private MySQLVectorStoreQueryProvider( - @Nonnull DataSource dataSource, - @Nonnull String collectionsTable, - @Nonnull String prefixForCollectionTables, - @Nonnull ObjectMapper objectMapper) { + @Nonnull DataSource dataSource, + @Nonnull String collectionsTable, + @Nonnull String prefixForCollectionTables, + @Nonnull ObjectMapper objectMapper) { super(dataSource, collectionsTable, prefixForCollectionTables); this.dataSource = dataSource; this.objectMapper = objectMapper; @@ -45,7 +46,7 @@ public static Builder builder() { return new Builder(); } - private void setStatementValues(PreparedStatement statement, Object record, + private void setUpsertStatementValues(PreparedStatement statement, Object record, List fields) { JsonNode jsonNode = objectMapper.valueToTree(record); @@ -82,20 +83,12 @@ private void setStatementValues(PreparedStatement statement, Object record, @SuppressFBWarnings("SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING") // SQL query is generated dynamically with valid identifiers public void upsertRecords(String collectionName, List records, VectorStoreRecordDefinition recordDefinition, UpsertRecordOptions options) { - validateSQLidentifier(getCollectionTableName(collectionName)); - List fields = recordDefinition.getAllFields(); - StringBuilder onDuplicateKeyUpdate = new StringBuilder(); - for (int i = 0; i < fields.size(); ++i) { - VectorStoreRecordField field = fields.get(i); - if (i > 0) { - onDuplicateKeyUpdate.append(", "); - } - - onDuplicateKeyUpdate.append(field.getEffectiveStorageName()).append(" = VALUES(") - .append(field.getEffectiveStorageName()).append(")"); - } + String onDuplicateKeyUpdate = fields.stream() + .map(field -> validateSQLidentifier(field.getEffectiveStorageName()) + + " = VALUES(" + validateSQLidentifier(field.getEffectiveStorageName()) + ")") + .collect(Collectors.joining(", ")); String query = "INSERT INTO " + getCollectionTableName(collectionName) + " (" + getQueryColumnsFromFields(fields) + ")" @@ -105,7 +98,7 @@ public void upsertRecords(String collectionName, List records, try (Connection connection = dataSource.getConnection(); PreparedStatement statement = connection.prepareStatement(query)) { for (Object record : records) { - setStatementValues(statement, record, recordDefinition.getAllFields()); + setUpsertStatementValues(statement, record, recordDefinition.getAllFields()); statement.addBatch(); } diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreQueryProvider.java index 3d7a47a9..d061254d 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreQueryProvider.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreQueryProvider.java @@ -25,6 +25,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; public class PostgreSQLVectorStoreQueryProvider extends JDBCVectorStoreDefaultQueryProvider implements JDBCVectorStoreQueryProvider { @@ -40,10 +41,10 @@ public class PostgreSQLVectorStoreQueryProvider extends @SuppressFBWarnings("EI_EXPOSE_REP2") private PostgreSQLVectorStoreQueryProvider( - @Nonnull DataSource dataSource, - @Nonnull String collectionsTable, - @Nonnull String prefixForCollectionTables, - @Nonnull ObjectMapper objectMapper) { + @Nonnull DataSource dataSource, + @Nonnull String collectionsTable, + @Nonnull String prefixForCollectionTables, + @Nonnull ObjectMapper objectMapper) { super(dataSource, collectionsTable, prefixForCollectionTables); this.dataSource = dataSource; this.collectionsTable = collectionsTable; @@ -134,24 +135,19 @@ public void prepareVectorStore() { private String getColumnNamesAndTypesForVectorFields( List fields) { - StringBuilder columnNames = new StringBuilder(); - for (VectorStoreRecordVectorField field : fields) { - if (columnNames.length() > 0) { - columnNames.append(", "); - } - - if (field.getFieldType().equals(String.class)) { - columnNames.append(field.getEffectiveStorageName()).append(" ") - .append(supportedVectorTypes.get(String.class)); - } else { - // Get the vector type and dimensions - String type = String.format(supportedVectorTypes.get(field.getFieldType()), - field.getDimensions()); - columnNames.append(field.getEffectiveStorageName()).append(" ").append(type); - } - } - - return columnNames.toString(); + return fields.stream() + .map(field -> { + String columnType; + if (field.getFieldType().equals(String.class)) { + columnType = supportedVectorTypes.get(String.class); + } else { + // Get the vector type and dimensions + columnType = String.format(supportedVectorTypes.get(field.getFieldType()), + field.getDimensions()); + } + return validateSQLidentifier(field.getEffectiveStorageName()) + " " + columnType; + }) + .collect(Collectors.joining(", ")); } /** @@ -167,8 +163,8 @@ public void createCollection(String collectionName, VectorStoreRecordDefinition recordDefinition) { String createStorageTable = "CREATE TABLE IF NOT EXISTS " - + getCollectionTableName(collectionName) - + " (" + recordDefinition.getKeyField().getStorageName() + " VARCHAR(255) PRIMARY KEY, " + + getCollectionTableName(collectionName) + " (" + + getKeyColumnName(recordDefinition.getKeyField()) + " VARCHAR(255) PRIMARY KEY, " + getColumnNamesAndTypes(new ArrayList<>(recordDefinition.getDataFields()), supportedDataTypes) + ", " @@ -194,7 +190,7 @@ public void createCollection(String collectionName, } } - private void setStatementValues(PreparedStatement statement, Object record, + private void setUpsertStatementValues(PreparedStatement statement, Object record, List fields) { JsonNode jsonNode = objectMapper.valueToTree(record); @@ -220,19 +216,16 @@ private void setStatementValues(PreparedStatement statement, Object record, } private String getWildcardStringWithCast(List fields) { - StringBuilder wildcardString = new StringBuilder(); - int wildcards = fields.size(); - for (int i = 0; i < wildcards; ++i) { - if (i > 0) { - wildcardString.append(", "); - } - wildcardString.append("?"); - // Add casting for vector fields - if (fields.get(i) instanceof VectorStoreRecordVectorField) { - wildcardString.append("::vector"); - } - } - return wildcardString.toString(); + return fields.stream() + .map(field -> { + String wildcard = "?"; + // Add casting for vector fields + if (field instanceof VectorStoreRecordVectorField) { + wildcard += "::vector"; + } + return wildcard; + }) + .collect(Collectors.joining(", ")); } /** @@ -250,30 +243,23 @@ public void upsertRecords(String collectionName, List records, validateSQLidentifier(getCollectionTableName(collectionName)); List fields = recordDefinition.getAllFields(); - StringBuilder onDuplicateKeyUpdate = new StringBuilder(); - for (VectorStoreRecordField field : fields) { - if (field instanceof VectorStoreRecordKeyField) { - continue; - } - if (onDuplicateKeyUpdate.length() > 0) { - onDuplicateKeyUpdate.append(", "); - } - onDuplicateKeyUpdate.append(field.getEffectiveStorageName()) - .append(" = EXCLUDED.") - .append(field.getEffectiveStorageName()); - } + String onDuplicateKeyUpdate = fields.stream() + .filter(field -> !(field instanceof VectorStoreRecordKeyField)) // Exclude key fields + .map(field -> validateSQLidentifier(field.getEffectiveStorageName()) + + " = EXCLUDED." + validateSQLidentifier(field.getEffectiveStorageName())) + .collect(Collectors.joining(", ")); String query = "INSERT INTO " + getCollectionTableName(collectionName) + " (" + getQueryColumnsFromFields(fields) + ")" + " VALUES (" + getWildcardStringWithCast(fields) + ")" - + " ON CONFLICT (" + recordDefinition.getKeyField().getEffectiveStorageName() + + " ON CONFLICT (" + getKeyColumnName(recordDefinition.getKeyField()) + ") DO UPDATE SET " + onDuplicateKeyUpdate; try (Connection connection = dataSource.getConnection(); PreparedStatement statement = connection.prepareStatement(query)) { for (Object record : records) { - setStatementValues(statement, record, recordDefinition.getAllFields()); + setUpsertStatementValues(statement, record, recordDefinition.getAllFields()); statement.addBatch(); }