diff --git a/python/semantic_kernel/connectors/memory/postgres/postgres_collection.py b/python/semantic_kernel/connectors/memory/postgres/postgres_collection.py index 53a40eaf8cd2..7698264799bb 100644 --- a/python/semantic_kernel/connectors/memory/postgres/postgres_collection.py +++ b/python/semantic_kernel/connectors/memory/postgres/postgres_collection.py @@ -8,8 +8,10 @@ from semantic_kernel.connectors.memory.postgres.utils import ( convert_dict_to_row, convert_row_to_dict, + get_vector_index_ops_str, python_type_to_postgres, ) +from semantic_kernel.data.const import IndexKind if sys.version_info >= (3, 12): from typing import override # pragma: no cover @@ -192,8 +194,10 @@ def _deserialize_store_models_to_dicts(self, records: Sequence[Any], **kwargs: A async def create_collection(self, **kwargs: Any) -> None: """Create a PostgreSQL table based on a dictionary of VectorStoreRecordField. - :param table_name: Name of the table to be created - :param fields: A dictionary where keys are column names and values are VectorStoreRecordField instances + Args: + table_name: Name of the table to be created + fields: A dictionary where keys are column names and values are VectorStoreRecordField instances + **kwargs: Additional arguments """ column_definitions = [] table_name = self.collection_name @@ -206,6 +210,8 @@ async def create_collection(self, **kwargs: Any) -> None: property_type = python_type_to_postgres(field.property_type) or field.property_type.upper() # For Vector fields with dimensions, use pgvector's VECTOR type + # Note that other vector types are supported in pgvector (e.g. halfvec), + # but would need to be created outside of this method. if isinstance(field, VectorStoreRecordVectorField) and field.dimensions: column_definitions.append( sql.SQL("{} VECTOR({})").format(sql.Identifier(field_name), sql.Literal(field.dimensions)) @@ -221,23 +227,71 @@ async def create_collection(self, **kwargs: Any) -> None: columns_str = sql.SQL(", ").join(column_definitions) - # Create the final CREATE TABLE statement create_table_query = sql.SQL("CREATE TABLE {}.{} ({})").format( sql.Identifier(self.db_schema), sql.Identifier(table_name), columns_str ) try: - # Establish the database connection using psycopg3 with self.connection_pool.connection() as conn, conn.cursor() as cur: - # Execute the CREATE TABLE query cur.execute(create_table_query) conn.commit() logger.info(f"Postgres table '{table_name}' created successfully.") + # If the vector field defines an index, apply it + for vector_field in self.data_model_definition.vector_fields: + if vector_field.index_kind: + await self._create_index(table_name, vector_field) + except DatabaseError as error: raise MemoryConnectorException(f"Error creating table: {error}") from error + async def _create_index(self, table_name: str, vector_field: VectorStoreRecordVectorField) -> None: + """Create an index on a column in the table. + + Args: + table_name: The name of the table. + vector_field: The vector field definition that the index is based on. + """ + column_name = vector_field.name + index_name = f"{table_name}_{column_name}_idx" + + # Only support creating HNSW indexes through the vector store + if vector_field.index_kind != IndexKind.HNSW: + raise MemoryConnectorException( + f"Unsupported index kind: {vector_field.index_kind}. " + "If you need to create an index of this type, please do so manually. " + "Only HNSW indexes are supported through the vector store." + ) + + # Require the distance function to be set for HNSW indexes + if not vector_field.distance_function: + raise MemoryConnectorException( + "Distance function must be set for HNSW indexes. " + "Please set the distance function in the vector field definition." + ) + + ops_str = get_vector_index_ops_str(vector_field.distance_function) + + try: + with self.connection_pool.connection() as conn, conn.cursor() as cur: + cur.execute( + sql.SQL("CREATE INDEX {} ON {}.{} USING {} ({} {})").format( + sql.Identifier(index_name), + sql.Identifier(self.db_schema), + sql.Identifier(table_name), + sql.SQL(vector_field.index_kind), + sql.Identifier(column_name), + sql.SQL(ops_str), + ) + ) + conn.commit() + + logger.info(f"Index '{index_name}' created successfully on column '{column_name}'.") + + except DatabaseError as error: + raise MemoryConnectorException(f"Error creating index: {error}") from error + @override async def does_collection_exist(self, **kwargs: Any) -> bool: """Check if the collection exists.""" diff --git a/python/semantic_kernel/connectors/memory/postgres/utils.py b/python/semantic_kernel/connectors/memory/postgres/utils.py index 3325ca165a0f..71b0371ef8a2 100644 --- a/python/semantic_kernel/connectors/memory/postgres/utils.py +++ b/python/semantic_kernel/connectors/memory/postgres/utils.py @@ -8,6 +8,7 @@ from psycopg_pool import ConnectionPool +from semantic_kernel.data.const import DistanceFunction from semantic_kernel.data.vector_store_record_fields import VectorStoreRecordField, VectorStoreRecordVectorField @@ -101,3 +102,28 @@ def _convert(v: Any | None) -> Any | None: return v return tuple(_convert(record.get(field.name)) for _, field in fields) + + +def get_vector_index_ops_str(distance_function: DistanceFunction) -> str: + """Get the PostgreSQL ops string for creating an index for a given distance function. + + Args: + distance_function: The distance function the index is created for. + + Returns: + The PostgreSQL ops string for the given distance function. + + Examples: + >>> get_vector_index_ops_str(DistanceFunction.COSINE) + 'vector_cosine_ops' + """ + if distance_function == DistanceFunction.COSINE: + return "vector_cosine_ops" + if distance_function == DistanceFunction.DOT_PROD: + return "vector_ip_ops" + if distance_function == DistanceFunction.EUCLIDEAN: + return "vector_l2_ops" + if distance_function == DistanceFunction.MANHATTAN: + return "vector_l1_ops" + + raise ValueError(f"Unsupported distance function: {distance_function}") diff --git a/python/tests/integration/connectors/memory/test_postgres.py b/python/tests/integration/connectors/memory/test_postgres.py index 59f0626500e4..a5d4dcea2bec 100644 --- a/python/tests/integration/connectors/memory/test_postgres.py +++ b/python/tests/integration/connectors/memory/test_postgres.py @@ -136,6 +136,13 @@ async def test_upsert_get_and_delete(simple_collection: VectorStoreRecordCollect assert result.embedding == record.embedding assert result.data == record.data + # Check that the table has an index + connection_pool = simple_collection.connection_pool + with connection_pool.connection() as conn, conn.cursor() as cur: + cur.execute("SELECT indexname FROM pg_indexes WHERE tablename = %s", (simple_collection.collection_name,)) + index_names = [index[0] for index in cur.fetchall()] + assert any("embedding_idx" in index_name for index_name in index_names) + await simple_collection.delete(1) result_after_delete = await simple_collection.get(1) assert result_after_delete is None diff --git a/python/tests/unit/connectors/memory/test_postgres.py b/python/tests/unit/connectors/memory/test_postgres.py index 19c16596ad37..0fc1d7dfbe70 100644 --- a/python/tests/unit/connectors/memory/test_postgres.py +++ b/python/tests/unit/connectors/memory/test_postgres.py @@ -137,15 +137,26 @@ async def test_create_collection_simple_model(vector_store: PostgresStore, mock_ collection = vector_store.get_collection("test_collection", SimpleDataModel) await collection.create_collection() - assert mock_cursor.execute.call_count == 1 - execute_args, _ = mock_cursor.execute.call_args + # 2 calls, once for the table creation and once for the index creation + assert mock_cursor.execute.call_count == 2 + + # Check the table creation statement + execute_args, _ = mock_cursor.execute.call_args_list[0] statement = execute_args[0] statement_str = statement.as_string() - assert statement_str == ( 'CREATE TABLE "public"."test_collection" ("id" INTEGER PRIMARY KEY, "embedding" VECTOR(1536), "data" JSONB)' ) + # Check the index creation statement + execute_args, _ = mock_cursor.execute.call_args_list[1] + statement = execute_args[0] + statement_str = statement.as_string() + assert statement_str == ( + 'CREATE INDEX "test_collection_embedding_idx" ON "public"."test_collection" ' + 'USING hnsw ("embedding" vector_cosine_ops)' + ) + @mark.asyncio async def test_create_collection_model_with_python_types(vector_store: PostgresStore, mock_cursor: Mock) -> None: