diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 7730b455516d..9b83abf2739e 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -5,29 +5,57 @@ version: 2 updates: + # Maintain dependencies for nuget + - package-ecosystem: "nuget" + directory: "samples/apps/copilot-chat-app/webapi" + schedule: + interval: "weekly" + day: "monday" + labels: + - "copilot chat" + - "dependencies" + # Maintain dependencies for nuget - package-ecosystem: "nuget" directory: "dotnet/" schedule: interval: "weekly" + day: "monday" + ignore: + # For all System.* and Microsoft.Extensions/Bcl.* packages, ignore all major version updates + - dependency-name: "System.*" + update-types: ["version-update:semver-major"] + - dependency-name: "Microsoft.Extensions.*" + update-types: ["version-update:semver-major"] + - dependency-name: "Microsoft.Bcl.*" + update-types: ["version-update:semver-major"] + labels: + - ".NET" + - "dependencies" # Maintain dependencies for nuget - package-ecosystem: "nuget" - directory: "samples/" + directory: "samples/dotnet" schedule: interval: "weekly" - + day: "monday" + # Maintain dependencies for npm - package-ecosystem: "npm" directory: "samples/apps" schedule: interval: "weekly" + day: "monday" # Maintain dependencies for pip - package-ecosystem: "pip" directory: "python/" schedule: interval: "weekly" + day: "monday" + labels: + - "python" + - "dependencies" # Maintain dependencies for github-actions - package-ecosystem: "github-actions" @@ -36,3 +64,4 @@ updates: directory: "/" schedule: interval: "weekly" + day: "monday" diff --git a/.github/labeler.yml b/.github/labeler.yml index ea26d78e57cf..64d48e798940 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -1,10 +1,10 @@ # Add 'kernel' label to any change within Connectors, Extensions, Skills, and tests directories kernel: - - "dotnet/src/Connectors/**/*" - - "dotnet/src/Extensions/**/*" - - "dotnet/src/Skills/**/*" - - "dotnet/src/IntegrationTests/**/*" - - "dotnet/src/SemanticKernel.UnitTests/**/*" + - dotnet/src/Connectors/**/* + - dotnet/src/Extensions/**/* + - dotnet/src/Skills/**/* + - dotnet/src/IntegrationTests/**/* + - dotnet/src/SemanticKernel.UnitTests/**/* # Add 'kernel.core' label to any change within the 'SemanticKernel', 'SemanticKernel.Abstractions', or 'SemanticKernel.MetaPackage' directories kernel.core: @@ -16,11 +16,28 @@ kernel.core: python: - python/**/* +# Add 'java' label to any change within the 'java' directory +java: + - java/**/* + # Add 'samples' label to any change within the 'samples' directory samples: - samples/**/* # Add '.NET' label to any change within samples or kernel 'dotnet' directories. .NET: - - dotnet/src/**/* - - samples/**/dotnet/**/* + - dotnet/**/* + +# Add 'copilot chat' label to any change within the 'samples/apps/copilot-chat-app' directory +copilot chat: + - samples/apps/copilot-chat-app/**/* + +# Add 'documentation' label to any change within the 'docs' directory, or any '.md' files +documentation: + - docs/**/* + - '**/*.md' + +# Add 'memory' label to any memory connectors in dotnet/ or python/ +memory: + - dotnet/src/Connectors/Connectors.Memory.*/**/* + - python/semantic_kernel/connectors/memory/**/* diff --git a/.github/workflows/dotnet-pr.yml b/.github/workflows/dotnet-pr.yml index 05efaa0b30a0..688a01aa1e0d 100644 --- a/.github/workflows/dotnet-pr.yml +++ b/.github/workflows/dotnet-pr.yml @@ -11,8 +11,6 @@ on: paths: - 'dotnet/**' - 'samples/dotnet/**' - - '**.cs' - - '**.csproj' concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} diff --git a/.github/workflows/java-format.yml b/.github/workflows/java-format.yml index 1507f330845c..5cd460b67632 100644 --- a/.github/workflows/java-format.yml +++ b/.github/workflows/java-format.yml @@ -19,6 +19,14 @@ jobs: pull-requests: write steps: + - name: Check for command + id: command + uses: xt0rted/slash-command-action@v2 + continue-on-error: true + with: + command: spotless + reaction-type: "eyes" + - name: Get command env: BODY: ${{ github.event.comment.body }} @@ -77,4 +85,4 @@ jobs: " gh pr comment $NUMBER --body "$body" fi - working-directory: java \ No newline at end of file + working-directory: java diff --git a/.github/workflows/label-issues.yml b/.github/workflows/label-issues.yml new file mode 100644 index 000000000000..946b65da21cb --- /dev/null +++ b/.github/workflows/label-issues.yml @@ -0,0 +1,44 @@ +name: Label issues +on: + issues: + types: + - reopened + - opened + +jobs: + label_issues: + name: "Issue: add labels" + if: ${{ github.event.action == 'opened' || github.event.action == 'reopened' }} + runs-on: ubuntu-latest + permissions: + issues: write + steps: + - uses: actions/github-script@v6 + with: + script: | + // Get the issue body and title + const body = context.payload.issue.body + let title = context.payload.issue.title + + // Define the labels array + let labels = ["triage"] + + // Check if the body or the title contains the word 'python' (case-insensitive) + if (body.match(/python/i) || title.match(/python/i)) { + // Add the 'python' label to the array + labels.push("python") + } + + // Check if the body or the title contains the word 'java' (case-insensitive) + if (body.match(/java/i) || title.match(/java/i)) { + // Add the 'java' label to the array + labels.push("java") + } + + // Add the labels to the issue + github.rest.issues.addLabels({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + labels: labels + }); diff --git a/.github/workflows/label-title-prefix.yml b/.github/workflows/label-title-prefix.yml new file mode 100644 index 000000000000..4686e4977c64 --- /dev/null +++ b/.github/workflows/label-title-prefix.yml @@ -0,0 +1,97 @@ +name: Label title prefix +on: + issues: + types: [ labeled ] + pull_request: + types: [ labeled ] + +jobs: + add_title_prefix: + name: "Issue/PR: add title prefix" + # Define a matrix of label and prefix pairs + strategy: + matrix: + include: + - {label: 'python', prefix: 'Python'} + - {label: 'java', prefix: 'Java'} + - {label: '.NET', prefix: '.Net'} + - {label: 'copilot chat', prefix: 'Copilot Chat'} + + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + + steps: + - uses: actions/github-script@v6 + name: "Issue/PR: update title" + with: + script: | + // Get the label and prefix from the matrix + const label = '${{ matrix.label }}' + const prefix = '${{ matrix.prefix }}' + + labelAdded = context.payload.label.name + + // Write the contents of context to console + core.info(JSON.stringify(context, null, 2)) + + // Get the event name, title and labels + let title + switch(context.eventName) { + case 'issues': + title = context.payload.issue.title + break + case 'pull_request': + title = context.payload.pull_request.title + break + default: + core.setFailed('Unrecognited eventName: ' + context.eventName) + } + + let originalTitle = title + + // Update the title based on the label and prefix + // Check if the issue or PR has the label + if (labelAdded == label) { + // Check if the title starts with the prefix (case-sensitive) + if (!title.startsWith(prefix + ": ")) { + // If not, check if the first word is the label (case-insensitive) + if (title.match(new RegExp(`^${prefix}`, 'i'))) { + // If yes, replace it with the prefix (case-sensitive) + title = title.replace(new RegExp(`^${prefix}`, 'i'), prefix) + } else { + // If not, prepend the prefix to the title + title = prefix + ": " + title + } + } + } + + // Update the issue or PR title, if changed + if (title != originalTitle ) { + switch(context.eventName) { + case 'issues': + github.rest.issues.update({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + title: title + }); + break + case 'pull_request': + try { + github.rest.pulls.update({ + pull_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + title: title + }); + } + catch (err) { + core.info("Update PR title failed: " + err.message) + } + break + default: + core.setFailed('Unrecognited eventName: ' + context.eventName) + } + } diff --git a/.github/workflows/node-pr.yml b/.github/workflows/node-pr.yml index 32e966752d76..1bdabff5ab72 100644 --- a/.github/workflows/node-pr.yml +++ b/.github/workflows/node-pr.yml @@ -33,10 +33,8 @@ jobs: yarndirs=() for lockfile in samples/apps/**/yarn.lock; do # loop over all yarn.lock files dir=$(dirname "$lockfile") # get the directory of the lock file - if [[ "$dir" != "samples/apps" ]]; then # exclude samples/apps directory - echo "Found yarn project in $dir" - yarndirs+=("$dir") # add the directory to the yarndirs array - fi + echo "Found yarn project in $dir" + yarndirs+=("$dir") # add the directory to the yarndirs array done echo "All yarn projects found: '${yarndirs[*]}'" diff --git a/.github/workflows/python-integration-tests.yml b/.github/workflows/python-integration-tests.yml index ad77410c9397..78e576d4c7b8 100644 --- a/.github/workflows/python-integration-tests.yml +++ b/.github/workflows/python-integration-tests.yml @@ -7,12 +7,11 @@ name: Python Integration Tests on: workflow_dispatch: push: - branches: [ "main"] - paths: - - 'python/**' + branches: ["main"] + paths: + - "python/**" schedule: - - cron: '0 */12 * * *' # Run every 12 hours: midnight UTC and noon UTC - + - cron: "0 */12 * * *" # Run every 12 hours: midnight UTC and noon UTC permissions: contents: read @@ -25,44 +24,45 @@ jobs: fail-fast: false matrix: python-version: ["3.8", "3.9", "3.10", "3.11"] - os: [ ubuntu-latest, windows-latest, macos-latest ] - + os: [ubuntu-latest, windows-latest, macos-latest] + steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies wtih hnswlib native disabled - if: matrix.os == 'macos-latest' && matrix.python-version == '3.11' - run: | - export HNSWLIB_NO_NATIVE=1 - python -m pip install --upgrade pip setuptools wheel - python -m pip install poetry pytest - cd python && poetry install --with hugging_face --with chromadb --with weaviate - - name: Install dependencies wtih hnswlib native enabled - if: matrix.os != 'macos-latest' || matrix.python-version != '3.11' - run: | - python -m pip install --upgrade pip setuptools wheel - python -m pip install poetry pytest - cd python && poetry install --with hugging_face --with chromadb --with weaviate - - name: Run Integration Tests - shell: bash - env: # Set Azure credentials secret as an input - HNSWLIB_NO_NATIVE: 1 - Python_Integration_Tests: Python_Integration_Tests - AzureOpenAI__Label: azure-text-davinci-003 - AzureOpenAIEmbedding__Label: azure-text-embedding-ada-002 - AzureOpenAI__DeploymentName: ${{ vars.AZUREOPENAI__DEPLOYMENTNAME }} - AzureOpenAIChat__DeploymentName: ${{ vars.AZUREOPENAI__CHAT__DEPLOYMENTNAME }} - AzureOpenAIEmbeddings__DeploymentName: ${{ vars.AZUREOPENAIEMBEDDING__DEPLOYMENTNAME }} - AzureOpenAI__Endpoint: ${{ secrets.AZUREOPENAI__ENDPOINT }} - AzureOpenAIEmbeddings__Endpoint: ${{ secrets.AZUREOPENAI__ENDPOINT }} - AzureOpenAI__ApiKey: ${{ secrets.AZUREOPENAI__APIKEY }} - AzureOpenAIEmbeddings__ApiKey: ${{ secrets.AZUREOPENAI__APIKEY }} - Bing__ApiKey: ${{ secrets.BING__APIKEY }} - OpenAI__ApiKey: ${{ secrets.OPENAI__APIKEY }} - run: | - cd python - poetry run pytest ./tests/integration - + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies with hnswlib native disabled + if: matrix.os == 'macos-latest' && matrix.python-version == '3.11' + run: | + export HNSWLIB_NO_NATIVE=1 + python -m pip install --upgrade pip setuptools wheel + python -m pip install poetry pytest + cd python && poetry install --with hugging_face --with chromadb --with weaviate + - name: Install dependencies with hnswlib native enabled + if: matrix.os != 'macos-latest' || matrix.python-version != '3.11' + run: | + python -m pip install --upgrade pip setuptools wheel + python -m pip install poetry pytest + cd python && poetry install --with hugging_face --with chromadb --with weaviate + - name: Run Integration Tests + shell: bash + env: # Set Azure credentials secret as an input + HNSWLIB_NO_NATIVE: 1 + Python_Integration_Tests: Python_Integration_Tests + AzureOpenAI__Label: azure-text-davinci-003 + AzureOpenAIEmbedding__Label: azure-text-embedding-ada-002 + AzureOpenAI__DeploymentName: ${{ vars.AZUREOPENAI__DEPLOYMENTNAME }} + AzureOpenAIChat__DeploymentName: ${{ vars.AZUREOPENAI__CHAT__DEPLOYMENTNAME }} + AzureOpenAIEmbeddings__DeploymentName: ${{ vars.AZUREOPENAIEMBEDDING__DEPLOYMENTNAME }} + AzureOpenAI__Endpoint: ${{ secrets.AZUREOPENAI__ENDPOINT }} + AzureOpenAIEmbeddings__Endpoint: ${{ secrets.AZUREOPENAI__ENDPOINT }} + AzureOpenAI__ApiKey: ${{ secrets.AZUREOPENAI__APIKEY }} + AzureOpenAIEmbeddings__ApiKey: ${{ secrets.AZUREOPENAI__APIKEY }} + Bing__ApiKey: ${{ secrets.BING__APIKEY }} + OpenAI__ApiKey: ${{ secrets.OPENAI__APIKEY }} + Pinecone__ApiKey: ${{ secrets.PINECONE__APIKEY }} + Pinecone__Environment: ${{ secrets.PINECONE__ENVIRONMENT }} + run: | + cd python + poetry run pytest ./tests/integration diff --git a/.gitignore b/.gitignore index 431046016861..530b14105f74 100644 --- a/.gitignore +++ b/.gitignore @@ -476,4 +476,5 @@ playwright-report/ # Static Web App deployment config swa-cli.config.json -**/copilot-chat-app/webapp/build \ No newline at end of file +**/copilot-chat-app/webapp/build +**/copilot-chat-app/webapp/node_modules \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 6fcbd719ff9a..f07c5679f7ee 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -57,9 +57,8 @@ } ], "eslint.options": { - "overrideConfigFile": ".eslintrc.js" + "overrideConfigFile": "./package.json" }, - "eslint.packageManager": "yarn", "files.associations": { "*.json": "jsonc" }, diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index 5123c7531317..465f3f3711e9 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -9,32 +9,32 @@ - - - + + + - + - - + + - + - - - - + + + + - + - + @@ -44,7 +44,7 @@ - + diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index c2e04a6b16e7..f6f1bfd35c2f 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -62,8 +62,6 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "connectors", "connectors", EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.UnitTests", "src\Connectors\Connectors.UnitTests\Connectors.UnitTests.csproj", "{EB3FC57F-E591-4C88-BCD5-B6A1BC635168}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "CopilotChatWebApi", "..\samples\apps\copilot-chat-app\webapi\CopilotChatWebApi.csproj", "{CCABF515-2C79-453E-A5A2-69C69B8D172E}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Memory.Qdrant", "src\Connectors\Connectors.Memory.Qdrant\Connectors.Memory.Qdrant.csproj", "{5DEBAA62-F117-496A-8778-FED3604B70E2}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Memory.Sqlite", "src\Connectors\Connectors.Memory.Sqlite\Connectors.Memory.Sqlite.csproj", "{EC004F12-2F60-4EDD-B3CD-3A504900D929}" @@ -243,9 +241,6 @@ Global {EB3FC57F-E591-4C88-BCD5-B6A1BC635168}.Publish|Any CPU.Build.0 = Release|Any CPU {EB3FC57F-E591-4C88-BCD5-B6A1BC635168}.Release|Any CPU.ActiveCfg = Release|Any CPU {EB3FC57F-E591-4C88-BCD5-B6A1BC635168}.Release|Any CPU.Build.0 = Release|Any CPU - {CCABF515-2C79-453E-A5A2-69C69B8D172E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {CCABF515-2C79-453E-A5A2-69C69B8D172E}.Publish|Any CPU.ActiveCfg = Release|Any CPU - {CCABF515-2C79-453E-A5A2-69C69B8D172E}.Release|Any CPU.ActiveCfg = Release|Any CPU {5DEBAA62-F117-496A-8778-FED3604B70E2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {5DEBAA62-F117-496A-8778-FED3604B70E2}.Debug|Any CPU.Build.0 = Debug|Any CPU {5DEBAA62-F117-496A-8778-FED3604B70E2}.Publish|Any CPU.ActiveCfg = Publish|Any CPU @@ -403,7 +398,6 @@ Global {BC70A5D8-2125-4C37-8C0E-C903EAFA9772} = {FA3720F1-C99A-49B2-9577-A940257098BF} {0247C2C9-86C3-45BA-8873-28B0948EDC0C} = {831DDCA2-7D2C-4C31-80DB-6BDB3E1F7AE0} {EB3FC57F-E591-4C88-BCD5-B6A1BC635168} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} - {CCABF515-2C79-453E-A5A2-69C69B8D172E} = {FA3720F1-C99A-49B2-9577-A940257098BF} {5DEBAA62-F117-496A-8778-FED3604B70E2} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} {EC004F12-2F60-4EDD-B3CD-3A504900D929} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} {EA61C289-7928-4B78-A9C1-7AAD61F907CD} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} diff --git a/dotnet/samples/KernelSyntaxExamples/Example39_Postgres.cs b/dotnet/samples/KernelSyntaxExamples/Example39_Postgres.cs index af9a2b7ecffa..b9ac79e1e27c 100644 --- a/dotnet/samples/KernelSyntaxExamples/Example39_Postgres.cs +++ b/dotnet/samples/KernelSyntaxExamples/Example39_Postgres.cs @@ -20,7 +20,7 @@ public static async Task RunAsync() dataSourceBuilder.UseVector(); using NpgsqlDataSource dataSource = dataSourceBuilder.Build(); - PostgresMemoryStore memoryStore = new(dataSource, vectorSize: 1536, schema: "public", numberOfLists: 100); + PostgresMemoryStore memoryStore = new(dataSource, vectorSize: 1536, schema: "public"); IKernel kernel = Kernel.Builder .WithLogger(ConsoleLogger.Log) diff --git a/dotnet/samples/KernelSyntaxExamples/KernelSyntaxExamples.csproj b/dotnet/samples/KernelSyntaxExamples/KernelSyntaxExamples.csproj index 996ea3b9b97d..0660b619b123 100644 --- a/dotnet/samples/KernelSyntaxExamples/KernelSyntaxExamples.csproj +++ b/dotnet/samples/KernelSyntaxExamples/KernelSyntaxExamples.csproj @@ -44,6 +44,10 @@ + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.Chroma/ChromaClient.cs b/dotnet/src/Connectors/Connectors.Memory.Chroma/ChromaClient.cs index 87fd45448e47..d80cfb9833c9 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Chroma/ChromaClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Chroma/ChromaClient.cs @@ -106,11 +106,11 @@ public async IAsyncEnumerable ListCollectionsAsync([EnumeratorCancellati } /// - public async Task AddEmbeddingsAsync(string collectionId, string[] ids, float[][] embeddings, object[]? metadatas = null, CancellationToken cancellationToken = default) + public async Task UpsertEmbeddingsAsync(string collectionId, string[] ids, float[][] embeddings, object[]? metadatas = null, CancellationToken cancellationToken = default) { - this._logger.LogDebug("Adding embeddings to collection with id: {0}", collectionId); + this._logger.LogDebug("Upserting embeddings to collection with id: {0}", collectionId); - using var request = AddEmbeddingsRequest.Create(collectionId, ids, embeddings, metadatas).Build(); + using var request = UpsertEmbeddingsRequest.Create(collectionId, ids, embeddings, metadatas).Build(); await this.ExecuteHttpRequestAsync(request, cancellationToken).ConfigureAwait(false); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Chroma/ChromaMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.Chroma/ChromaMemoryStore.cs index 3a7f7d4192dd..08d80d435eea 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Chroma/ChromaMemoryStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Chroma/ChromaMemoryStore.cs @@ -212,7 +212,7 @@ public async IAsyncEnumerable UpsertBatchAsync(string collectionName, IE metadatas[i] = recordsArray[i].Metadata; } - await this._chromaClient.AddEmbeddingsAsync(collection.Id, ids, embeddings, metadatas, cancellationToken).ConfigureAwait(false); + await this._chromaClient.UpsertEmbeddingsAsync(collection.Id, ids, embeddings, metadatas, cancellationToken).ConfigureAwait(false); foreach (var record in recordsArray) { diff --git a/dotnet/src/Connectors/Connectors.Memory.Chroma/Http/ApiSchema/Internal/AddEmbeddingsRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Chroma/Http/ApiSchema/Internal/UpsertEmbeddingsRequest.cs similarity index 68% rename from dotnet/src/Connectors/Connectors.Memory.Chroma/Http/ApiSchema/Internal/AddEmbeddingsRequest.cs rename to dotnet/src/Connectors/Connectors.Memory.Chroma/Http/ApiSchema/Internal/UpsertEmbeddingsRequest.cs index 04872516fbc5..bda2ca179c50 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Chroma/Http/ApiSchema/Internal/AddEmbeddingsRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Chroma/Http/ApiSchema/Internal/UpsertEmbeddingsRequest.cs @@ -5,7 +5,7 @@ namespace Microsoft.SemanticKernel.Connectors.Memory.Chroma.Http.ApiSchema.Internal; -internal sealed class AddEmbeddingsRequest +internal sealed class UpsertEmbeddingsRequest { [JsonIgnore] public string CollectionId { get; set; } @@ -19,19 +19,19 @@ internal sealed class AddEmbeddingsRequest [JsonPropertyName("metadatas")] public object[]? Metadatas { get; set; } - public static AddEmbeddingsRequest Create(string collectionId, string[] ids, float[][] embeddings, object[]? metadatas = null) + public static UpsertEmbeddingsRequest Create(string collectionId, string[] ids, float[][] embeddings, object[]? metadatas = null) { - return new AddEmbeddingsRequest(collectionId, ids, embeddings, metadatas); + return new UpsertEmbeddingsRequest(collectionId, ids, embeddings, metadatas); } public HttpRequestMessage Build() { - return HttpRequest.CreatePostRequest($"collections/{this.CollectionId}/add", this); + return HttpRequest.CreatePostRequest($"collections/{this.CollectionId}/upsert", this); } #region private ================================================================================ - private AddEmbeddingsRequest(string collectionId, string[] ids, float[][] embeddings, object[]? metadatas = null) + private UpsertEmbeddingsRequest(string collectionId, string[] ids, float[][] embeddings, object[]? metadatas = null) { this.CollectionId = collectionId; this.Ids = ids; diff --git a/dotnet/src/Connectors/Connectors.Memory.Chroma/IChromaClient.cs b/dotnet/src/Connectors/Connectors.Memory.Chroma/IChromaClient.cs index c1e7718548d1..9260bde408f3 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Chroma/IChromaClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Chroma/IChromaClient.cs @@ -42,14 +42,14 @@ public interface IChromaClient IAsyncEnumerable ListCollectionsAsync(CancellationToken cancellationToken = default); /// - /// Adds embedding to specified collection. + /// Upserts embedding to specified collection. /// /// Collection identifier generated by Chroma. /// Array of embedding identifiers. /// Array of embedding vectors. /// Array of embedding metadatas. /// The to monitor for cancellation requests. The default is . - Task AddEmbeddingsAsync(string collectionId, string[] ids, float[][] embeddings, object[]? metadatas = null, CancellationToken cancellationToken = default); + Task UpsertEmbeddingsAsync(string collectionId, string[] ids, float[][] embeddings, object[]? metadatas = null, CancellationToken cancellationToken = default); /// /// Returns embeddings from specified collection. diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexDefinition.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexDefinition.cs index fd1716acd49f..37efe6a3979d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexDefinition.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Model/IndexDefinition.cs @@ -3,7 +3,6 @@ using System.Net.Http; using System.Text; using System.Text.Json.Serialization; -using Microsoft.SemanticKernel.Connectors.Memory.Pinecone.Http; namespace Microsoft.SemanticKernel.Connectors.Memory.Pinecone.Model; diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj b/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj index 66f12ab2a80d..b1972599424e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj @@ -20,6 +20,9 @@ + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs index 59d69fbb0da5..51064a4e3839 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; @@ -13,76 +14,95 @@ namespace Microsoft.SemanticKernel.Connectors.Memory.Postgres; public interface IPostgresDbClient { /// - /// Check if a collection exists. + /// Check if a table exists. /// - /// The name assigned to a collection of entries. + /// The name assigned to a table of entries. /// The to monitor for cancellation requests. The default is . /// - Task DoesCollectionExistsAsync(string collectionName, CancellationToken cancellationToken = default); + Task DoesTableExistsAsync(string tableName, CancellationToken cancellationToken = default); /// - /// Create a collection. + /// Create a table. /// - /// The name assigned to a collection of entries. + /// The name assigned to a table of entries. /// The to monitor for cancellation requests. The default is . /// - Task CreateCollectionAsync(string collectionName, CancellationToken cancellationToken = default); + Task CreateTableAsync(string tableName, CancellationToken cancellationToken = default); /// - /// Get all collections. + /// Get all tables. /// /// The to monitor for cancellation requests. The default is . - /// - IAsyncEnumerable GetCollectionsAsync(CancellationToken cancellationToken = default); + /// A group of tables. + IAsyncEnumerable GetTablesAsync(CancellationToken cancellationToken = default); /// - /// Delete a collection. + /// Delete a table. /// - /// The name assigned to a collection of entries. + /// The name assigned to a table of entries. /// The to monitor for cancellation requests. The default is . /// - Task DeleteCollectionAsync(string collectionName, CancellationToken cancellationToken = default); + Task DeleteTableAsync(string tableName, CancellationToken cancellationToken = default); /// - /// Upsert entry into a collection. + /// Upsert entry into a table. /// - /// The name assigned to a collection of entries. + /// The name assigned to a table of entries. /// The key of the entry to upsert. /// The metadata of the entry. /// The embedding of the entry. - /// The timestamp of the entry + /// The timestamp of the entry. Its 'DateTimeKind' must be /// The to monitor for cancellation requests. The default is . /// - Task UpsertAsync(string collectionName, string key, string? metadata, Vector? embedding, long? timestamp, CancellationToken cancellationToken = default); + Task UpsertAsync(string tableName, string key, string? metadata, Vector? embedding, DateTime? timestamp, CancellationToken cancellationToken = default); /// /// Gets the nearest matches to the . /// - /// The name assigned to a collection of entries. - /// The to compare the collection's embeddings with. + /// The name assigned to a table of entries. + /// The to compare the table's embeddings with. /// The maximum number of similarity results to return. /// The minimum relevance threshold for returned results. /// If true, the embeddings will be returned in the entries. /// The to monitor for cancellation requests. The default is . - /// - IAsyncEnumerable<(PostgresMemoryEntry, double)> GetNearestMatchesAsync(string collectionName, Vector embeddingFilter, int limit, double minRelevanceScore = 0, bool withEmbeddings = false, CancellationToken cancellationToken = default); + /// An asynchronous stream of objects that the nearest matches to the . + IAsyncEnumerable<(PostgresMemoryEntry, double)> GetNearestMatchesAsync(string tableName, Vector embedding, int limit, double minRelevanceScore = 0, bool withEmbeddings = false, CancellationToken cancellationToken = default); /// /// Read a entry by its key. /// - /// The name assigned to a collection of entries. + /// The name assigned to a table of entries. /// The key of the entry to read. - /// If true, the embeddings will be returned in the entries. + /// If true, the embeddings will be returned in the entry. /// The to monitor for cancellation requests. The default is . /// - Task ReadAsync(string collectionName, string key, bool withEmbeddings = false, CancellationToken cancellationToken = default); + Task ReadAsync(string tableName, string key, bool withEmbeddings = false, CancellationToken cancellationToken = default); + + /// + /// Read multiple entries by their keys. + /// + /// The name assigned to a table of entries. + /// The keys of the entries to read. + /// If true, the embeddings will be returned in the entries. + /// The to monitor for cancellation requests. The default is . + /// An asynchronous stream of objects that match the given keys. + IAsyncEnumerable ReadBatchAsync(string tableName, IEnumerable keys, bool withEmbeddings = false, CancellationToken cancellationToken = default); /// /// Delete a entry by its key. /// - /// The name assigned to a collection of entries. + /// The name assigned to a table of entries. /// The key of the entry to delete. /// The to monitor for cancellation requests. The default is . /// - Task DeleteAsync(string collectionName, string key, CancellationToken cancellationToken = default); + Task DeleteAsync(string tableName, string key, CancellationToken cancellationToken = default); + + /// + /// Delete multiple entries by their key. + /// + /// The name assigned to a table of entries. + /// The keys of the entries to delete. + /// The to monitor for cancellation requests. The default is . + /// + Task DeleteBatchAsync(string tableName, IEnumerable keys, CancellationToken cancellationToken = default); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs index e957f01f7668..0ce375618414 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs @@ -2,10 +2,12 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Npgsql; +using NpgsqlTypes; using Pgvector; namespace Microsoft.SemanticKernel.Connectors.Memory.Postgres; @@ -22,220 +24,249 @@ public class PostgresDbClient : IPostgresDbClient /// Postgres data source. /// Schema of collection tables. /// Embedding vector size. - /// Specifies the number of lists for indexing. Higher values can improve recall but may impact performance. The default value is 1000. More info - public PostgresDbClient(NpgsqlDataSource dataSource, string schema, int vectorSize, int numberOfLists) + public PostgresDbClient(NpgsqlDataSource dataSource, string schema, int vectorSize) { this._dataSource = dataSource; this._schema = schema; this._vectorSize = vectorSize; - this._numberOfLists = numberOfLists; } - /// - /// Check if a collection exists. - /// - /// The name assigned to a collection of entries. - /// The to monitor for cancellation requests. The default is . - /// - public async Task DoesCollectionExistsAsync( - string collectionName, - CancellationToken cancellationToken = default) + /// + public async Task DoesTableExistsAsync(string tableName, CancellationToken cancellationToken = default) { - using NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - - using NpgsqlCommand cmd = connection.CreateCommand(); - cmd.CommandText = $@" - SELECT table_name - FROM information_schema.tables - WHERE table_schema = @schema - AND table_type = 'BASE TABLE' - AND table_name = '{collectionName}'"; - cmd.Parameters.AddWithValue("@schema", this._schema); - - using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); - if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) { - return dataReader.GetString(dataReader.GetOrdinal("table_name")) == collectionName; + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = $@" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = @schema + AND table_type = 'BASE TABLE' + AND table_name = '{tableName}'"; + cmd.Parameters.AddWithValue("@schema", this._schema); + + using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + return dataReader.GetString(dataReader.GetOrdinal("table_name")) == tableName; + } + + return false; } - - return false; } - /// - /// Create a collection. - /// - /// The name assigned to a collection of entries. - /// The to monitor for cancellation requests. The default is . - /// - public async Task CreateCollectionAsync(string collectionName, CancellationToken cancellationToken = default) + /// + public async Task CreateTableAsync(string tableName, CancellationToken cancellationToken = default) { - using NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - - await this.CreateTableAsync(connection, collectionName, cancellationToken).ConfigureAwait(false); + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - await this.CreateIndexAsync(connection, collectionName, cancellationToken).ConfigureAwait(false); + await using (connection) + { + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = $@" + CREATE TABLE IF NOT EXISTS {this.GetFullTableName(tableName)} ( + key TEXT NOT NULL, + metadata JSONB, + embedding vector({this._vectorSize}), + timestamp TIMESTAMP WITH TIME ZONE, + PRIMARY KEY (key))"; + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } } - /// - /// Get all collections. - /// - /// The to monitor for cancellation requests. The default is . - /// - public async IAsyncEnumerable GetCollectionsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + /// + public async IAsyncEnumerable GetTablesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) { - using NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - - using NpgsqlCommand cmd = connection.CreateCommand(); - cmd.CommandText = @" - SELECT table_name - FROM information_schema.tables - WHERE table_schema = @schema - AND table_type = 'BASE TABLE'"; - cmd.Parameters.AddWithValue("@schema", this._schema); - - using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); - while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) { - yield return dataReader.GetString(dataReader.GetOrdinal("table_name")); + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = @" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = @schema + AND table_type = 'BASE TABLE'"; + cmd.Parameters.AddWithValue("@schema", this._schema); + + using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + yield return dataReader.GetString(dataReader.GetOrdinal("table_name")); + } } } - /// - /// Delete a collection. - /// - /// The name assigned to a collection of entries. - /// The to monitor for cancellation requests. The default is . - /// - public async Task DeleteCollectionAsync(string collectionName, CancellationToken cancellationToken = default) + /// + public async Task DeleteTableAsync(string tableName, CancellationToken cancellationToken = default) { - using NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - using NpgsqlCommand cmd = connection.CreateCommand(); - cmd.CommandText = $"DROP TABLE IF EXISTS {this.GetTableName(collectionName)}"; + await using (connection) + { + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = $"DROP TABLE IF EXISTS {this.GetFullTableName(tableName)}"; - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } } - /// - /// Upsert entry into a collection. - /// - /// The name assigned to a collection of entries. - /// The key of the entry to upsert. - /// The metadata of the entry. - /// The embedding of the entry. - /// The timestamp of the entry - /// The to monitor for cancellation requests. The default is . - /// - public async Task UpsertAsync(string collectionName, string key, - string? metadata, Vector? embedding, long? timestamp, CancellationToken cancellationToken = default) + /// + public async Task UpsertAsync(string tableName, string key, + string? metadata, Vector? embedding, DateTime? timestamp, CancellationToken cancellationToken = default) { - using NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - - using NpgsqlCommand cmd = connection.CreateCommand(); - cmd.CommandText = $@" - INSERT INTO {this.GetTableName(collectionName)} (key, metadata, embedding, timestamp) - VALUES(@key, @metadata, @embedding, @timestamp) - ON CONFLICT (key) - DO UPDATE SET metadata=@metadata, embedding=@embedding, timestamp=@timestamp"; - cmd.Parameters.AddWithValue("@key", key); - cmd.Parameters.AddWithValue("@metadata", NpgsqlTypes.NpgsqlDbType.Jsonb, metadata ?? (object)DBNull.Value); - cmd.Parameters.AddWithValue("@embedding", embedding ?? (object)DBNull.Value); - cmd.Parameters.AddWithValue("@timestamp", timestamp ?? (object)DBNull.Value); - - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = $@" + INSERT INTO {this.GetFullTableName(tableName)} (key, metadata, embedding, timestamp) + VALUES(@key, @metadata, @embedding, @timestamp) + ON CONFLICT (key) + DO UPDATE SET metadata=@metadata, embedding=@embedding, timestamp=@timestamp"; + cmd.Parameters.AddWithValue("@key", key); + cmd.Parameters.AddWithValue("@metadata", NpgsqlTypes.NpgsqlDbType.Jsonb, metadata ?? (object)DBNull.Value); + cmd.Parameters.AddWithValue("@embedding", embedding ?? (object)DBNull.Value); + cmd.Parameters.AddWithValue("@timestamp", NpgsqlTypes.NpgsqlDbType.TimestampTz, timestamp ?? (object)DBNull.Value); + + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } } - /// - /// Gets the nearest matches to the . - /// - /// The name assigned to a collection of entries. - /// The to compare the collection's embeddings with. - /// The maximum number of similarity results to return. - /// The minimum relevance threshold for returned results. - /// If true, the embeddings will be returned in the entries. - /// The to monitor for cancellation requests. The default is . - /// + /// public async IAsyncEnumerable<(PostgresMemoryEntry, double)> GetNearestMatchesAsync( - string collectionName, Vector embeddingFilter, int limit, double minRelevanceScore = 0, bool withEmbeddings = false, + string tableName, Vector embedding, int limit, double minRelevanceScore = 0, bool withEmbeddings = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - var queryColumns = "key, metadata, timestamp"; + string queryColumns = "key, metadata, timestamp"; if (withEmbeddings) { queryColumns = "*"; } - using NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - using NpgsqlCommand cmd = connection.CreateCommand(); - cmd.CommandText = @$" - SELECT * FROM (SELECT {queryColumns}, 1 - (embedding <=> @embedding) AS cosine_similarity FROM {this.GetTableName(collectionName)} - ) AS sk_memory_cosine_similarity_table - WHERE cosine_similarity >= @min_relevance_score - ORDER BY cosine_similarity DESC - Limit @limit"; - cmd.Parameters.AddWithValue("@embedding", embeddingFilter); - cmd.Parameters.AddWithValue("@collection", collectionName); - cmd.Parameters.AddWithValue("@min_relevance_score", minRelevanceScore); - cmd.Parameters.AddWithValue("@limit", limit); - - using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); - - while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + await using (connection) { - double cosineSimilarity = dataReader.GetDouble(dataReader.GetOrdinal("cosine_similarity")); - yield return (await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false), cosineSimilarity); + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = @$" + SELECT * FROM (SELECT {queryColumns}, 1 - (embedding <=> @embedding) AS cosine_similarity FROM {this.GetFullTableName(tableName)} + ) AS sk_memory_cosine_similarity_table + WHERE cosine_similarity >= @min_relevance_score + ORDER BY cosine_similarity DESC + Limit @limit"; + cmd.Parameters.AddWithValue("@embedding", embedding); + cmd.Parameters.AddWithValue("@collection", tableName); + cmd.Parameters.AddWithValue("@min_relevance_score", minRelevanceScore); + cmd.Parameters.AddWithValue("@limit", limit); + + using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + double cosineSimilarity = dataReader.GetDouble(dataReader.GetOrdinal("cosine_similarity")); + yield return (await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false), cosineSimilarity); + } } } - /// - /// Read a entry by its key. - /// - /// The name assigned to a collection of entries. - /// The key of the entry to read. - /// If true, the embeddings will be returned in the entries. - /// The to monitor for cancellation requests. The default is . - /// - public async Task ReadAsync(string collectionName, string key, + /// + public async Task ReadAsync(string tableName, string key, bool withEmbeddings = false, CancellationToken cancellationToken = default) { - var queryColumns = "key, metadata, timestamp"; + string queryColumns = "key, metadata, timestamp"; if (withEmbeddings) { queryColumns = "*"; } - using NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = $"SELECT {queryColumns} FROM {this.GetFullTableName(tableName)} WHERE key=@key"; + cmd.Parameters.AddWithValue("@key", key); + + using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + return await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false); + } + + return null; + } + } + + /// + public async IAsyncEnumerable ReadBatchAsync(string tableName, IEnumerable keys, bool withEmbeddings = false, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + string[] keysArray = keys.ToArray(); + if (keysArray.Length == 0) + { + yield break; + } + + string queryColumns = "key, metadata, timestamp"; + if (withEmbeddings) + { + queryColumns = "*"; + } - using NpgsqlCommand cmd = connection.CreateCommand(); - cmd.CommandText = $"SELECT {queryColumns} FROM {this.GetTableName(collectionName)} WHERE key=@key"; - cmd.Parameters.AddWithValue("@key", key); + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); - if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + await using (connection) { - return await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false); + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = $"SELECT {queryColumns} FROM {this.GetFullTableName(tableName)} WHERE key=ANY(@keys)"; + cmd.Parameters.AddWithValue("@keys", NpgsqlDbType.Array | NpgsqlDbType.Text, keysArray); + + using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + yield return await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false); + } } + } - return null; + /// + public async Task DeleteAsync(string tableName, string key, CancellationToken cancellationToken = default) + { + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = $"DELETE FROM {this.GetFullTableName(tableName)} WHERE key=@key"; + cmd.Parameters.AddWithValue("@key", key); + + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } } - /// - /// Delete a entry by its key. - /// - /// The name assigned to a collection of entries. - /// The key of the entry to delete. - /// The to monitor for cancellation requests. The default is . - /// - public async Task DeleteAsync(string collectionName, string key, CancellationToken cancellationToken = default) + /// + public async Task DeleteBatchAsync(string tableName, IEnumerable keys, CancellationToken cancellationToken = default) { - using NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + string[] keysArray = keys.ToArray(); + if (keysArray.Length == 0) + { + return; + } - using NpgsqlCommand cmd = connection.CreateCommand(); - cmd.CommandText = $"DELETE FROM {this.GetTableName(collectionName)} WHERE key=@key"; - cmd.Parameters.AddWithValue("@key", key); + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + await using (connection) + { + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = $"DELETE FROM {this.GetFullTableName(tableName)} WHERE key=ANY(@keys)"; + cmd.Parameters.AddWithValue("@keys", NpgsqlDbType.Array | NpgsqlDbType.Text, keysArray); + + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } } #region private ================================================================================ @@ -243,7 +274,6 @@ public async Task DeleteAsync(string collectionName, string key, CancellationTok private readonly NpgsqlDataSource _dataSource; private readonly int _vectorSize; private readonly string _schema; - private readonly int _numberOfLists; /// /// Read a entry. @@ -257,54 +287,18 @@ private async Task ReadEntryAsync(NpgsqlDataReader dataRead string key = dataReader.GetString(dataReader.GetOrdinal("key")); string metadata = dataReader.GetString(dataReader.GetOrdinal("metadata")); Vector? embedding = withEmbeddings ? await dataReader.GetFieldValueAsync(dataReader.GetOrdinal("embedding"), cancellationToken).ConfigureAwait(false) : null; - long? timestamp = await dataReader.GetFieldValueAsync(dataReader.GetOrdinal("timestamp"), cancellationToken).ConfigureAwait(false); + DateTime? timestamp = await dataReader.GetFieldValueAsync(dataReader.GetOrdinal("timestamp"), cancellationToken).ConfigureAwait(false); return new PostgresMemoryEntry() { Key = key, MetadataString = metadata, Embedding = embedding, Timestamp = timestamp }; } /// - /// Create a collection as table. - /// - /// An opened instance. - /// The name assigned to a collection of entries. - /// The to monitor for cancellation requests. The default is . - /// - private async Task CreateTableAsync(NpgsqlConnection connection, string collectionName, CancellationToken cancellationToken = default) - { - using NpgsqlCommand cmd = connection.CreateCommand(); - cmd.CommandText = $@" - CREATE TABLE IF NOT EXISTS {this.GetTableName(collectionName)} ( - key TEXT NOT NULL, - metadata JSONB, - embedding vector({this._vectorSize}), - timestamp BIGINT, - PRIMARY KEY (key))"; - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - } - - /// - /// Create index for collection table. - /// - /// An opened instance. - /// The name assigned to a collection of entries. - /// The to monitor for cancellation requests. The default is . - /// - private async Task CreateIndexAsync(NpgsqlConnection connection, string collectionName, CancellationToken cancellationToken = default) - { - using NpgsqlCommand cmd = connection.CreateCommand(); - cmd.CommandText = $@" - CREATE INDEX IF NOT EXISTS {collectionName}_ix - ON {this.GetTableName(collectionName)} USING ivfflat (embedding vector_cosine_ops) WITH (lists = {this._numberOfLists})"; - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - } - - /// - /// Get table name from collection name. + /// Get full table name with schema from table name. /// - /// + /// /// - private string GetTableName(string collectionName) + private string GetFullTableName(string tableName) { - return $"{this._schema}.\"{collectionName}\""; + return $"{this._schema}.\"{tableName}\""; } #endregion } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresKernelBuilderExtensions.cs index 440dfcd9dcf8..4baa6728c2df 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresKernelBuilderExtensions.cs @@ -18,18 +18,16 @@ public static class PostgresKernelBuilderExtensions /// The instance /// Postgres data source. /// Embedding vector size. - /// Schema of collection tables. - /// Specifies the number of lists for indexing. Higher values can improve recall but may impact performance. The default value is 1000. More info + /// Schema of collection tables. The default value is "public". /// Self instance public static KernelBuilder WithPostgresMemoryStore(this KernelBuilder builder, NpgsqlDataSource dataSource, int vectorSize, - string schema = PostgresMemoryStore.DefaultSchema, - int numberOfLists = PostgresMemoryStore.DefaultNumberOfLists) + string schema = PostgresMemoryStore.DefaultSchema) { builder.WithMemoryStorage((parameters) => { - return new PostgresMemoryStore(dataSource, vectorSize, schema, numberOfLists); + return new PostgresMemoryStore(dataSource, vectorSize, schema); }); return builder; diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryEntry.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryEntry.cs index 29a4b8f31f90..a7429b44c157 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryEntry.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryEntry.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using Pgvector; namespace Microsoft.SemanticKernel.Connectors.Memory.Postgres; @@ -25,7 +26,7 @@ public record struct PostgresMemoryEntry public Vector? Embedding { get; set; } /// - /// Optional timestamp. + /// Optional timestamp. Its 'DateTimeKind' is /// - public long? Timestamp { get; set; } + public DateTime? Timestamp { get; set; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryStore.cs index 269624651ec5..b95ac91f8fb1 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryStore.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System; using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; @@ -23,7 +22,6 @@ namespace Microsoft.SemanticKernel.Connectors.Memory.Postgres; public class PostgresMemoryStore : IMemoryStore { internal const string DefaultSchema = "public"; - internal const int DefaultNumberOfLists = 1000; /// /// Initializes a new instance of the class. @@ -31,9 +29,8 @@ public class PostgresMemoryStore : IMemoryStore /// Postgres data source. /// Embedding vector size. /// Database schema of collection tables. The default value is "public". - /// Specifies the number of lists for indexing. Higher values can improve recall but may impact performance. The default value is 1000. More info - public PostgresMemoryStore(NpgsqlDataSource dataSource, int vectorSize, string schema = DefaultSchema, int numberOfLists = DefaultNumberOfLists) - : this(new PostgresDbClient(dataSource, schema, vectorSize, numberOfLists)) + public PostgresMemoryStore(NpgsqlDataSource dataSource, int vectorSize, string schema = DefaultSchema) + : this(new PostgresDbClient(dataSource, schema, vectorSize)) { } @@ -47,7 +44,7 @@ public async Task CreateCollectionAsync(string collectionName, CancellationToken { Verify.NotNullOrWhiteSpace(collectionName); - await this._postgresDbClient.CreateCollectionAsync(collectionName, cancellationToken).ConfigureAwait(false); + await this._postgresDbClient.CreateTableAsync(collectionName, cancellationToken).ConfigureAwait(false); } /// @@ -55,13 +52,13 @@ public async Task DoesCollectionExistAsync(string collectionName, Cancella { Verify.NotNullOrWhiteSpace(collectionName); - return await this._postgresDbClient.DoesCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false); + return await this._postgresDbClient.DoesTableExistsAsync(collectionName, cancellationToken).ConfigureAwait(false); } /// public async IAsyncEnumerable GetCollectionsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) { - await foreach (var collection in this._postgresDbClient.GetCollectionsAsync(cancellationToken).ConfigureAwait(false)) + await foreach (string collection in this._postgresDbClient.GetTablesAsync(cancellationToken).ConfigureAwait(false)) { yield return collection; } @@ -72,7 +69,7 @@ public async Task DeleteCollectionAsync(string collectionName, CancellationToken { Verify.NotNullOrWhiteSpace(collectionName); - await this._postgresDbClient.DeleteCollectionAsync(collectionName, cancellationToken).ConfigureAwait(false); + await this._postgresDbClient.DeleteTableAsync(collectionName, cancellationToken).ConfigureAwait(false); } /// @@ -89,7 +86,7 @@ public async IAsyncEnumerable UpsertBatchAsync(string collectionName, IE { Verify.NotNullOrWhiteSpace(collectionName); - foreach (var record in records) + foreach (MemoryRecord record in records) { yield return await this.InternalUpsertAsync(collectionName, record, cancellationToken).ConfigureAwait(false); } @@ -100,7 +97,11 @@ public async IAsyncEnumerable UpsertBatchAsync(string collectionName, IE { Verify.NotNullOrWhiteSpace(collectionName); - return await this.InternalGetAsync(collectionName, key, withEmbedding, cancellationToken).ConfigureAwait(false); + PostgresMemoryEntry? entry = await this._postgresDbClient.ReadAsync(collectionName, key, withEmbedding, cancellationToken).ConfigureAwait(false); + + if (!entry.HasValue) { return null; } + + return this.GetMemoryRecordFromEntry(entry.Value); } /// @@ -109,13 +110,9 @@ public async IAsyncEnumerable GetBatchAsync(string collectionName, { Verify.NotNullOrWhiteSpace(collectionName); - foreach (var key in keys) + await foreach (PostgresMemoryEntry entry in this._postgresDbClient.ReadBatchAsync(collectionName, keys, withEmbeddings, cancellationToken).ConfigureAwait(false)) { - var result = await this.InternalGetAsync(collectionName, key, withEmbeddings, cancellationToken).ConfigureAwait(false); - if (result != null) - { - yield return result; - } + yield return this.GetMemoryRecordFromEntry(entry); } } @@ -132,10 +129,7 @@ public async Task RemoveBatchAsync(string collectionName, IEnumerable ke { Verify.NotNullOrWhiteSpace(collectionName); - foreach (var key in keys) - { - await this._postgresDbClient.DeleteAsync(collectionName, key, cancellationToken).ConfigureAwait(false); - } + await this._postgresDbClient.DeleteBatchAsync(collectionName, keys, cancellationToken).ConfigureAwait(false); } /// @@ -155,21 +149,16 @@ public async Task RemoveBatchAsync(string collectionName, IEnumerable ke } IAsyncEnumerable<(PostgresMemoryEntry, double)> results = this._postgresDbClient.GetNearestMatchesAsync( - collectionName: collectionName, - embeddingFilter: new Vector(embedding.Vector.ToArray()), + tableName: collectionName, + embedding: new Vector(embedding.Vector.ToArray()), limit: limit, minRelevanceScore: minRelevanceScore, withEmbeddings: withEmbeddings, cancellationToken: cancellationToken); - await foreach (var (entry, cosineSimilarity) in results.ConfigureAwait(false)) + await foreach ((PostgresMemoryEntry entry, double cosineSimilarity) in results.ConfigureAwait(false)) { - MemoryRecord record = MemoryRecord.FromJsonMetadata( - json: entry.MetadataString, - this.GetEmbeddingForEntry(entry), - entry.Key, - ParseTimestamp(entry.Timestamp)); - yield return (record, cosineSimilarity); + yield return (this.GetMemoryRecordFromEntry(entry), cosineSimilarity); } } @@ -190,52 +179,28 @@ public async Task RemoveBatchAsync(string collectionName, IEnumerable ke private readonly IPostgresDbClient _postgresDbClient; - private static long? ToTimestampLong(DateTimeOffset? timestamp) - { - return timestamp?.ToUnixTimeMilliseconds(); - } - - private static DateTimeOffset? ParseTimestamp(long? timestamp) - { - if (timestamp.HasValue) - { - return DateTimeOffset.FromUnixTimeMilliseconds(timestamp.Value); - } - - return null; - } - private async Task InternalUpsertAsync(string collectionName, MemoryRecord record, CancellationToken cancellationToken) { record.Key = record.Metadata.Id; await this._postgresDbClient.UpsertAsync( - collectionName: collectionName, + tableName: collectionName, key: record.Key, metadata: record.GetSerializedMetadata(), embedding: new Vector(record.Embedding.Vector.ToArray()), - timestamp: ToTimestampLong(record.Timestamp), + timestamp: record.Timestamp?.UtcDateTime, cancellationToken: cancellationToken).ConfigureAwait(false); return record.Key; } - private async Task InternalGetAsync(string collectionName, string key, bool withEmbedding, CancellationToken cancellationToken) + private MemoryRecord GetMemoryRecordFromEntry(PostgresMemoryEntry entry) { - PostgresMemoryEntry? entry = await this._postgresDbClient.ReadAsync(collectionName, key, withEmbedding, cancellationToken).ConfigureAwait(false); - - if (!entry.HasValue) { return null; } - return MemoryRecord.FromJsonMetadata( - json: entry.Value.MetadataString, - embedding: this.GetEmbeddingForEntry(entry.Value), - entry.Value.Key, - ParseTimestamp(entry.Value.Timestamp)); - } - - private Embedding GetEmbeddingForEntry(PostgresMemoryEntry entry) - { - return entry.Embedding != null ? new Embedding(entry.Embedding!.ToArray()) : Embedding.Empty; + json: entry.MetadataString, + embedding: entry.Embedding != null ? new Embedding(entry.Embedding!.ToArray()) : Embedding.Empty, + key: entry.Key, + timestamp: entry.Timestamp?.ToLocalTime()); } #endregion diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md b/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md index 4ba3424224ae..0dd94e053643 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md @@ -35,20 +35,57 @@ sk_demo=# CREATE EXTENSION vector; 3. To use Postgres as a semantic memory store: ```csharp -NpgsqlDataSourceBuilder dataSourceBuilder = new NpgsqlDataSourceBuilder("Host=localhost;Port=5432;Database=sk_memory;User Id=postgres;Password=mysecretpassword"); +NpgsqlDataSourceBuilder dataSourceBuilder = new NpgsqlDataSourceBuilder("Host=localhost;Port=5432;Database=sk_demo;User Id=postgres;Password=mysecretpassword"); dataSourceBuilder.UseVector(); NpgsqlDataSource dataSource = dataSourceBuilder.Build(); -PostgresMemoryStore memoryStore = new PostgresMemoryStore(dataSource, vectorSize: 1536/*, schema: "public", numberOfLists: 1000 */); +PostgresMemoryStore memoryStore = new PostgresMemoryStore(dataSource, vectorSize: 1536/*, schema: "public" */); IKernel kernel = Kernel.Builder .WithLogger(ConsoleLogger.Log) .WithOpenAITextEmbeddingGenerationService("text-embedding-ada-002", Env.Var("OPENAI_API_KEY")) .WithMemoryStorage(memoryStore) - //.WithPostgresMemoryStore(dataSource, vectorSize: 1536, schema: "public", numberOfLists: 1000) // This method offers an alternative approach to registering Postgres memory store. + //.WithPostgresMemoryStore(dataSource, vectorSize: 1536, schema: "public") // This method offers an alternative approach to registering Postgres memory store. .Build(); ``` +### Create Index + +> By default, pgvector performs exact nearest neighbor search, which provides perfect recall. + +> You can add an index to use approximate nearest neighbor search, which trades some recall for performance. Unlike typical indexes, you will see different results for queries after adding an approximate index. + +> Three keys to achieving good recall are: +> - Create the index after the table has some data +> - Choose an appropriate number of lists - a good place to start is rows / 1000 for up to 1M rows and sqrt(rows) for over 1M rows +> - When querying, specify an appropriate number of probes (higher is better for recall, lower is better for speed) - a good place to start is sqrt(lists) + +Please read [the documentation](https://github.com/pgvector/pgvector#indexing) for more information. + +Based on the data rows of your collection table, consider the following statement to create an index. + +```sql +DO $$ +DECLARE + collection TEXT; + c_count INTEGER; +BEGIN + SELECT 'REPLACE YOUR COLLECTION TABLE NAME' INTO collection; + + -- Get count of records in collection + EXECUTE format('SELECT count(*) FROM public.%I;', collection) INTO c_count; + + -- Create Index (https://github.com/pgvector/pgvector#indexing) + IF c_count > 10000000 THEN + EXECUTE format('CREATE INDEX %I ON public.%I USING ivfflat (embedding vector_cosine_ops) WITH (lists = %s);', + collection || '_ix', collection, ROUND(sqrt(c_count))); + ELSIF c_count > 10000 THEN + EXECUTE format('CREATE INDEX %I ON public.%I USING ivfflat (embedding vector_cosine_ops) WITH (lists = %s);', + collection || '_ix', collection, c_count / 1000); + END IF; +END $$; +``` + ## Migration from older versions Since Postgres Memory connector has been re-implemented, the new implementation uses a separate table to store each Collection. @@ -64,6 +101,7 @@ We provide the following migration script to help you migrate to the new structu DO $$ DECLARE r record; + c_count integer; BEGIN FOR r IN SELECT DISTINCT collection FROM sk_memory_table LOOP @@ -75,14 +113,23 @@ BEGIN key TEXT NOT NULL, metadata JSONB, embedding vector(1536), - timestamp BIGINT, + timestamp TIMESTAMP WITH TIME ZONE, PRIMARY KEY (key) );', r.collection); - - -- Create Index (You can modify the size of lists according to your data needs. Its default value is 1000.) - EXECUTE format('CREATE INDEX %I - ON public.%I USING ivfflat (embedding vector_cosine_ops) WITH (lists = 1000);', - r.collection || '_ix', r.collection); + + -- Get count of records in collection + SELECT count(*) INTO c_count FROM sk_memory_table WHERE collection = r.collection AND key <> ''; + + -- Create Index (https://github.com/pgvector/pgvector#indexing) + IF c_count > 10000000 THEN + EXECUTE format('CREATE INDEX %I + ON public.%I USING ivfflat (embedding vector_cosine_ops) WITH (lists = %s);', + r.collection || '_ix', r.collection, ROUND(sqrt(c_count))); + ELSIF c_count > 10000 THEN + EXECUTE format('CREATE INDEX %I + ON public.%I USING ivfflat (embedding vector_cosine_ops) WITH (lists = %s);', + r.collection || '_ix', r.collection, c_count / 1000); + END IF; END LOOP; END $$; @@ -93,7 +140,7 @@ DECLARE BEGIN FOR r IN SELECT DISTINCT collection FROM sk_memory_table LOOP EXECUTE format('INSERT INTO public.%I (key, metadata, embedding, timestamp) - SELECT key, metadata::JSONB, embedding, timestamp + SELECT key, metadata::JSONB, embedding, to_timestamp(timestamp / 1000.0) AT TIME ZONE ''UTC'' FROM sk_memory_table WHERE collection = %L AND key <> '''';', r.collection, r.collection); END LOOP; END $$; diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/QdrantFilter.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/QdrantFilter.cs new file mode 100644 index 000000000000..d9bccce15807 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/QdrantFilter.cs @@ -0,0 +1,225 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Text.Json.Serialization; +using Microsoft.SemanticKernel.Connectors.Memory.Qdrant.Diagnostics; + +namespace Microsoft.SemanticKernel.Connectors.Memory.Qdrant.Http.ApiSchema; + +public sealed class QdrantFilter : IValidatable +{ + [JsonPropertyName("must")] + public List Conditions { get; set; } = new(); + + public void Validate() + { + Verify.NotNull(this.Conditions, "Conditions is NULL"); + foreach (var condition in this.Conditions) + { + if (condition is IValidatable validatable) + { + validatable.Validate(); + } + } + } + + public QdrantFilter Must(params Condition[] conditions) + { + this.Conditions.AddRange(conditions); + return this; + } + + internal QdrantFilter ValueMustMatch(string key, object value) + { + this.Conditions.Add(new MatchCondition + { + Key = key, + Match = new Match { Value = value } + }); + + return this; + } + + internal QdrantFilter CoordinatesWithinRadius(string key, GeoRadius radius) + { + this.Conditions.Add(new GeoRadiusCondition + { + Key = key, + GeoRadius = radius + }); + + return this; + } + + [JsonDerivedType(typeof(MatchCondition))] + [JsonDerivedType(typeof(RangeCondition))] + [JsonDerivedType(typeof(GeoBoundingBoxCondition))] + [JsonDerivedType(typeof(GeoRadiusCondition))] + public abstract class Condition + { + [JsonPropertyName("key")] + public string Key { get; set; } = string.Empty; + } + + public sealed class MatchCondition : Condition, IValidatable + { + [JsonPropertyName("match")] + public Match? Match { get; set; } + + public void Validate() + { + Verify.NotNullOrEmpty(this.Key, "Match key is NULL"); + Verify.NotNull(this.Match, "Match condition is NULL"); + this.Match!.Validate(); + } + } + + public sealed class RangeCondition : Condition, IValidatable + { + [JsonPropertyName("range")] + public Range? Range { get; set; } + + public void Validate() + { + Verify.NotNullOrEmpty(this.Key, "Match key is NULL"); + Verify.NotNull(this.Range, "Range condition is NULL"); + this.Range!.Validate(); + } + } + + public sealed class GeoBoundingBoxCondition : Condition, IValidatable + { + [JsonPropertyName("geo_bounding_box")] + public GeoBoundingBox? GeoBoundingBox { get; set; } + + public void Validate() + { + Verify.NotNullOrEmpty(this.Key, "Match key is NULL"); + Verify.NotNull(this.GeoBoundingBox, "Geo bounding box condition is NULL"); + this.GeoBoundingBox!.Validate(); + } + } + + public sealed class GeoRadiusCondition : Condition, IValidatable + { + [JsonPropertyName("geo_radius")] + public GeoRadius? GeoRadius { get; set; } + + public void Validate() + { + Verify.NotNullOrEmpty(this.Key, "Match key is NULL"); + Verify.NotNull(this.GeoRadius, "Geo radius condition is NULL"); + this.GeoRadius!.Validate(); + } + } + + public sealed class Range : IValidatable + { + [JsonPropertyName("gt")] + public float? GreaterThan { get; set; } + + [JsonPropertyName("gte")] + public float? GreaterThanOrEqual { get; set; } + + [JsonPropertyName("lt")] + public float? LowerThan { get; set; } + + [JsonPropertyName("lte")] + public float? LowerThanOrEqual { get; set; } + + public void Validate() + { + Verify.True( + this.GreaterThan.HasValue || this.GreaterThanOrEqual.HasValue || this.LowerThan.HasValue || this.LowerThanOrEqual.HasValue, + "No range conditions are specified"); + } + } + + public class Match : IValidatable + { + [JsonPropertyName("value")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public object? Value { get; set; } + + [JsonPropertyName("text")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public object? Text { get; set; } + + [JsonPropertyName("any")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? Any { get; set; } + + [JsonPropertyName("except")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? Except { get; set; } + + public void Validate() + { + Verify.True( + this.Value != null || this.Text != null || this.Any != null || this.Except != null, + "No match conditions are specified"); + } + } + + public class GeoRadius : IValidatable + { + public GeoRadius(Coordinates center, float radius) + { + this.Center = center; + this.Radius = radius; + } + + [JsonPropertyName("center")] + public Coordinates Center { get; set; } + + [JsonPropertyName("radius")] + public float Radius { get; set; } + + public void Validate() + { + Verify.NotNull(this.Center, "Geo radius center is NULL"); + } + } + + public class GeoBoundingBox : IValidatable + { + public GeoBoundingBox(Coordinates bottomRight, Coordinates topLeft) + { + this.BottomRight = bottomRight; + this.TopLeft = topLeft; + } + + [JsonPropertyName("bottom_right")] + public Coordinates BottomRight { get; set; } + + [JsonPropertyName("top_left")] + public Coordinates TopLeft { get; set; } + + public void Validate() + { + Verify.NotNull(this.BottomRight, "Geo bounding box bottom right is NULL"); + Verify.NotNull(this.TopLeft, "Geo bounding box top left is NULL"); + } + } + + public class Coordinates : IValidatable + { + public Coordinates(float latitude, float longitude) + { + this.Latitude = latitude; + this.Longitude = longitude; + } + + [JsonPropertyName("lat")] + public float Latitude { get; set; } + + [JsonPropertyName("lon")] + public float Longitude { get; set; } + + public void Validate() + { + Verify.True(this.Latitude >= -90 && this.Latitude <= 90, "Latitude is out of range"); + Verify.True(this.Longitude >= -180 && this.Longitude <= 180, "Longitude is out of range"); + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/SearchVectorsRequest.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/SearchVectorsRequest.cs index 65003e619b2e..ca3e3ef7939d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/SearchVectorsRequest.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Http/ApiSchema/SearchVectorsRequest.cs @@ -13,7 +13,7 @@ internal sealed class SearchVectorsRequest : IValidatable public IEnumerable StartingVector { get; set; } = System.Array.Empty(); [JsonPropertyName("filter")] - public Filter Filters { get; set; } + public QdrantFilter Filters { get; set; } [JsonPropertyName("limit")] public int Limit { get; set; } @@ -68,6 +68,14 @@ public SearchVectorsRequest HavingTags(IEnumerable? tags) return this; } + public SearchVectorsRequest WithFilters(QdrantFilter? filters) + { + if (filters == null) { return this; } + + this.Filters = filters; + return this; + } + public SearchVectorsRequest WithScoreThreshold(double threshold) { this.ScoreThreshold = threshold; @@ -119,74 +127,6 @@ public HttpRequestMessage Build() payload: this); } - internal sealed class Filter : IValidatable - { - internal sealed class Match : IValidatable - { - [JsonPropertyName("value")] - public object Value { get; set; } - - public Match() - { - this.Value = string.Empty; - } - - public void Validate() - { - } - } - - internal sealed class Must : IValidatable - { - [JsonPropertyName("key")] - public string Key { get; set; } - - [JsonPropertyName("match")] - public Match Match { get; set; } - - public Must() - { - this.Match = new(); - this.Key = string.Empty; - } - - public Must(string key, object value) : this() - { - this.Key = key; - this.Match.Value = value; - } - - public void Validate() - { - Verify.NotNull(this.Key, "The filter key is NULL"); - Verify.NotNull(this.Match, "The filter match is NULL"); - } - } - - [JsonPropertyName("must")] - public List Conditions { get; set; } - - internal Filter() - { - this.Conditions = new(); - } - - internal Filter ValueMustMatch(string key, object value) - { - this.Conditions.Add(new Must(key, value)); - return this; - } - - public void Validate() - { - Verify.NotNull(this.Conditions, "Filter conditions are NULL"); - foreach (var x in this.Conditions) - { - x.Validate(); - } - } - } - #region private ================================================================================ private readonly string _collectionName; @@ -194,7 +134,7 @@ public void Validate() private SearchVectorsRequest(string collectionName) { this._collectionName = collectionName; - this.Filters = new Filter(); + this.Filters = new QdrantFilter(); this.WithPayload = false; this.WithVector = false; diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/IQdrantVectorDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/IQdrantVectorDbClient.cs index f426f4dd5b3b..7c215a76d27b 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/IQdrantVectorDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/IQdrantVectorDbClient.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using Microsoft.SemanticKernel.Connectors.Memory.Qdrant.Http.ApiSchema; namespace Microsoft.SemanticKernel.Connectors.Memory.Qdrant; @@ -62,6 +63,7 @@ public IAsyncEnumerable GetVectorsByIdAsync(string collectio /// The name assigned to a collection of vectors. /// The vector to compare the collection's vectors with. /// The minimum relevance threshold for returned results. + /// Filter applied during search. /// The maximum number of similarity results to return. /// Whether to include the vector data in the returned results. /// Qdrant tags used to filter the results. @@ -70,9 +72,10 @@ public IAsyncEnumerable GetVectorsByIdAsync(string collectio string collectionName, IEnumerable target, double threshold, + QdrantFilter? filters = default, int top = 1, bool withVectors = false, - IEnumerable? requiredTags = null, + IEnumerable? requiredTags = default, CancellationToken cancellationToken = default); /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantMemoryStore.cs index 38753c57c81e..88561f89081d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantMemoryStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantMemoryStore.cs @@ -10,6 +10,7 @@ using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel.AI.Embeddings; using Microsoft.SemanticKernel.Connectors.Memory.Qdrant.Diagnostics; +using Microsoft.SemanticKernel.Connectors.Memory.Qdrant.Http.ApiSchema; using Microsoft.SemanticKernel.Memory; namespace Microsoft.SemanticKernel.Connectors.Memory.Qdrant; @@ -20,7 +21,7 @@ namespace Microsoft.SemanticKernel.Connectors.Memory.Qdrant; /// The Embedding data is saved to a Qdrant Vector Database instance specified in the constructor by url and port. /// The embedding data persists between subsequent instances and has similarity search capability. /// -public class QdrantMemoryStore : IMemoryStore +public class QdrantMemoryStore : IMemoryStore { /// /// The Qdrant Vector Database memory store logger. @@ -347,12 +348,57 @@ public async Task RemoveWithPointIdBatchAsync(string collectionName, IEnumerable double minRelevanceScore = 0, bool withEmbeddings = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await foreach (var s in this.GetNearestMatchesAsync( + collectionName: collectionName, + embedding: embedding, + filters: null, + limit: limit, + minRelevanceScore: minRelevanceScore, + withEmbeddings: withEmbeddings, + cancellationToken: cancellationToken)) + { + yield return s; + }; + } + + /// + public async Task<(MemoryRecord, double)?> GetNearestMatchAsync( + string collectionName, + Embedding embedding, + double minRelevanceScore = 0, + bool withEmbedding = false, + CancellationToken cancellationToken = default) + { + var results = this.GetNearestMatchesAsync( + collectionName: collectionName, + embedding: embedding, + minRelevanceScore: minRelevanceScore, + limit: 1, + withEmbeddings: withEmbedding, + cancellationToken: cancellationToken); + + var record = await results.FirstOrDefaultAsync(cancellationToken).ConfigureAwait(false); + + return (record.Item1, record.Item2); + } + + /// + public async IAsyncEnumerable<(MemoryRecord, double)> GetNearestMatchesAsync( + string collectionName, + Embedding embedding, + QdrantFilter? filters, + int limit, + double minRelevanceScore = 0, + bool withEmbeddings = false, + [EnumeratorCancellation] CancellationToken cancellationToken = default) { IAsyncEnumerator<(QdrantVectorRecord, double)> enumerator = this._qdrantClient .FindNearestInCollectionAsync( collectionName: collectionName, target: embedding.Vector, threshold: minRelevanceScore, + filters: filters, top: limit, withVectors: withEmbeddings, cancellationToken: cancellationToken) @@ -393,27 +439,6 @@ public async Task RemoveWithPointIdBatchAsync(string collectionName, IEnumerable } while (hasResult); } - /// - public async Task<(MemoryRecord, double)?> GetNearestMatchAsync( - string collectionName, - Embedding embedding, - double minRelevanceScore = 0, - bool withEmbedding = false, - CancellationToken cancellationToken = default) - { - var results = this.GetNearestMatchesAsync( - collectionName: collectionName, - embedding: embedding, - minRelevanceScore: minRelevanceScore, - limit: 1, - withEmbeddings: withEmbedding, - cancellationToken: cancellationToken); - - var record = await results.FirstOrDefaultAsync(cancellationToken).ConfigureAwait(false); - - return (record.Item1, record.Item2); - } - #region private ================================================================================ private readonly IQdrantVectorDbClient _qdrantClient; @@ -467,6 +492,5 @@ private async Task ConvertFromMemoryRecordAsync( return vectorData; } - #endregion } diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorDbClient.cs index 0c3e634b2896..ff24f18c2d93 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorDbClient.cs @@ -311,9 +311,10 @@ public async Task UpsertVectorsAsync(string collectionName, IEnumerable target, double threshold, + QdrantFilter? filters = default, int top = 1, bool withVectors = false, - IEnumerable? requiredTags = null, + IEnumerable? requiredTags = default, [EnumeratorCancellation] CancellationToken cancellationToken = default) { this._logger.LogDebug("Searching top {0} nearest vectors", top); @@ -324,6 +325,7 @@ public async Task UpsertVectorsAsync(string collectionName, IEnumerableruntime; build; native; contentfiles; analyzers; buildtransitive all + + + + diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresMemoryStoreTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresMemoryStoreTests.cs index 581d623a0c54..e358466ed960 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresMemoryStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresMemoryStoreTests.cs @@ -26,7 +26,7 @@ public PostgresMemoryStoreTests() { this._postgresDbClientMock = new Mock(); this._postgresDbClientMock - .Setup(client => client.DoesCollectionExistsAsync(CollectionName, CancellationToken.None)) + .Setup(client => client.DoesTableExistsAsync(CollectionName, CancellationToken.None)) .ReturnsAsync(true); } @@ -40,7 +40,7 @@ public async Task ItCanCreateCollectionAsync() await store.CreateCollectionAsync(CollectionName); // Assert - this._postgresDbClientMock.Verify(client => client.CreateCollectionAsync(CollectionName, CancellationToken.None), Times.Once()); + this._postgresDbClientMock.Verify(client => client.CreateTableAsync(CollectionName, CancellationToken.None), Times.Once()); } [Fact] @@ -53,7 +53,7 @@ public async Task ItCanDeleteCollectionAsync() await store.DeleteCollectionAsync(CollectionName); // Assert - this._postgresDbClientMock.Verify(client => client.DeleteCollectionAsync(CollectionName, CancellationToken.None), Times.Once()); + this._postgresDbClientMock.Verify(client => client.DeleteTableAsync(CollectionName, CancellationToken.None), Times.Once()); } [Fact] @@ -76,7 +76,7 @@ public async Task ItReturnsFalseWhenCollectionDoesNotExistAsync() const string collectionName = "non-existent-collection"; this._postgresDbClientMock - .Setup(client => client.DoesCollectionExistsAsync(collectionName, CancellationToken.None)) + .Setup(client => client.DoesTableExistsAsync(collectionName, CancellationToken.None)) .ReturnsAsync(false); var store = new PostgresMemoryStore(this._postgresDbClientMock.Object); @@ -192,12 +192,13 @@ public async Task ItCanGetMemoryRecordBatchFromCollectionAsync() .ReturnsAsync(this.GetPostgresMemoryEntryFromMemoryRecord(memoryRecord)); } - var doesNotExistMemoryKey = "fake-record-key"; - this._postgresDbClientMock - .Setup(client => client.ReadAsync(CollectionName, doesNotExistMemoryKey, true, CancellationToken.None)) - .ReturnsAsync((PostgresMemoryEntry?)null); + memoryRecordKeys.Insert(0, "non-existent-record-key-1"); + memoryRecordKeys.Insert(2, "non-existent-record-key-2"); + memoryRecordKeys.Add("non-existent-record-key-3"); - memoryRecordKeys.Add(doesNotExistMemoryKey); + this._postgresDbClientMock + .Setup(client => client.ReadBatchAsync(CollectionName, memoryRecordKeys, true, CancellationToken.None)) + .Returns(expectedMemoryRecords.Select(memoryRecord => this.GetPostgresMemoryEntryFromMemoryRecord(memoryRecord)).ToAsyncEnumerable()); var store = new PostgresMemoryStore(this._postgresDbClientMock.Object); @@ -205,6 +206,7 @@ public async Task ItCanGetMemoryRecordBatchFromCollectionAsync() var actualMemoryRecords = await store.GetBatchAsync(CollectionName, memoryRecordKeys, withEmbeddings: true).ToListAsync(); // Assert + this._postgresDbClientMock.Verify(client => client.ReadBatchAsync(CollectionName, memoryRecordKeys, true, CancellationToken.None), Times.Once()); Assert.Equal(expectedMemoryRecords.Length, actualMemoryRecords.Count); for (var i = 0; i < expectedMemoryRecords.Length; i++) @@ -220,7 +222,7 @@ public async Task ItCanReturnCollectionsAsync() var expectedCollections = new List { "fake-collection-1", "fake-collection-2", "fake-collection-3" }; this._postgresDbClientMock - .Setup(client => client.GetCollectionsAsync(CancellationToken.None)) + .Setup(client => client.GetTablesAsync(CancellationToken.None)) .Returns(expectedCollections.ToAsyncEnumerable()); var store = new PostgresMemoryStore(this._postgresDbClientMock.Object); @@ -262,10 +264,7 @@ public async Task ItCanRemoveBatchAsync() await store.RemoveBatchAsync(CollectionName, memoryRecordKeys); // Assert - foreach (var memoryRecordKey in memoryRecordKeys) - { - this._postgresDbClientMock.Verify(client => client.DeleteAsync(CollectionName, memoryRecordKey, CancellationToken.None), Times.Once()); - } + this._postgresDbClientMock.Verify(client => client.DeleteBatchAsync(CollectionName, memoryRecordKeys, CancellationToken.None), Times.Once()); } #region private ================================================================================ @@ -293,7 +292,8 @@ private MemoryRecord GetRandomMemoryRecord(Embedding? embedding = null) description: "description-" + Guid.NewGuid().ToString(), embedding: memoryEmbedding, additionalMetadata: "metadata-" + Guid.NewGuid().ToString(), - key: id); + key: id, + timestamp: DateTimeOffset.Now); } private PostgresMemoryEntry GetPostgresMemoryEntryFromMemoryRecord(MemoryRecord memoryRecord) @@ -303,7 +303,7 @@ private PostgresMemoryEntry GetPostgresMemoryEntryFromMemoryRecord(MemoryRecord Key = memoryRecord.Key, Embedding = new Pgvector.Vector(memoryRecord.Embedding.Vector.ToArray()), MetadataString = memoryRecord.GetSerializedMetadata(), - Timestamp = memoryRecord.Timestamp?.ToUnixTimeMilliseconds() + Timestamp = memoryRecord.Timestamp?.UtcDateTime }; } diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Qdrant/QdrantKernelBuilderExtensionsTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Qdrant/QdrantKernelBuilderExtensionsTests.cs index aee72a868d39..6e49b9a36828 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Qdrant/QdrantKernelBuilderExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Qdrant/QdrantKernelBuilderExtensionsTests.cs @@ -6,6 +6,8 @@ using System.Text; using System.Threading.Tasks; using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.Memory.Qdrant.Http.ApiSchema; +using Microsoft.SemanticKernel.Memory; using Xunit; namespace SemanticKernel.Connectors.UnitTests.Memory.Qdrant; @@ -41,6 +43,21 @@ public async Task QdrantMemoryStoreShouldBeProperlyInitialized() Assert.Equal("https://fake-random-qdrant-host/collections", this.messageHandlerStub?.RequestUri?.AbsoluteUri); } + [Fact] + public void ItUsesFilterableSemanticTextMemoryWhenUsingQdrantMemoryStore() + { + //Arrange + var builder = new KernelBuilder(); + builder.WithQdrantMemoryStore("https://fake-random-qdrant-host", 123); + builder.WithAzureTextEmbeddingGenerationService("fake-deployment-name", "https://fake-random-text-embedding-generation-host/fake-path", "fake-api-key"); + + //Act + var kernel = builder.Build(); + + //Assert + Assert.IsType>(kernel.Memory); + } + public void Dispose() { this.httpClient.Dispose(); diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Qdrant/QdrantMemoryStoreTests3.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Qdrant/QdrantMemoryStoreTests3.cs index b1c42eb16176..919a6324ea80 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Qdrant/QdrantMemoryStoreTests3.cs +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Qdrant/QdrantMemoryStoreTests3.cs @@ -12,6 +12,7 @@ using Microsoft.SemanticKernel.AI.Embeddings; using Microsoft.SemanticKernel.Connectors.Memory.Pinecone; using Microsoft.SemanticKernel.Connectors.Memory.Qdrant; +using Microsoft.SemanticKernel.Connectors.Memory.Qdrant.Http.ApiSchema; using Microsoft.SemanticKernel.Memory; using Moq; using Moq.Protected; @@ -40,6 +41,7 @@ public async Task GetNearestMatchesAsyncCallsDoNotReturnVectorsUnlessSpecifiedAs It.IsAny(), It.IsAny>(), It.IsAny(), + It.IsAny(), It.IsAny(), It.IsAny(), null, @@ -75,6 +77,7 @@ public async Task GetNearestMatchesAsyncCallsDoNotReturnVectorsUnlessSpecifiedAs It.IsAny(), It.IsAny>(), It.IsAny(), + It.IsAny(), 1, false, null, @@ -84,6 +87,7 @@ public async Task GetNearestMatchesAsyncCallsDoNotReturnVectorsUnlessSpecifiedAs It.IsAny(), It.IsAny>(), It.IsAny(), + It.IsAny(), 1, true, null, @@ -93,6 +97,7 @@ public async Task GetNearestMatchesAsyncCallsDoNotReturnVectorsUnlessSpecifiedAs It.IsAny(), It.IsAny>(), It.IsAny(), + It.IsAny(), 3, false, null, @@ -102,6 +107,7 @@ public async Task GetNearestMatchesAsyncCallsDoNotReturnVectorsUnlessSpecifiedAs It.IsAny(), It.IsAny>(), It.IsAny(), + It.IsAny(), 3, true, null, @@ -119,6 +125,7 @@ public async Task ItReturnsEmptyTupleIfNearestMatchNotFoundAsync() It.IsAny(), It.IsAny>(), It.IsAny(), + It.IsAny(), It.IsAny(), It.IsAny(), null, @@ -138,6 +145,7 @@ public async Task ItReturnsEmptyTupleIfNearestMatchNotFoundAsync() It.IsAny(), It.IsAny>(), It.IsAny(), + It.IsAny(), It.IsAny(), It.IsAny(), null, @@ -171,6 +179,7 @@ public async Task ItWillReturnTheNearestMatchAsATupleAsync() It.IsAny(), It.IsAny>(), It.IsAny(), + It.IsAny(), It.IsAny(), It.IsAny(), null, @@ -190,6 +199,7 @@ public async Task ItWillReturnTheNearestMatchAsATupleAsync() It.IsAny(), It.IsAny>(), It.IsAny(), + It.IsAny(), It.IsAny(), It.IsAny(), null, @@ -213,6 +223,7 @@ public async Task ItReturnsEmptyListIfNearestMatchesNotFoundAsync() It.IsAny(), It.IsAny>(), It.IsAny(), + It.IsAny(), It.IsAny(), It.IsAny(), null, @@ -390,4 +401,45 @@ public async Task ScoredVectorSupportsStringIds() } } } + + [Fact] + public async Task ItPassesQdrantFilterToQdrantClient() + { + // Arrange + var filters = new QdrantFilter(); + var mockQdrantClient = new Mock(); + mockQdrantClient + .Setup>(x => x.FindNearestInCollectionAsync( + It.IsAny(), + It.IsAny>(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + null, + It.IsAny())) + .Returns(AsyncEnumerable.Empty<(QdrantVectorRecord, double)>()); + + var vectorStore = new QdrantMemoryStore(mockQdrantClient.Object, logger: null); + + // Act + await vectorStore.GetNearestMatchesAsync( + collectionName: "test_collection", + embedding: this._embedding, + filters: filters, + limit: 3, + minRelevanceScore: 0.0).ToListAsync(); + + // Assert + mockQdrantClient.Verify>(x => x.FindNearestInCollectionAsync( + It.IsAny(), + It.IsAny>(), + It.IsAny(), + It.Is(qf => qf == filters), + It.IsAny(), + It.IsAny(), + null, + It.IsAny()), + Times.Once()); + } } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Chroma/ChromaMemoryStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Chroma/ChromaMemoryStoreTests.cs index da475d7ba328..92795f02d8a4 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Chroma/ChromaMemoryStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Chroma/ChromaMemoryStoreTests.cs @@ -251,9 +251,9 @@ public async Task ItCanGetNearestMatchAsync() // Arrange var collectionName = this.GetRandomCollectionName(); - var expectedRecord1 = this.GetRandomMemoryRecord(new Embedding(new[] { 10f, 10f, 10f })); - var expectedRecord2 = this.GetRandomMemoryRecord(new Embedding(new[] { 5f, 5f, 5f })); - var expectedRecord3 = this.GetRandomMemoryRecord(new Embedding(new[] { 1f, 1f, 1f })); + var expectedRecord1 = this.GetRandomMemoryRecord(embedding: new Embedding(new[] { 10f, 10f, 10f })); + var expectedRecord2 = this.GetRandomMemoryRecord(embedding: new Embedding(new[] { 5f, 5f, 5f })); + var expectedRecord3 = this.GetRandomMemoryRecord(embedding: new Embedding(new[] { 1f, 1f, 1f })); var searchEmbedding = new Embedding(new[] { 2f, 2f, 2f }); @@ -282,9 +282,9 @@ public async Task ItCanGetNearestMatchesAsync() // Arrange var collectionName = this.GetRandomCollectionName(); - var expectedRecord1 = this.GetRandomMemoryRecord(new Embedding(new[] { 10f, 10f, 10f })); - var expectedRecord2 = this.GetRandomMemoryRecord(new Embedding(new[] { 5f, 5f, 5f })); - var expectedRecord3 = this.GetRandomMemoryRecord(new Embedding(new[] { 1f, 1f, 1f })); + var expectedRecord1 = this.GetRandomMemoryRecord(embedding: new Embedding(new[] { 10f, 10f, 10f })); + var expectedRecord2 = this.GetRandomMemoryRecord(embedding: new Embedding(new[] { 5f, 5f, 5f })); + var expectedRecord3 = this.GetRandomMemoryRecord(embedding: new Embedding(new[] { 1f, 1f, 1f })); var searchEmbedding = new Embedding(new[] { 2f, 2f, 2f }); @@ -330,6 +330,55 @@ public async Task ItReturnsNoMatchesFromEmptyCollection() Assert.Null(nearestMatch.Value.Item1); } + [Fact(Skip = SkipReason)] + public async Task ItCanUpsertSameMemoryRecordMultipleTimesAsync() + { + // Arrange + var collectionName = this.GetRandomCollectionName(); + var expectedRecord = this.GetRandomMemoryRecord(); + + await this._chromaMemoryStore.CreateCollectionAsync(collectionName); + + // Act + await this._chromaMemoryStore.UpsertAsync(collectionName, expectedRecord); + await this._chromaMemoryStore.UpsertAsync(collectionName, expectedRecord); + await this._chromaMemoryStore.UpsertAsync(collectionName, expectedRecord); + + // Assert + var actualRecord = await this._chromaMemoryStore.GetAsync(collectionName, expectedRecord.Key, true); + + Assert.NotNull(actualRecord); + + this.AssertMemoryRecordEqual(expectedRecord, actualRecord); + } + + [Fact(Skip = SkipReason)] + public async Task ItCanUpsertDifferentMemoryRecordsWithSameKeyMultipleTimesAsync() + { + // Arrange + var collectionName = this.GetRandomCollectionName(); + var expectedRecord1 = this.GetRandomMemoryRecord(); + var key = expectedRecord1.Key; + + await this._chromaMemoryStore.CreateCollectionAsync(collectionName); + await this._chromaMemoryStore.UpsertAsync(collectionName, expectedRecord1); + + var actualRecord1 = await this._chromaMemoryStore.GetAsync(collectionName, key, withEmbedding: true); + + Assert.NotNull(actualRecord1); + this.AssertMemoryRecordEqual(expectedRecord1, actualRecord1); + + // Act + var expectedRecord2 = this.GetRandomMemoryRecord(key: key); + await this._chromaMemoryStore.UpsertAsync(collectionName, expectedRecord2); + + // Assert + var actualRecord2 = await this._chromaMemoryStore.GetAsync(collectionName, key, withEmbedding: true); + + Assert.NotNull(actualRecord2); + this.AssertMemoryRecordEqual(expectedRecord2, actualRecord2); + } + public void Dispose() { this.Dispose(true); @@ -366,18 +415,18 @@ private string GetRandomCollectionName() return "sk-test-" + Guid.NewGuid(); } - private MemoryRecord GetRandomMemoryRecord(Embedding? embedding = null) + private MemoryRecord GetRandomMemoryRecord(string? key = null, Embedding? embedding = null) { - var id = Guid.NewGuid().ToString(); - var memoryEmbedding = embedding ?? new Embedding(new[] { 1f, 3f, 5f }); + var recordKey = key ?? Guid.NewGuid().ToString(); + var recordEmbedding = embedding ?? new Embedding(new[] { 1f, 3f, 5f }); return MemoryRecord.LocalRecord( - id: id, + id: recordKey, text: "text-" + Guid.NewGuid().ToString(), description: "description-" + Guid.NewGuid().ToString(), - embedding: memoryEmbedding, + embedding: recordEmbedding, additionalMetadata: "metadata-" + Guid.NewGuid().ToString(), - key: id); + key: recordKey); } #endregion diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj index 56a65c361226..28efab76da42 100644 --- a/dotnet/src/IntegrationTests/IntegrationTests.csproj +++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj @@ -27,6 +27,10 @@ runtime; build; native; contentfiles; analyzers; buildtransitive all + + + + diff --git a/dotnet/src/IntegrationTests/RedirectOutput.cs b/dotnet/src/IntegrationTests/RedirectOutput.cs index 5aa902e6c9a2..64c374c57c78 100644 --- a/dotnet/src/IntegrationTests/RedirectOutput.cs +++ b/dotnet/src/IntegrationTests/RedirectOutput.cs @@ -27,9 +27,9 @@ public override void WriteLine(string? value) this._logs.AppendLine(value); } - public IDisposable? BeginScope(TState state) where TState : notnull + public IDisposable BeginScope(TState state) { - return null; + return null!; } public bool IsEnabled(LogLevel logLevel) diff --git a/dotnet/src/IntegrationTests/XunitLogger.cs b/dotnet/src/IntegrationTests/XunitLogger.cs index a61dffec371f..9118d31676a8 100644 --- a/dotnet/src/IntegrationTests/XunitLogger.cs +++ b/dotnet/src/IntegrationTests/XunitLogger.cs @@ -28,7 +28,7 @@ public void Log(LogLevel logLevel, EventId eventId, TState state, Except public bool IsEnabled(LogLevel logLevel) => true; /// - public IDisposable BeginScope(TState state) where TState : notnull + public IDisposable BeginScope(TState state) => this; /// diff --git a/dotnet/src/SemanticKernel.Abstractions/IKernel.cs b/dotnet/src/SemanticKernel.Abstractions/IKernel.cs index 2426c2b31687..9af659454cad 100644 --- a/dotnet/src/SemanticKernel.Abstractions/IKernel.cs +++ b/dotnet/src/SemanticKernel.Abstractions/IKernel.cs @@ -34,6 +34,11 @@ public interface IKernel /// ISemanticTextMemory Memory { get; } + /// + /// Semantic memory instance with filtering capabilities + /// + ISemanticTextMemory? GetFilterableMemory(); + /// /// Reference to the engine rendering prompt templates /// diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/IMemoryStore.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/IMemoryStore.cs index 997dba6a9d17..8c8882dfd886 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/IMemoryStore.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/IMemoryStore.cs @@ -133,3 +133,30 @@ public interface IMemoryStore bool withEmbedding = false, CancellationToken cancellationToken = default); } + +/// +/// An interface for storing and retrieving indexed objects in a data store with support for metadata filtering. +/// +/// Type of filter used for metadata filtering. +public interface IMemoryStore : IMemoryStore +{ + /// + /// Gets the nearest matches to the of type with payload meeting filtering conditions. Does not guarantee that the collection exists. + /// + /// The name associated with a collection of embeddings. + /// The to compare the collection's embeddings with. + /// Filters to be applied during search. + /// The maximum number of similarity results to return. + /// The minimum relevance threshold for returned results. + /// If true, the embeddings will be returned in the memory records. + /// The to monitor for cancellation requests. The default is . + /// A group of tuples where item1 is a and item2 is its similarity score as a . + IAsyncEnumerable<(MemoryRecord, double)> GetNearestMatchesAsync( + string collectionName, + Embedding embedding, + TFilter filters, + int limit, + double minRelevanceScore = 0.0, + bool withEmbeddings = false, + CancellationToken cancellationToken = default); +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Memory/ISemanticTextMemory.cs b/dotnet/src/SemanticKernel.Abstractions/Memory/ISemanticTextMemory.cs index 76aac625b98c..b1bf769bd867 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Memory/ISemanticTextMemory.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Memory/ISemanticTextMemory.cs @@ -96,3 +96,30 @@ public IAsyncEnumerable SearchAsync( /// A group of collection names. public Task> GetCollectionsAsync(CancellationToken cancellationToken = default); } + +/// +/// An interface for semantic memory that creates and recalls memories associated with text with support for metadata filtering. +/// +/// Type of filter used for metada filtering. +public interface ISemanticTextMemory : ISemanticTextMemory +{ + /// + /// Find some information in memory. + /// + /// Collection to search. + /// What to search for. + /// Filters to be applied during search. + /// How many results to return. + /// Minimum relevance score, from 0 to 1, where 1 means exact match. + /// Whether to return the embeddings of the memories found. + /// The to monitor for cancellation requests. The default is . + /// Memories found. + public IAsyncEnumerable SearchAsync( + string collection, + string query, + TFilter filters, + int limit = 1, + double minRelevanceScore = 0.7, + bool withEmbeddings = false, + CancellationToken cancellationToken = default); +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Orchestration/SKContext.cs b/dotnet/src/SemanticKernel.Abstractions/Orchestration/SKContext.cs index 0b6c873652d6..acf06eedf77a 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Orchestration/SKContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Orchestration/SKContext.cs @@ -104,6 +104,11 @@ public SKContext Fail(string errorDescription, Exception? exception = null) /// public ISemanticTextMemory Memory { get; } + /// + /// Semantic memory with filtering capabilities + /// + public ISemanticTextMemory? GetFilterableMemory() => this.Memory as ISemanticTextMemory; + /// /// Read only skills collection /// diff --git a/dotnet/src/SemanticKernel/Kernel.cs b/dotnet/src/SemanticKernel/Kernel.cs index 11e6bcdc23fe..aff9ae69cce6 100644 --- a/dotnet/src/SemanticKernel/Kernel.cs +++ b/dotnet/src/SemanticKernel/Kernel.cs @@ -45,6 +45,9 @@ public sealed class Kernel : IKernel, IDisposable /// public ISemanticTextMemory Memory => this._memory; + /// + public ISemanticTextMemory? GetFilterableMemory() => this._memory as ISemanticTextMemory; + /// public IReadOnlySkillCollection Skills => this._skillCollection.ReadOnlySkillCollection; diff --git a/dotnet/src/SemanticKernel/Memory/MemoryConfiguration.cs b/dotnet/src/SemanticKernel/Memory/MemoryConfiguration.cs index caa177670924..ff4cb01bea05 100644 --- a/dotnet/src/SemanticKernel/Memory/MemoryConfiguration.cs +++ b/dotnet/src/SemanticKernel/Memory/MemoryConfiguration.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Diagnostics.CodeAnalysis; +using System.Linq; using Microsoft.SemanticKernel.AI.Embeddings; using Microsoft.SemanticKernel.Diagnostics; using Microsoft.SemanticKernel.Memory; @@ -34,12 +34,47 @@ public static void UseMemory(this IKernel kernel, IMemoryStore storage, string? /// Kernel instance /// Embedding generator /// Memory storage - [SuppressMessage("Reliability", "CA2000:Dispose objects before losing scope", Justification = "The embeddingGenerator object is disposed by the kernel")] public static void UseMemory(this IKernel kernel, ITextEmbeddingGeneration embeddingGenerator, IMemoryStore storage) { Verify.NotNull(storage); Verify.NotNull(embeddingGenerator); - kernel.RegisterMemory(new SemanticTextMemory(storage, embeddingGenerator)); + var memory = CreateSemanticTextMemory(storage, embeddingGenerator); + + kernel.RegisterMemory(memory); + } + + /// + /// Create or based on . + /// + /// Memory storage + /// Embedding generator + /// + private static ISemanticTextMemory CreateSemanticTextMemory(IMemoryStore memoryStore, ITextEmbeddingGeneration embeddingGenerator) + { + var memoryStoreType = memoryStore.GetType(); + var filterableMemoryStoreInterfaceType = memoryStoreType + .GetInterfaces() + .SingleOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IMemoryStore<>)); + + if (filterableMemoryStoreInterfaceType != null) + { + var filterType = filterableMemoryStoreInterfaceType + .GetGenericArguments() + .Single(); + + var textEmbeddingGenerationType = typeof(ITextEmbeddingGeneration); + var filterableSemanticTextMemoryType = typeof(SemanticTextMemory<>).MakeGenericType(filterType); + var constructor = filterableSemanticTextMemoryType.GetConstructor(new[] { memoryStoreType, textEmbeddingGenerationType }); + if (constructor == null) + { + throw new System.InvalidOperationException( + $"No {filterableSemanticTextMemoryType.FullName} constructor with parameter types: {memoryStoreType.FullName}, {textEmbeddingGenerationType.FullName} found."); + } + + return (ISemanticTextMemory)constructor.Invoke(new object[] { memoryStore, embeddingGenerator }); + } + + return new SemanticTextMemory(memoryStore, embeddingGenerator); } } diff --git a/dotnet/src/SemanticKernel/Memory/SemanticTextMemory.cs b/dotnet/src/SemanticKernel/Memory/SemanticTextMemory.cs index 57b94fc8262d..dbddf11c45c6 100644 --- a/dotnet/src/SemanticKernel/Memory/SemanticTextMemory.cs +++ b/dotnet/src/SemanticKernel/Memory/SemanticTextMemory.cs @@ -11,9 +11,9 @@ namespace Microsoft.SemanticKernel.Memory; /// -/// Implementation of ./>. +/// Implementation of . /// -public sealed class SemanticTextMemory : ISemanticTextMemory, IDisposable +public class SemanticTextMemory : ISemanticTextMemory, IDisposable { private readonly ITextEmbeddingGeneration _embeddingGenerator; private readonly IMemoryStore _storage; @@ -125,10 +125,69 @@ public async Task> GetCollectionsAsync(CancellationToken cancellat public void Dispose() { - // ReSharper disable once SuspiciousTypeConversion.Global - if (this._embeddingGenerator is IDisposable emb) { emb.Dispose(); } + this.Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (disposing) + { + // ReSharper disable once SuspiciousTypeConversion.Global + if (this._embeddingGenerator is IDisposable emb) { emb.Dispose(); } + + // ReSharper disable once SuspiciousTypeConversion.Global + if (this._storage is IDisposable storage) { storage.Dispose(); } + } + } +} + +/// +/// Implementation of . +/// +/// Type of filter used for metada filtering. +public sealed class SemanticTextMemory : SemanticTextMemory, ISemanticTextMemory +{ + private readonly IMemoryStore _storage; + private readonly ITextEmbeddingGeneration _embeddingGenerator; + + public SemanticTextMemory( + IMemoryStore storage, + ITextEmbeddingGeneration embeddingGenerator) : base(storage, embeddingGenerator) + { + this._embeddingGenerator = embeddingGenerator; + this._storage = storage; + } + + /// + public async IAsyncEnumerable SearchAsync( + string collection, + string query, + TFilter filters, + int limit = 1, + double minRelevanceScore = 0.7, + bool withEmbeddings = false, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Embedding queryEmbedding = await this._embeddingGenerator.GenerateEmbeddingAsync(query, cancellationToken).ConfigureAwait(false); + + IAsyncEnumerable<(MemoryRecord, double)> results = this._storage.GetNearestMatchesAsync( + collectionName: collection, + embedding: queryEmbedding, + filters: filters, + limit: limit, + minRelevanceScore: minRelevanceScore, + withEmbeddings: withEmbeddings, + cancellationToken: cancellationToken); - // ReSharper disable once SuspiciousTypeConversion.Global - if (this._storage is IDisposable storage) { storage.Dispose(); } + await foreach ((MemoryRecord, double) result in results.WithCancellation(cancellationToken)) + { + yield return MemoryQueryResult.FromMemoryRecord(result.Item1, result.Item2); + } + } + + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); } } diff --git a/dotnet/src/SemanticKernel/SkillDefinition/SKFunction.cs b/dotnet/src/SemanticKernel/SkillDefinition/SKFunction.cs index 7bdb35a3b431..83ef515324f7 100644 --- a/dotnet/src/SemanticKernel/SkillDefinition/SKFunction.cs +++ b/dotnet/src/SemanticKernel/SkillDefinition/SKFunction.cs @@ -922,7 +922,7 @@ private static void TrackUniqueParameterType(ref bool hasParameterType, MethodIn // If that fails, try with the invariant culture and allow any exception to propagate. try { - return converter.ConvertFromString(context: null, cultureInfo ?? CultureInfo.CurrentCulture, input); + return converter.ConvertFromString(context: null, cultureInfo, input); } catch (Exception e) when (!e.IsCriticalException() && cultureInfo != CultureInfo.InvariantCulture) { @@ -974,7 +974,7 @@ private static void TrackUniqueParameterType(ref bool hasParameterType, MethodIn return null!; } - return converter.ConvertToString(context: null, cultureInfo ?? CultureInfo.InvariantCulture, input); + return converter.ConvertToString(context: null, cultureInfo, input); }; } diff --git a/dotnet/src/Skills/Skills.UnitTests/XunitHelpers/XunitLogger.cs b/dotnet/src/Skills/Skills.UnitTests/XunitHelpers/XunitLogger.cs index d8e2c929f49f..f2c7e2848c87 100644 --- a/dotnet/src/Skills/Skills.UnitTests/XunitHelpers/XunitLogger.cs +++ b/dotnet/src/Skills/Skills.UnitTests/XunitHelpers/XunitLogger.cs @@ -28,7 +28,7 @@ public void Log(LogLevel logLevel, EventId eventId, TState state, Except public bool IsEnabled(LogLevel logLevel) => true; /// - public IDisposable BeginScope(TState state) where TState : notnull + public IDisposable BeginScope(TState state) => this; /// diff --git a/python/poetry.lock b/python/poetry.lock index b598dad41c13..a29c4d7b0c45 100644 --- a/python/poetry.lock +++ b/python/poetry.lock @@ -1237,13 +1237,13 @@ files = [ [[package]] name = "ipykernel" -version = "6.23.3" +version = "6.24.0" description = "IPython Kernel for Jupyter" optional = false python-versions = ">=3.8" files = [ - {file = "ipykernel-6.23.3-py3-none-any.whl", hash = "sha256:bc00662dc44d4975b668cdb5fefb725e38e9d8d6e28441a519d043f38994922d"}, - {file = "ipykernel-6.23.3.tar.gz", hash = "sha256:dd4e18116357f36a1e459b3768412371bee764c51844cbf25c4ed1eb9cae4a54"}, + {file = "ipykernel-6.24.0-py3-none-any.whl", hash = "sha256:2f5fffc7ad8f1fd5aadb4e171ba9129d9668dbafa374732cf9511ada52d6547f"}, + {file = "ipykernel-6.24.0.tar.gz", hash = "sha256:29cea0a716b1176d002a61d0b0c851f34536495bc4ef7dd0222c88b41b816123"}, ] [package.dependencies] @@ -2892,28 +2892,28 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "ruff" -version = "0.0.275" +version = "0.0.277" description = "An extremely fast Python linter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.0.275-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:5e6554a072e7ce81eb6f0bec1cebd3dcb0e358652c0f4900d7d630d61691e914"}, - {file = "ruff-0.0.275-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:1cc599022fe5ffb143a965b8d659eb64161ab8ab4433d208777eab018a1aab67"}, - {file = "ruff-0.0.275-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5206fc1cd8c1c1deadd2e6360c0dbcd690f1c845da588ca9d32e4a764a402c60"}, - {file = "ruff-0.0.275-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0c4e6468da26f77b90cae35319d310999f471a8c352998e9b39937a23750149e"}, - {file = "ruff-0.0.275-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0dbdea02942131dbc15dd45f431d152224f15e1dd1859fcd0c0487b658f60f1a"}, - {file = "ruff-0.0.275-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:22efd9f41af27ef8fb9779462c46c35c89134d33e326c889971e10b2eaf50c63"}, - {file = "ruff-0.0.275-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2c09662112cfa22d7467a19252a546291fd0eae4f423e52b75a7a2000a1894db"}, - {file = "ruff-0.0.275-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80043726662144876a381efaab88841c88e8df8baa69559f96b22d4fa216bef1"}, - {file = "ruff-0.0.275-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5859ee543b01b7eb67835dfd505faa8bb7cc1550f0295c92c1401b45b42be399"}, - {file = "ruff-0.0.275-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:c8ace4d40a57b5ea3c16555f25a6b16bc5d8b2779ae1912ce2633543d4e9b1da"}, - {file = "ruff-0.0.275-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8347fc16aa185aae275906c4ac5b770e00c896b6a0acd5ba521f158801911998"}, - {file = "ruff-0.0.275-py3-none-musllinux_1_2_i686.whl", hash = "sha256:ec43658c64bfda44fd84bbea9da8c7a3b34f65448192d1c4dd63e9f4e7abfdd4"}, - {file = "ruff-0.0.275-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:508b13f7ca37274cceaba4fb3ea5da6ca192356323d92acf39462337c33ad14e"}, - {file = "ruff-0.0.275-py3-none-win32.whl", hash = "sha256:6afb1c4422f24f361e877937e2a44b3f8176774a476f5e33845ebfe887dd5ec2"}, - {file = "ruff-0.0.275-py3-none-win_amd64.whl", hash = "sha256:d9b264d78621bf7b698b6755d4913ab52c19bd28bee1a16001f954d64c1a1220"}, - {file = "ruff-0.0.275-py3-none-win_arm64.whl", hash = "sha256:a19ce3bea71023eee5f0f089dde4a4272d088d5ac0b675867e074983238ccc65"}, - {file = "ruff-0.0.275.tar.gz", hash = "sha256:a63a0b645da699ae5c758fce19188e901b3033ec54d862d93fcd042addf7f38d"}, + {file = "ruff-0.0.277-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:3250b24333ef419b7a232080d9724ccc4d2da1dbbe4ce85c4caa2290d83200f8"}, + {file = "ruff-0.0.277-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:3e60605e07482183ba1c1b7237eca827bd6cbd3535fe8a4ede28cbe2a323cb97"}, + {file = "ruff-0.0.277-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7baa97c3d7186e5ed4d5d4f6834d759a27e56cf7d5874b98c507335f0ad5aadb"}, + {file = "ruff-0.0.277-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:74e4b206cb24f2e98a615f87dbe0bde18105217cbcc8eb785bb05a644855ba50"}, + {file = "ruff-0.0.277-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:479864a3ccd8a6a20a37a6e7577bdc2406868ee80b1e65605478ad3b8eb2ba0b"}, + {file = "ruff-0.0.277-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:468bfb0a7567443cec3d03cf408d6f562b52f30c3c29df19927f1e0e13a40cd7"}, + {file = "ruff-0.0.277-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f32ec416c24542ca2f9cc8c8b65b84560530d338aaf247a4a78e74b99cd476b4"}, + {file = "ruff-0.0.277-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:14a7b2f00f149c5a295f188a643ac25226ff8a4d08f7a62b1d4b0a1dc9f9b85c"}, + {file = "ruff-0.0.277-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9879f59f763cc5628aa01c31ad256a0f4dc61a29355c7315b83c2a5aac932b5"}, + {file = "ruff-0.0.277-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:f612e0a14b3d145d90eb6ead990064e22f6f27281d847237560b4e10bf2251f3"}, + {file = "ruff-0.0.277-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:323b674c98078be9aaded5b8b51c0d9c424486566fb6ec18439b496ce79e5998"}, + {file = "ruff-0.0.277-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3a43fbe026ca1a2a8c45aa0d600a0116bec4dfa6f8bf0c3b871ecda51ef2b5dd"}, + {file = "ruff-0.0.277-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:734165ea8feb81b0d53e3bf523adc2413fdb76f1264cde99555161dd5a725522"}, + {file = "ruff-0.0.277-py3-none-win32.whl", hash = "sha256:88d0f2afb2e0c26ac1120e7061ddda2a566196ec4007bd66d558f13b374b9efc"}, + {file = "ruff-0.0.277-py3-none-win_amd64.whl", hash = "sha256:6fe81732f788894a00f6ade1fe69e996cc9e485b7c35b0f53fb00284397284b2"}, + {file = "ruff-0.0.277-py3-none-win_arm64.whl", hash = "sha256:2d4444c60f2e705c14cd802b55cd2b561d25bf4311702c463a002392d3116b22"}, + {file = "ruff-0.0.277.tar.gz", hash = "sha256:2dab13cdedbf3af6d4427c07f47143746b6b95d9e4a254ac369a0edb9280a0d2"}, ] [[package]] @@ -4030,4 +4030,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "18d8ee6959d05dd3cda4d33d69aa0baf2f9de98ac792bb54a119c2b7d57444a0" +content-hash = "9ae4e4236289823b370e24a704d54c886e213a81e03a96bb6381db4e24442ebe" diff --git a/python/pyproject.toml b/python/pyproject.toml index 402575c0a77a..6c6998c71bc3 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -19,7 +19,7 @@ pre-commit = "3.3.3" black = {version = "23.3.0", allow-prereleases = true} ipykernel = "^6.21.1" pytest = "7.4.0" -ruff = "0.0.275" +ruff = "0.0.277" pytest-asyncio = "0.21.0" [tool.poetry.group.hugging_face.dependencies] diff --git a/python/semantic_kernel/memory/semantic_text_memory.py b/python/semantic_kernel/memory/semantic_text_memory.py index 0d6846fb8992..fce46f34e01f 100644 --- a/python/semantic_kernel/memory/semantic_text_memory.py +++ b/python/semantic_kernel/memory/semantic_text_memory.py @@ -56,7 +56,9 @@ async def save_information_async( ): await self._storage.create_collection_async(collection_name=collection) - embedding = await self._embeddings_generator.generate_embeddings_async([text]) + embedding = ( + await self._embeddings_generator.generate_embeddings_async([text]) + )[0] data = MemoryRecord.local_record( id=id, text=text, @@ -94,7 +96,9 @@ async def save_reference_async( ): await self._storage.create_collection_async(collection_name=collection) - embedding = await self._embeddings_generator.generate_embeddings_async([text]) + embedding = ( + await self._embeddings_generator.generate_embeddings_async([text]) + )[0] data = MemoryRecord.reference_record( external_id=external_id, source_name=external_source_name, @@ -142,9 +146,9 @@ async def search_async( Returns: List[MemoryQueryResult] -- The list of MemoryQueryResult found. """ - query_embedding = await self._embeddings_generator.generate_embeddings_async( - [query] - ) + query_embedding = ( + await self._embeddings_generator.generate_embeddings_async([query]) + )[0] results = await self._storage.get_nearest_matches_async( collection_name=collection, embedding=query_embedding, diff --git a/python/tests/integration/connectors/memory/test_pinecone.py b/python/tests/integration/connectors/memory/test_pinecone.py index 6e0427b940fa..f73a91617aea 100644 --- a/python/tests/integration/connectors/memory/test_pinecone.py +++ b/python/tests/integration/connectors/memory/test_pinecone.py @@ -25,12 +25,12 @@ def get_pinecone_config(): if "Python_Integration_Tests" in os.environ: api_key = os.environ["Pinecone__ApiKey"] - org_id = None + environment = os.environ["Pinecone__Environment"] else: # Load credentials from .env file - api_key, org_id = sk.pinecone_settings_from_dot_env() + api_key, environment = sk.pinecone_settings_from_dot_env() - return api_key, org_id + return api_key, environment @pytest.fixture diff --git a/samples/apps/.eslintrc.js b/samples/apps/.eslintrc.js index 7b4da7967410..3755ebea6fee 100644 --- a/samples/apps/.eslintrc.js +++ b/samples/apps/.eslintrc.js @@ -1,46 +1,33 @@ module.exports = { env: { - browser: true, es2021: true, }, - extends: ['plugin:react/recommended', 'standard-with-typescript'], - ignorePatterns: ['build', '.*.js', '*.config.js', 'node_modules'], - overrides: [], + extends: [ + 'eslint:recommended', + 'plugin:react/recommended', + 'plugin:react-hooks/recommended', + 'plugin:@typescript-eslint/recommended', + 'plugin:@typescript-eslint/recommended-requiring-type-checking', + 'plugin:@typescript-eslint/strict', + ], + ignorePatterns: ['build', '.*.js', 'node_modules'], parserOptions: { project: './tsconfig.json', ecmaVersion: 'latest', sourceType: 'module', }, - plugins: ['react', '@typescript-eslint', 'import', 'react-hooks', 'react-security'], rules: { - '@typescript-eslint/brace-style': ['off'], - '@typescript-eslint/space-before-function-paren': [ - 'error', - { anonymous: 'always', named: 'never', asyncArrow: 'always' }, - ], - '@typescript-eslint/semi': ['error', 'always'], + '@typescript-eslint/array-type': ['error', { default: 'array-simple' }], '@typescript-eslint/triple-slash-reference': ['error', { types: 'prefer-import' }], - '@typescript-eslint/indent': ['off'], - '@typescript-eslint/comma-dangle': ['error', 'always-multiline'], + '@typescript-eslint/non-nullable-type-assertion-style': 'off', '@typescript-eslint/strict-boolean-expressions': 'off', - '@typescript-eslint/member-delimiter-style': [ - 'error', - { - multiline: { - delimiter: 'semi', - requireLast: true, - }, - singleline: { - delimiter: 'semi', - requireLast: false, - }, - }, - ], '@typescript-eslint/explicit-function-return-type': 'off', - 'react/jsx-props-no-spreading': 'warn', - 'react-hooks/rules-of-hooks': 'error', - 'react-hooks/exhaustive-deps': 'warn', + '@typescript-eslint/consistent-type-imports': 'off', + '@typescript-eslint/no-empty-function': 'off', + '@typescript-eslint/no-explicit-any': 'off', 'react/react-in-jsx-scope': 'off', + 'react/prop-types': 'off', + 'react/jsx-props-no-spreading': 'off', }, settings: { react: { diff --git a/samples/apps/copilot-chat-app/deploy/deploy-webapi.sh b/samples/apps/copilot-chat-app/deploy/deploy-webapi.sh index ae4658f59ea1..6e0ea511e588 100644 --- a/samples/apps/copilot-chat-app/deploy/deploy-webapi.sh +++ b/samples/apps/copilot-chat-app/deploy/deploy-webapi.sh @@ -5,13 +5,13 @@ set -e usage() { - echo "Usage: $0 -d DEPLOYMENT_NAME -s SUBSCRIPTION --ai AI_SERVICE_TYPE -aikey AI_SERVICE_KEY [OPTIONS]" + echo "Usage: $0 -d DEPLOYMENT_NAME -s SUBSCRIPTION -rg RESOURCE_GROUP [OPTIONS]" echo "" echo "Arguments:" - echo " -s, --subscription SUBSCRIPTION Subscription to which to make the deployment (mandatory)" + echo " -d, --deployment-name DEPLOYMENT_NAME Name of the deployment from a 'deploy-azure.sh' deployment (mandatory)" + echo " -s, --subscription SUBSCRIPTION Subscription to which to make the deployment (mandatory)" echo " -rg, --resource-group RESOURCE_GROUP Resource group name from a 'deploy-azure.sh' deployment (mandatory)" - echo " -d, --deployment-name DEPLOYMENT_NAME Name of the deployment from a 'deploy-azure.sh' deployment (mandatory)" - echo " -p, --package PACKAGE_FILE_PATH Path to the WebAPI package file from a 'package-webapi.sh' run (mandatory)" + echo " -p, --package PACKAGE_FILE_PATH Path to the WebAPI package file from a 'package-webapi.sh' run" } # Parse arguments @@ -47,11 +47,14 @@ while [[ $# -gt 0 ]]; do done # Check mandatory arguments -if [[ -z "$DEPLOYMENT_NAME" ]] || [[ -z "$SUBSCRIPTION" ]] || [[ -z "$RESOURCE_GROUP" ]] || [[ -z "$PACKAGE_FILE_PATH" ]]; then +if [[ -z "$DEPLOYMENT_NAME" ]] || [[ -z "$SUBSCRIPTION" ]] || [[ -z "$RESOURCE_GROUP" ]]; then usage exit 1 fi +# Set defaults +: "${PACKAGE_FILE_PATH:="$(dirname "$0")/out/webapi.zip"}" + # Ensure $PACKAGE_FILE_PATH exists if [[ ! -f "$PACKAGE_FILE_PATH" ]]; then echo "Package file '$PACKAGE_FILE_PATH' does not exist. Have you run 'package-webapi.sh' yet?" diff --git a/samples/apps/copilot-chat-app/deploy/main.bicep b/samples/apps/copilot-chat-app/deploy/main.bicep index d34054442f45..7307dea84fc1 100644 --- a/samples/apps/copilot-chat-app/deploy/main.bicep +++ b/samples/apps/copilot-chat-app/deploy/main.bicep @@ -208,6 +208,10 @@ resource appServiceWebConfig 'Microsoft.Web/sites/config@2022-09-01' = { name: 'ChatStore:Cosmos:ChatMemorySourcesContainer' value: 'chatmemorysources' } + { + name: 'ChatStore:Cosmos:ChatParticipantsContainer' + value: 'chatparticipants' + } { name: 'ChatStore:Cosmos:ConnectionString' value: deployCosmosDB ? cosmosAccount.listConnectionStrings().connectionStrings[0].connectionString : '' diff --git a/samples/apps/copilot-chat-app/deploy/main.json b/samples/apps/copilot-chat-app/deploy/main.json index 137b136aa886..1f75396319fe 100644 --- a/samples/apps/copilot-chat-app/deploy/main.json +++ b/samples/apps/copilot-chat-app/deploy/main.json @@ -4,8 +4,8 @@ "metadata": { "_generator": { "name": "bicep", - "version": "0.17.1.54307", - "templateHash": "14923066769528474387" + "version": "0.16.2.56959", + "templateHash": "18037485528098010448" } }, "parameters": { @@ -312,6 +312,10 @@ "name": "ChatStore:Cosmos:ChatMemorySourcesContainer", "value": "chatmemorysources" }, + { + "name": "ChatStore:Cosmos:ChatParticipantsContainer", + "value": "chatparticipants" + }, { "name": "ChatStore:Cosmos:ConnectionString", "value": "[if(parameters('deployCosmosDB'), listConnectionStrings(resourceId('Microsoft.DocumentDB/databaseAccounts', toLower(format('cosmos-{0}', variables('uniqueName')))), '2023-04-15').connectionStrings[0].connectionString, '')]" diff --git a/samples/apps/copilot-chat-app/webapi/ConfigurationExtensions.cs b/samples/apps/copilot-chat-app/webapi/ConfigurationExtensions.cs index 5edc161465c4..81e7bf3697a4 100644 --- a/samples/apps/copilot-chat-app/webapi/ConfigurationExtensions.cs +++ b/samples/apps/copilot-chat-app/webapi/ConfigurationExtensions.cs @@ -37,7 +37,7 @@ public static IHostBuilder AddConfiguration(this IHostBuilder host) reloadOnChange: true); // For settings from Key Vault, see https://learn.microsoft.com/en-us/aspnet/core/security/key-vault-configuration?view=aspnetcore-8.0 - string? keyVaultUri = builderContext.Configuration["KeyVaultUri"]; + string? keyVaultUri = builderContext.Configuration["Service:KeyVault"]; if (!string.IsNullOrWhiteSpace(keyVaultUri)) { configBuilder.AddAzureKeyVault( diff --git a/samples/apps/copilot-chat-app/webapi/Options/ServiceOptions.cs b/samples/apps/copilot-chat-app/webapi/Options/ServiceOptions.cs index 58c05547297b..31f7d6d81ef2 100644 --- a/samples/apps/copilot-chat-app/webapi/Options/ServiceOptions.cs +++ b/samples/apps/copilot-chat-app/webapi/Options/ServiceOptions.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System; using System.ComponentModel.DataAnnotations; namespace SemanticKernel.Service.Options; @@ -16,7 +15,7 @@ public class ServiceOptions /// Configuration Key Vault URI /// [Url] - public Uri? KeyVaultUri { get; set; } + public string? KeyVault { get; set; } /// /// Local directory in which to load semantic skills. diff --git a/samples/apps/copilot-chat-app/webapi/appsettings.json b/samples/apps/copilot-chat-app/webapi/appsettings.json index a48432da495f..224678239fdc 100644 --- a/samples/apps/copilot-chat-app/webapi/appsettings.json +++ b/samples/apps/copilot-chat-app/webapi/appsettings.json @@ -17,7 +17,7 @@ // "Service": { // "SemanticSkillsDirectory": "", - // "KeyVaultUri": "" + // "KeyVault": "" }, // diff --git a/samples/apps/copilot-chat-app/webapp/.env.local b/samples/apps/copilot-chat-app/webapp/.env.local deleted file mode 100644 index 146010753578..000000000000 --- a/samples/apps/copilot-chat-app/webapp/.env.local +++ /dev/null @@ -1,2 +0,0 @@ -# Disable ESLint in all environments -DISABLE_ESLINT_PLUGIN=true \ No newline at end of file diff --git a/samples/apps/copilot-chat-app/webapp/package.json b/samples/apps/copilot-chat-app/webapp/package.json index cb4630846866..f6fd8dc07e17 100644 --- a/samples/apps/copilot-chat-app/webapp/package.json +++ b/samples/apps/copilot-chat-app/webapp/package.json @@ -8,6 +8,7 @@ "auth:mac": "better-vsts-npm-auth -config .npmrc", "depcheck": "depcheck --ignores=\"@types/*,typescript\" --ignore-dirs=\".vscode,.vs,.git,node_modules\" --skip-missing", "lint": "eslint src", + "lint:fix": "eslint src --fix", "prettify": "prettier --write \"src/**/*.{ts,tsx,js,jsx,json,scss,css,html,svg}\"", "serve": "serve -s build", "start": "react-scripts start", @@ -44,10 +45,17 @@ "http-server": "^14.1.1", "prettier": "^2.8.1", "serve": "^14.2.0", - "typescript": "*", + "typescript": "5.0.4", "vsts-npm-auth": "^0.42.1", "workbox-window": "^6.5.4" }, + "eslintConfig": { + "extends": [ + "react-app", + "react-app/jest", + "../../.eslintrc.js" + ] + }, "browserslist": { "production": [ ">0.2%", diff --git a/samples/apps/copilot-chat-app/webapp/src/App.tsx b/samples/apps/copilot-chat-app/webapp/src/App.tsx index 4935ec2cc462..e1fbb3268803 100644 --- a/samples/apps/copilot-chat-app/webapp/src/App.tsx +++ b/samples/apps/copilot-chat-app/webapp/src/App.tsx @@ -1,11 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -import { - AuthenticatedTemplate, - UnauthenticatedTemplate, - useIsAuthenticated, - useMsal, -} from '@azure/msal-react'; +import { AuthenticatedTemplate, UnauthenticatedTemplate, useIsAuthenticated, useMsal } from '@azure/msal-react'; import { Subtitle1, makeStyles, shorthands, tokens } from '@fluentui/react-components'; import { Alert } from '@fluentui/react-components/unstable'; import { Dismiss16Regular } from '@fluentui/react-icons'; @@ -79,21 +74,19 @@ const App: FC = () => { dispatch(setLoggedInUserId(account.homeAccountId)); if (appState === AppState.LoadingChats) { - // Load all chats from the backend. - async function loadChats() { - if (await chat.loadChats()) { + // Load all chats from memory + void chat.loadChats().then((succeeded) => { + if (succeeded) { setAppState(AppState.Chat); } - } - - loadChats(); + }); } } // eslint-disable-next-line react-hooks/exhaustive-deps }, [instance, inProgress, isAuthenticated, appState]); - const onDismissAlert = (key: string) => { - dispatch(removeAlert(key)); + const onDismissAlert = (index: number) => { + dispatch(removeAlert(index)); }; // TODO: handle error case of missing account information @@ -114,34 +107,40 @@ const App: FC = () => { Copilot Chat
- setAppState(AppState.SigningOut)} /> + { + setAppState(AppState.SigningOut); + }} + />
- {alerts && - Object.keys(alerts).map((key) => { - const alert = alerts[key]; - return ( - onDismissAlert(key)} - color="black" - /> - ), - }} - key={key} - > - {alert.message} - - ); - })} + {alerts.map(({ type, message }, index) => { + return ( + { + onDismissAlert(index); + }} + color="black" + /> + ), + }} + key={`${index}-${type}`} + > + {message} + + ); + })} {appState === AppState.ProbeForBackend && ( setAppState(AppState.LoadingChats)} + onBackendFound={() => { + setAppState(AppState.LoadingChats); + }} /> )} {appState === AppState.LoadingChats && } diff --git a/samples/apps/copilot-chat-app/webapp/src/components/FileUploader.tsx b/samples/apps/copilot-chat-app/webapp/src/components/FileUploader.tsx index 474e2e20cd73..dbeb3c771e53 100644 --- a/samples/apps/copilot-chat-app/webapp/src/components/FileUploader.tsx +++ b/samples/apps/copilot-chat-app/webapp/src/components/FileUploader.tsx @@ -18,7 +18,7 @@ export const FileUploader: React.FC = forwardRef) => { + (event: React.SyntheticEvent) => { const target = event.target as HTMLInputElement; const selectedFiles = target.files; event.stopPropagation(); @@ -50,3 +50,5 @@ export const FileUploader: React.FC = forwardRef; - onSubmit: (options: GetResponseOptions) => void; + onSubmit: (options: GetResponseOptions) => Promise; } export const ChatInput: React.FC = ({ isDraggingOver, onDragLeave, onSubmit }) => { @@ -89,17 +89,21 @@ export const ChatInput: React.FC = ({ isDraggingOver, onDragLeav React.useEffect(() => { async function initSpeechRecognizer() { const speechService = new SpeechService(process.env.REACT_APP_BACKEND_URI as string); - var response = await speechService.getSpeechTokenAsync( + const response = await speechService.getSpeechTokenAsync( await AuthHelper.getSKaaSAccessToken(instance, inProgress), ); if (response.isSuccess) { - const recognizer = await speechService.getSpeechRecognizerAsyncWithValidKey(response); + const recognizer = speechService.getSpeechRecognizerAsyncWithValidKey(response); setRecognizer(recognizer); } } - initSpeechRecognizer(); - }, [instance, inProgress]); + initSpeechRecognizer().catch((e) => { + const errorDetails = e instanceof Error ? e.message : String(e); + const errorMessage = `Unable to initialize speech recognizer. Details: ${errorDetails}`; + dispatch(addAlert({ message: errorMessage, type: AlertType.Error })); + }); + }, [dispatch, instance, inProgress]); React.useEffect(() => { const chatState = conversations[selectedId]; @@ -120,28 +124,30 @@ export const ChatInput: React.FC = ({ isDraggingOver, onDragLeav } }; - const handleImport = async (dragAndDropFile?: File) => { - setDocumentImporting(true); + const handleImport = (dragAndDropFile?: File) => { const file = dragAndDropFile ?? documentFileRef.current?.files?.[0]; if (file) { - await chat.importDocument(selectedId, file); + setDocumentImporting(true); + chat.importDocument(selectedId, file).finally(() => { + setDocumentImporting(false); + }); } - setDocumentImporting(false); // Reset the file input so that the onChange event will // be triggered even if the same file is selected again. - documentFileRef.current!.value = ''; + if (documentFileRef.current?.value) { + documentFileRef.current.value = ''; + } }; const handleSubmit = (value: string, messageType: ChatMessageType = ChatMessageType.Message) => { - try { - if (value.trim() === '') { - return; // only submit if value is not empty - } - onSubmit({ value, messageType, chatId: selectedId }); - setValue(''); - dispatch(editConversationInput({ id: selectedId, newInput: '' })); - } catch (error) { + if (value.trim() === '') { + return; // only submit if value is not empty + } + + setValue(''); + dispatch(editConversationInput({ id: selectedId, newInput: '' })); + onSubmit({ value, messageType, chatId: selectedId }).catch((error) => { const message = `Error submitting chat input: ${(error as Error).message}`; log(message); dispatch( @@ -150,17 +156,19 @@ export const ChatInput: React.FC = ({ isDraggingOver, onDragLeav message, }), ); - } + }); }; - const handleDrop = async (e: React.DragEvent) => { + const handleDrop = (e: React.DragEvent) => { onDragLeave(e); - await handleImport(e.dataTransfer?.files[0]); + handleImport(e.dataTransfer.files[0]); }; return (
-
+
+ +