Skip to content

Commit

Permalink
Add index creation when creating collection.
Browse files Browse the repository at this point in the history
  • Loading branch information
lossyrob committed Sep 24, 2024
1 parent 3cbe91e commit 5189100
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from semantic_kernel.connectors.memory.postgres.utils import (
convert_dict_to_row,
convert_row_to_dict,
get_vector_index_ops_str,
python_type_to_postgres,
)
from semantic_kernel.data.const import IndexKind

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
Expand Down Expand Up @@ -192,8 +194,10 @@ def _deserialize_store_models_to_dicts(self, records: Sequence[Any], **kwargs: A
async def create_collection(self, **kwargs: Any) -> None:
"""Create a PostgreSQL table based on a dictionary of VectorStoreRecordField.
:param table_name: Name of the table to be created
:param fields: A dictionary where keys are column names and values are VectorStoreRecordField instances
Args:
table_name: Name of the table to be created
fields: A dictionary where keys are column names and values are VectorStoreRecordField instances
**kwargs: Additional arguments
"""
column_definitions = []
table_name = self.collection_name
Expand All @@ -206,6 +210,8 @@ async def create_collection(self, **kwargs: Any) -> None:
property_type = python_type_to_postgres(field.property_type) or field.property_type.upper()

# For Vector fields with dimensions, use pgvector's VECTOR type
# Note that other vector types are supported in pgvector (e.g. halfvec),
# but would need to be created outside of this method.
if isinstance(field, VectorStoreRecordVectorField) and field.dimensions:
column_definitions.append(
sql.SQL("{} VECTOR({})").format(sql.Identifier(field_name), sql.Literal(field.dimensions))
Expand All @@ -221,23 +227,71 @@ async def create_collection(self, **kwargs: Any) -> None:

columns_str = sql.SQL(", ").join(column_definitions)

# Create the final CREATE TABLE statement
create_table_query = sql.SQL("CREATE TABLE {}.{} ({})").format(
sql.Identifier(self.db_schema), sql.Identifier(table_name), columns_str
)

try:
# Establish the database connection using psycopg3
with self.connection_pool.connection() as conn, conn.cursor() as cur:
# Execute the CREATE TABLE query
cur.execute(create_table_query)
conn.commit()

logger.info(f"Postgres table '{table_name}' created successfully.")

# If the vector field defines an index, apply it
for vector_field in self.data_model_definition.vector_fields:
if vector_field.index_kind:
await self._create_index(table_name, vector_field)

except DatabaseError as error:
raise MemoryConnectorException(f"Error creating table: {error}") from error

async def _create_index(self, table_name: str, vector_field: VectorStoreRecordVectorField) -> None:
"""Create an index on a column in the table.
Args:
table_name: The name of the table.
vector_field: The vector field definition that the index is based on.
"""
column_name = vector_field.name
index_name = f"{table_name}_{column_name}_idx"

# Only support creating HNSW indexes through the vector store
if vector_field.index_kind != IndexKind.HNSW:
raise MemoryConnectorException(
f"Unsupported index kind: {vector_field.index_kind}. "
"If you need to create an index of this type, please do so manually. "
"Only HNSW indexes are supported through the vector store."
)

# Require the distance function to be set for HNSW indexes
if not vector_field.distance_function:
raise MemoryConnectorException(
"Distance function must be set for HNSW indexes. "
"Please set the distance function in the vector field definition."
)

ops_str = get_vector_index_ops_str(vector_field.distance_function)

try:
with self.connection_pool.connection() as conn, conn.cursor() as cur:
cur.execute(
sql.SQL("CREATE INDEX {} ON {}.{} USING {} ({} {})").format(
sql.Identifier(index_name),
sql.Identifier(self.db_schema),
sql.Identifier(table_name),
sql.SQL(vector_field.index_kind),
sql.Identifier(column_name),
sql.SQL(ops_str),
)
)
conn.commit()

logger.info(f"Index '{index_name}' created successfully on column '{column_name}'.")

except DatabaseError as error:
raise MemoryConnectorException(f"Error creating index: {error}") from error

@override
async def does_collection_exist(self, **kwargs: Any) -> bool:
"""Check if the collection exists."""
Expand Down
26 changes: 26 additions & 0 deletions python/semantic_kernel/connectors/memory/postgres/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from psycopg_pool import ConnectionPool

from semantic_kernel.data.const import DistanceFunction
from semantic_kernel.data.vector_store_record_fields import VectorStoreRecordField, VectorStoreRecordVectorField


Expand Down Expand Up @@ -101,3 +102,28 @@ def _convert(v: Any | None) -> Any | None:
return v

return tuple(_convert(record.get(field.name)) for _, field in fields)


def get_vector_index_ops_str(distance_function: DistanceFunction) -> str:
"""Get the PostgreSQL ops string for creating an index for a given distance function.
Args:
distance_function: The distance function the index is created for.
Returns:
The PostgreSQL ops string for the given distance function.
Examples:
>>> get_vector_index_ops_str(DistanceFunction.COSINE)
'vector_cosine_ops'
"""
if distance_function == DistanceFunction.COSINE:
return "vector_cosine_ops"
if distance_function == DistanceFunction.DOT_PROD:
return "vector_ip_ops"
if distance_function == DistanceFunction.EUCLIDEAN:
return "vector_l2_ops"
if distance_function == DistanceFunction.MANHATTAN:
return "vector_l1_ops"

raise ValueError(f"Unsupported distance function: {distance_function}")
7 changes: 7 additions & 0 deletions python/tests/integration/connectors/memory/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ async def test_upsert_get_and_delete(simple_collection: VectorStoreRecordCollect
assert result.embedding == record.embedding
assert result.data == record.data

# Check that the table has an index
connection_pool = simple_collection.connection_pool
with connection_pool.connection() as conn, conn.cursor() as cur:
cur.execute("SELECT indexname FROM pg_indexes WHERE tablename = %s", (simple_collection.collection_name,))
index_names = [index[0] for index in cur.fetchall()]
assert any("embedding_idx" in index_name for index_name in index_names)

await simple_collection.delete(1)
result_after_delete = await simple_collection.get(1)
assert result_after_delete is None
Expand Down
17 changes: 14 additions & 3 deletions python/tests/unit/connectors/memory/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,26 @@ async def test_create_collection_simple_model(vector_store: PostgresStore, mock_
collection = vector_store.get_collection("test_collection", SimpleDataModel)
await collection.create_collection()

assert mock_cursor.execute.call_count == 1
execute_args, _ = mock_cursor.execute.call_args
# 2 calls, once for the table creation and once for the index creation
assert mock_cursor.execute.call_count == 2

# Check the table creation statement
execute_args, _ = mock_cursor.execute.call_args_list[0]
statement = execute_args[0]
statement_str = statement.as_string()

assert statement_str == (
'CREATE TABLE "public"."test_collection" ("id" INTEGER PRIMARY KEY, "embedding" VECTOR(1536), "data" JSONB)'
)

# Check the index creation statement
execute_args, _ = mock_cursor.execute.call_args_list[1]
statement = execute_args[0]
statement_str = statement.as_string()
assert statement_str == (
'CREATE INDEX "test_collection_embedding_idx" ON "public"."test_collection" '
'USING hnsw ("embedding" vector_cosine_ops)'
)


@mark.asyncio
async def test_create_collection_model_with_python_types(vector_store: PostgresStore, mock_cursor: Mock) -> None:
Expand Down

0 comments on commit 5189100

Please sign in to comment.