Skip to content

Commit

Permalink
feat: add update_corpus method for vertex rag
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 683215818
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Oct 8, 2024
1 parent 1e49799 commit 739d92c
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 2 deletions.
2 changes: 2 additions & 0 deletions tests/unit/vertex_rag/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def rag_data_client_mock_exception():
api_client_mock = mock.Mock(spec=VertexRagDataServiceClient)
# create_rag_corpus
api_client_mock.create_rag_corpus.side_effect = Exception
# update_rag_corpus
api_client_mock.update_rag_corpus.side_effect = Exception
# get_rag_corpus
api_client_mock.get_rag_corpus.side_effect = Exception
# list_rag_corpora
Expand Down
141 changes: 141 additions & 0 deletions tests/unit/vertex_rag/test_rag_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,70 @@ def create_rag_corpus_mock_pinecone():
yield create_rag_corpus_mock_pinecone


@pytest.fixture
def update_rag_corpus_mock_weaviate():
with mock.patch.object(
VertexRagDataServiceClient,
"update_rag_corpus",
) as update_rag_corpus_mock_weaviate:
update_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
update_rag_corpus_lro_mock.done.return_value = True
update_rag_corpus_lro_mock.result.return_value = (
tc.TEST_GAPIC_RAG_CORPUS_WEAVIATE
)
update_rag_corpus_mock_weaviate.return_value = update_rag_corpus_lro_mock
yield update_rag_corpus_mock_weaviate


@pytest.fixture
def update_rag_corpus_mock_vertex_feature_store():
with mock.patch.object(
VertexRagDataServiceClient,
"update_rag_corpus",
) as update_rag_corpus_mock_vertex_feature_store:
update_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
update_rag_corpus_lro_mock.done.return_value = True
update_rag_corpus_lro_mock.result.return_value = (
tc.TEST_GAPIC_RAG_CORPUS_VERTEX_FEATURE_STORE
)
update_rag_corpus_mock_vertex_feature_store.return_value = (
update_rag_corpus_lro_mock
)
yield update_rag_corpus_mock_vertex_feature_store


@pytest.fixture
def update_rag_corpus_mock_vertex_vector_search():
with mock.patch.object(
VertexRagDataServiceClient,
"update_rag_corpus",
) as update_rag_corpus_mock_vertex_vector_search:
update_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
update_rag_corpus_lro_mock.done.return_value = True
update_rag_corpus_lro_mock.result.return_value = (
tc.TEST_GAPIC_RAG_CORPUS_VERTEX_VECTOR_SEARCH
)
update_rag_corpus_mock_vertex_vector_search.return_value = (
update_rag_corpus_lro_mock
)
yield update_rag_corpus_mock_vertex_vector_search


@pytest.fixture
def update_rag_corpus_mock_pinecone():
with mock.patch.object(
VertexRagDataServiceClient,
"update_rag_corpus",
) as update_rag_corpus_mock_pinecone:
update_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
update_rag_corpus_lro_mock.done.return_value = True
update_rag_corpus_lro_mock.result.return_value = (
tc.TEST_GAPIC_RAG_CORPUS_PINECONE
)
update_rag_corpus_mock_pinecone.return_value = update_rag_corpus_lro_mock
yield update_rag_corpus_mock_pinecone


@pytest.fixture
def list_rag_corpora_pager_mock():
with mock.patch.object(
Expand Down Expand Up @@ -298,6 +362,83 @@ def test_create_corpus_failure(self):
rag.create_corpus(display_name=tc.TEST_CORPUS_DISPLAY_NAME)
e.match("Failed in RagCorpus creation due to")

@pytest.mark.usefixtures("update_rag_corpus_mock_weaviate")
def test_update_corpus_weaviate_success(self):
rag_corpus = rag.update_corpus(
corpus_name=tc.TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=tc.TEST_CORPUS_DISPLAY_NAME,
vector_db=tc.TEST_WEAVIATE_CONFIG,
)

rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_WEAVIATE)

@pytest.mark.usefixtures("update_rag_corpus_mock_weaviate")
def test_update_corpus_weaviate_no_display_name_success(self):
rag_corpus = rag.update_corpus(
corpus_name=tc.TEST_RAG_CORPUS_RESOURCE_NAME,
vector_db=tc.TEST_WEAVIATE_CONFIG,
)

rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_WEAVIATE)

@pytest.mark.usefixtures("update_rag_corpus_mock_weaviate")
def test_update_corpus_weaviate_with_description_success(self):
rag_corpus = rag.update_corpus(
corpus_name=tc.TEST_RAG_CORPUS_RESOURCE_NAME,
description=tc.TEST_CORPUS_DISCRIPTION,
vector_db=tc.TEST_WEAVIATE_CONFIG,
)

rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_WEAVIATE)

@pytest.mark.usefixtures("update_rag_corpus_mock_weaviate")
def test_update_corpus_weaviate_with_description_and_display_name_success(self):
rag_corpus = rag.update_corpus(
corpus_name=tc.TEST_RAG_CORPUS_RESOURCE_NAME,
description=tc.TEST_CORPUS_DISCRIPTION,
display_name=tc.TEST_CORPUS_DISPLAY_NAME,
vector_db=tc.TEST_WEAVIATE_CONFIG,
)

rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_WEAVIATE)

@pytest.mark.usefixtures("update_rag_corpus_mock_vertex_feature_store")
def test_update_corpus_vertex_feature_store_success(self):
rag_corpus = rag.update_corpus(
corpus_name=tc.TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=tc.TEST_CORPUS_DISPLAY_NAME,
vector_db=tc.TEST_VERTEX_FEATURE_STORE_CONFIG,
)

rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_VERTEX_FEATURE_STORE)

@pytest.mark.usefixtures("update_rag_corpus_mock_vertex_vector_search")
def test_update_corpus_vertex_vector_search_success(self):
rag_corpus = rag.update_corpus(
corpus_name=tc.TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=tc.TEST_CORPUS_DISPLAY_NAME,
vector_db=tc.TEST_VERTEX_VECTOR_SEARCH_CONFIG,
)
rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_VERTEX_VECTOR_SEARCH)

@pytest.mark.usefixtures("update_rag_corpus_mock_pinecone")
def test_update_corpus_pinecone_success(self):
rag_corpus = rag.update_corpus(
corpus_name=tc.TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=tc.TEST_CORPUS_DISPLAY_NAME,
vector_db=tc.TEST_PINECONE_CONFIG,
)
rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_PINECONE)

@pytest.mark.usefixtures("rag_data_client_mock_exception")
def test_update_corpus_failure(self):
with pytest.raises(RuntimeError) as e:
rag.update_corpus(
corpus_name=tc.TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=tc.TEST_CORPUS_DISPLAY_NAME,
)
e.match("Failed in RagCorpus update due to")

@pytest.mark.usefixtures("rag_data_client_mock")
def test_get_corpus_success(self):
rag_corpus = rag.get_corpus(tc.TEST_RAG_CORPUS_RESOURCE_NAME)
Expand Down
2 changes: 2 additions & 0 deletions vertexai/preview/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from vertexai.preview.rag.rag_data import (
create_corpus,
update_corpus,
list_corpora,
get_corpus,
delete_corpus,
Expand Down Expand Up @@ -84,4 +85,5 @@
"list_files",
"retrieval_query",
"upload_file",
"update_corpus",
)
82 changes: 80 additions & 2 deletions vertexai/preview/rag/rag_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#
"""RAG data management SDK."""

from typing import Optional, Union, Sequence
from typing import Optional, Sequence, Union
from google import auth
from google.api_core import operation_async
from google.auth.transport import requests as google_auth_requests
Expand All @@ -33,8 +33,8 @@
ListRagCorporaRequest,
ListRagFilesRequest,
RagCorpus as GapicRagCorpus,
UpdateRagCorpusRequest,
)

from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service.pagers import (
ListRagCorporaPager,
ListRagFilesPager,
Expand Down Expand Up @@ -121,6 +121,84 @@ def create_corpus(
return _gapic_utils.convert_gapic_to_rag_corpus(response.result(timeout=600))


def update_corpus(
corpus_name: str,
display_name: Optional[str] = None,
description: Optional[str] = None,
vector_db: Optional[
Union[
Weaviate,
VertexFeatureStore,
VertexVectorSearch,
Pinecone,
RagManagedDb,
]
] = None,
) -> RagCorpus:
"""Updates a RagCorpus resource.
Example usage:
```
import vertexai
from vertexai.preview import rag
vertexai.init(project="my-project")
rag_corpus = rag.update_corpus(
corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1",
display_name="my-corpus-1",
)
```
Args:
corpus_name: The name of the RagCorpus resource to update. Format:
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` or
``{rag_corpus}``.
display_name: If not provided, the display name will not be updated. The
display name of the RagCorpus. The name can be up to 128 characters long
and can consist of any UTF-8 characters.
description: The description of the RagCorpus. If not provided, the
description will not be updated.
vector_db: The vector db config of the RagCorpus. If not provided, the
vector db will not be updated.
Returns:
RagCorpus.
Raises:
RuntimeError: Failed in RagCorpus update due to exception.
RuntimeError: Failed in RagCorpus update due to operation error.
"""
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
if display_name and description:
rag_corpus = GapicRagCorpus(
name=corpus_name, display_name=display_name, description=description
)
elif display_name:
rag_corpus = GapicRagCorpus(name=corpus_name, display_name=display_name)
elif description:
rag_corpus = GapicRagCorpus(name=corpus_name, description=description)
else:
rag_corpus = GapicRagCorpus(name=corpus_name)

_gapic_utils.set_vector_db(
vector_db=vector_db,
rag_corpus=rag_corpus,
)

request = UpdateRagCorpusRequest(
rag_corpus=rag_corpus,
)
client = _gapic_utils.create_rag_data_service_client()

try:
response = client.update_rag_corpus(request=request)
except Exception as e:
raise RuntimeError("Failed in RagCorpus update due to: ", e) from e
return _gapic_utils.convert_gapic_to_rag_corpus_no_embedding_model_config(
response.result(timeout=600)
)


def get_corpus(name: str) -> RagCorpus:
"""
Get an existing RagCorpus.
Expand Down
13 changes: 13 additions & 0 deletions vertexai/preview/rag/utils/_gapic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,19 @@ def convert_gapic_to_rag_corpus(gapic_rag_corpus: GapicRagCorpus) -> RagCorpus:
return rag_corpus


def convert_gapic_to_rag_corpus_no_embedding_model_config(
gapic_rag_corpus: GapicRagCorpus,
) -> RagCorpus:
"""Convert GapicRagCorpus without embedding model config (for UpdateRagCorpus) to RagCorpus."""
rag_corpus = RagCorpus(
name=gapic_rag_corpus.name,
display_name=gapic_rag_corpus.display_name,
description=gapic_rag_corpus.description,
vector_db=convert_gapic_to_vector_db(gapic_rag_corpus.rag_vector_db_config),
)
return rag_corpus


def convert_gapic_to_rag_file(gapic_rag_file: GapicRagFile) -> RagFile:
"""Convert GapicRagFile to RagFile."""
rag_file = RagFile(
Expand Down

0 comments on commit 739d92c

Please sign in to comment.