Skip to content

Commit

Permalink
Add suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
milderhc committed Aug 21, 2024
1 parent 07e1f5b commit 4d25116
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,24 @@ protected String getWildcardString(int wildcards) {
.collect(Collectors.joining(", "));
}

/**
* Gets the key column name from a key field.
* @param keyField the key field
* @return the key column name
*/
protected String getKeyColumnName(VectorStoreRecordField keyField) {
return validateSQLidentifier(keyField.getEffectiveStorageName());
}

/**
* Formats the query columns from a record definition.
* @param fields the fields to get the columns from
* @return the formatted query columns
*/
protected String getQueryColumnsFromFields(List<VectorStoreRecordField> fields) {
return fields.stream().map(VectorStoreRecordField::getEffectiveStorageName)
return fields.stream()
.map(VectorStoreRecordField::getEffectiveStorageName)
.map(JDBCVectorStoreDefaultQueryProvider::validateSQLidentifier)
.collect(Collectors.joining(", "));
}

Expand All @@ -106,7 +117,8 @@ protected String getQueryColumnsFromFields(List<VectorStoreRecordField> fields)
protected String getColumnNamesAndTypes(List<VectorStoreRecordField> fields,
Map<Class<?>, String> types) {
List<String> columns = fields.stream()
.map(field -> field.getEffectiveStorageName() + " " + types.get(field.getFieldType()))
.map(field -> validateSQLidentifier(field.getEffectiveStorageName()) + " "
+ types.get(field.getFieldType()))
.collect(Collectors.toList());

return String.join(", ", columns);
Expand Down Expand Up @@ -221,9 +233,8 @@ public void createCollection(String collectionName,
VectorStoreRecordDefinition recordDefinition) {

String createStorageTable = "CREATE TABLE IF NOT EXISTS "
+ getCollectionTableName(collectionName)
+ " (" + recordDefinition.getKeyField().getEffectiveStorageName()
+ " VARCHAR(255) PRIMARY KEY, "
+ getCollectionTableName(collectionName) + " ("
+ getKeyColumnName(recordDefinition.getKeyField()) + " VARCHAR(255) PRIMARY KEY, "
+ getColumnNamesAndTypes(new ArrayList<>(recordDefinition.getDataFields()),
getSupportedDataTypes())
+ ", "
Expand Down Expand Up @@ -330,7 +341,7 @@ public <Record> List<Record> getRecords(String collectionName, List<String> keys

String query = "SELECT " + getQueryColumnsFromFields(fields)
+ " FROM " + getCollectionTableName(collectionName)
+ " WHERE " + recordDefinition.getKeyField().getEffectiveStorageName()
+ " WHERE " + getKeyColumnName(recordDefinition.getKeyField())
+ " IN (" + getWildcardString(keys.size()) + ")";

try (Connection connection = dataSource.getConnection();
Expand Down Expand Up @@ -372,7 +383,7 @@ public void upsertRecords(String collectionName, List<?> records,
public void deleteRecords(String collectionName, List<String> keys,
VectorStoreRecordDefinition recordDefinition, DeleteRecordOptions options) {
String query = "DELETE FROM " + getCollectionTableName(collectionName)
+ " WHERE " + recordDefinition.getKeyField().getEffectiveStorageName()
+ " WHERE " + getKeyColumnName(recordDefinition.getKeyField())
+ " IN (" + getWildcardString(keys.size()) + ")";

try (Connection connection = dataSource.getConnection();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.List;
import java.util.stream.Collectors;

public class MySQLVectorStoreQueryProvider extends
JDBCVectorStoreDefaultQueryProvider implements JDBCVectorStoreQueryProvider {
Expand All @@ -28,10 +29,10 @@ public class MySQLVectorStoreQueryProvider extends

@SuppressFBWarnings("EI_EXPOSE_REP2")
private MySQLVectorStoreQueryProvider(
@Nonnull DataSource dataSource,
@Nonnull String collectionsTable,
@Nonnull String prefixForCollectionTables,
@Nonnull ObjectMapper objectMapper) {
@Nonnull DataSource dataSource,
@Nonnull String collectionsTable,
@Nonnull String prefixForCollectionTables,
@Nonnull ObjectMapper objectMapper) {
super(dataSource, collectionsTable, prefixForCollectionTables);
this.dataSource = dataSource;
this.objectMapper = objectMapper;
Expand All @@ -45,7 +46,7 @@ public static Builder builder() {
return new Builder();
}

private void setStatementValues(PreparedStatement statement, Object record,
private void setUpsertStatementValues(PreparedStatement statement, Object record,
List<VectorStoreRecordField> fields) {
JsonNode jsonNode = objectMapper.valueToTree(record);

Expand Down Expand Up @@ -82,20 +83,12 @@ private void setStatementValues(PreparedStatement statement, Object record,
@SuppressFBWarnings("SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING") // SQL query is generated dynamically with valid identifiers
public void upsertRecords(String collectionName, List<?> records,
VectorStoreRecordDefinition recordDefinition, UpsertRecordOptions options) {
validateSQLidentifier(getCollectionTableName(collectionName));

List<VectorStoreRecordField> fields = recordDefinition.getAllFields();

StringBuilder onDuplicateKeyUpdate = new StringBuilder();
for (int i = 0; i < fields.size(); ++i) {
VectorStoreRecordField field = fields.get(i);
if (i > 0) {
onDuplicateKeyUpdate.append(", ");
}

onDuplicateKeyUpdate.append(field.getEffectiveStorageName()).append(" = VALUES(")
.append(field.getEffectiveStorageName()).append(")");
}
String onDuplicateKeyUpdate = fields.stream()
.map(field -> validateSQLidentifier(field.getEffectiveStorageName())
+ " = VALUES(" + validateSQLidentifier(field.getEffectiveStorageName()) + ")")
.collect(Collectors.joining(", "));

String query = "INSERT INTO " + getCollectionTableName(collectionName)
+ " (" + getQueryColumnsFromFields(fields) + ")"
Expand All @@ -105,7 +98,7 @@ public void upsertRecords(String collectionName, List<?> records,
try (Connection connection = dataSource.getConnection();
PreparedStatement statement = connection.prepareStatement(query)) {
for (Object record : records) {
setStatementValues(statement, record, recordDefinition.getAllFields());
setUpsertStatementValues(statement, record, recordDefinition.getAllFields());
statement.addBatch();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class PostgreSQLVectorStoreQueryProvider extends
JDBCVectorStoreDefaultQueryProvider implements JDBCVectorStoreQueryProvider {
Expand All @@ -40,10 +41,10 @@ public class PostgreSQLVectorStoreQueryProvider extends

@SuppressFBWarnings("EI_EXPOSE_REP2")
private PostgreSQLVectorStoreQueryProvider(
@Nonnull DataSource dataSource,
@Nonnull String collectionsTable,
@Nonnull String prefixForCollectionTables,
@Nonnull ObjectMapper objectMapper) {
@Nonnull DataSource dataSource,
@Nonnull String collectionsTable,
@Nonnull String prefixForCollectionTables,
@Nonnull ObjectMapper objectMapper) {
super(dataSource, collectionsTable, prefixForCollectionTables);
this.dataSource = dataSource;
this.collectionsTable = collectionsTable;
Expand Down Expand Up @@ -134,24 +135,19 @@ public void prepareVectorStore() {

private String getColumnNamesAndTypesForVectorFields(
List<VectorStoreRecordVectorField> fields) {
StringBuilder columnNames = new StringBuilder();
for (VectorStoreRecordVectorField field : fields) {
if (columnNames.length() > 0) {
columnNames.append(", ");
}

if (field.getFieldType().equals(String.class)) {
columnNames.append(field.getEffectiveStorageName()).append(" ")
.append(supportedVectorTypes.get(String.class));
} else {
// Get the vector type and dimensions
String type = String.format(supportedVectorTypes.get(field.getFieldType()),
field.getDimensions());
columnNames.append(field.getEffectiveStorageName()).append(" ").append(type);
}
}

return columnNames.toString();
return fields.stream()
.map(field -> {
String columnType;
if (field.getFieldType().equals(String.class)) {
columnType = supportedVectorTypes.get(String.class);
} else {
// Get the vector type and dimensions
columnType = String.format(supportedVectorTypes.get(field.getFieldType()),
field.getDimensions());
}
return validateSQLidentifier(field.getEffectiveStorageName()) + " " + columnType;
})
.collect(Collectors.joining(", "));
}

/**
Expand All @@ -167,8 +163,8 @@ public void createCollection(String collectionName,
VectorStoreRecordDefinition recordDefinition) {

String createStorageTable = "CREATE TABLE IF NOT EXISTS "
+ getCollectionTableName(collectionName)
+ " (" + recordDefinition.getKeyField().getStorageName() + " VARCHAR(255) PRIMARY KEY, "
+ getCollectionTableName(collectionName) + " ("
+ getKeyColumnName(recordDefinition.getKeyField()) + " VARCHAR(255) PRIMARY KEY, "
+ getColumnNamesAndTypes(new ArrayList<>(recordDefinition.getDataFields()),
supportedDataTypes)
+ ", "
Expand All @@ -194,7 +190,7 @@ public void createCollection(String collectionName,
}
}

private void setStatementValues(PreparedStatement statement, Object record,
private void setUpsertStatementValues(PreparedStatement statement, Object record,
List<VectorStoreRecordField> fields) {
JsonNode jsonNode = objectMapper.valueToTree(record);

Expand All @@ -220,19 +216,16 @@ private void setStatementValues(PreparedStatement statement, Object record,
}

private String getWildcardStringWithCast(List<VectorStoreRecordField> fields) {
StringBuilder wildcardString = new StringBuilder();
int wildcards = fields.size();
for (int i = 0; i < wildcards; ++i) {
if (i > 0) {
wildcardString.append(", ");
}
wildcardString.append("?");
// Add casting for vector fields
if (fields.get(i) instanceof VectorStoreRecordVectorField) {
wildcardString.append("::vector");
}
}
return wildcardString.toString();
return fields.stream()
.map(field -> {
String wildcard = "?";
// Add casting for vector fields
if (field instanceof VectorStoreRecordVectorField) {
wildcard += "::vector";
}
return wildcard;
})
.collect(Collectors.joining(", "));
}

/**
Expand All @@ -250,30 +243,23 @@ public void upsertRecords(String collectionName, List<?> records,
validateSQLidentifier(getCollectionTableName(collectionName));
List<VectorStoreRecordField> fields = recordDefinition.getAllFields();

StringBuilder onDuplicateKeyUpdate = new StringBuilder();
for (VectorStoreRecordField field : fields) {
if (field instanceof VectorStoreRecordKeyField) {
continue;
}
if (onDuplicateKeyUpdate.length() > 0) {
onDuplicateKeyUpdate.append(", ");
}
onDuplicateKeyUpdate.append(field.getEffectiveStorageName())
.append(" = EXCLUDED.")
.append(field.getEffectiveStorageName());
}
String onDuplicateKeyUpdate = fields.stream()
.filter(field -> !(field instanceof VectorStoreRecordKeyField)) // Exclude key fields
.map(field -> validateSQLidentifier(field.getEffectiveStorageName())
+ " = EXCLUDED." + validateSQLidentifier(field.getEffectiveStorageName()))
.collect(Collectors.joining(", "));

String query = "INSERT INTO " + getCollectionTableName(collectionName)
+ " (" + getQueryColumnsFromFields(fields) + ")"
+ " VALUES (" + getWildcardStringWithCast(fields) + ")"
+ " ON CONFLICT (" + recordDefinition.getKeyField().getEffectiveStorageName()
+ " ON CONFLICT (" + getKeyColumnName(recordDefinition.getKeyField())
+ ") DO UPDATE SET "
+ onDuplicateKeyUpdate;

try (Connection connection = dataSource.getConnection();
PreparedStatement statement = connection.prepareStatement(query)) {
for (Object record : records) {
setStatementValues(statement, record, recordDefinition.getAllFields());
setUpsertStatementValues(statement, record, recordDefinition.getAllFields());
statement.addBatch();
}

Expand Down

0 comments on commit 4d25116

Please sign in to comment.