diff --git a/src/Index.cs b/src/Index.cs
index 0d4258a..4d7640d 100644
--- a/src/Index.cs
+++ b/src/Index.cs
@@ -1,5 +1,7 @@
+using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis;
using System.Text.Json.Serialization;
+using CommunityToolkit.Diagnostics;
namespace Pinecone;
@@ -29,12 +31,12 @@ public sealed partial record Index<
/// The URL address where the index is hosted.
///
public string? Host { get; init; }
-
+
///
/// Additional information about the index.
///
public required IndexSpec Spec { get; init; }
-
+
///
/// The current status of the index.
///
@@ -132,16 +134,73 @@ public Task Query(
}
///
- /// Writes vector into the index. If a new value is provided for an existing vector ID, it will overwrite the previous value.
+ /// Writes vectors into the index. If a new value is provided for an existing vector ID, it will overwrite the previous value.
///
+ ///
+ /// If the sequence of vectors is countable and greater than or equal to 400, it will be batched and the batches
+ /// will be upserted in parallel. The default batch size is 100 and the default parallelism is 20.
+ ///
/// A collection of objects to upsert.
/// Namespace to write the vector to. If no namespace is provided, the operation applies to all namespaces.
/// The number of vectors upserted.
- public Task Upsert(IEnumerable vectors, string? indexNamespace = null, CancellationToken ct = default)
+ public Task Upsert(
+ IEnumerable vectors,
+ string? indexNamespace = null,
+ CancellationToken ct = default)
{
+ const int batchSize = 100;
+ const int parallelism = 20;
+ const int threshold = 400;
+
+ if (vectors.TryGetNonEnumeratedCount(out var count) && count >= threshold)
+ {
+ return Upsert(vectors, batchSize, parallelism, indexNamespace, ct);
+ }
+
return Transport.Upsert(vectors, indexNamespace, ct);
}
+ ///
+ /// Writes vectors into the index as batches in parallel. If a new value is provided for an existing vector ID, it will overwrite the previous value.
+ ///
+ /// A collection of objects to upsert.
+ /// The number of vectors to upsert in each batch.
+ /// The maximum number of batches to process in parallel.
+ /// Namespace to write the vector to. If no namespace is provided, the operation applies to all namespaces.
+ /// The number of vectors upserted.
+ public async Task Upsert(
+ IEnumerable vectors,
+ int batchSize,
+ int parallelism,
+ string? indexNamespace = null,
+ CancellationToken ct = default)
+ {
+ Guard.IsGreaterThan(batchSize, 0);
+ Guard.IsGreaterThan(parallelism, 0);
+
+ if (parallelism is 1)
+ {
+ return await Transport.Upsert(vectors, indexNamespace, ct);
+ }
+
+ var upserted = 0u;
+ var batches = vectors.Chunk(batchSize);
+ var options = new ParallelOptions
+ {
+ CancellationToken = ct,
+ MaxDegreeOfParallelism = parallelism
+ };
+
+ // TODO: Do we need to provide more specific cancellation exception that
+ // includes the number of upserted vectors?
+ await Parallel.ForEachAsync(batches, options, async (batch, ct) =>
+ {
+ Interlocked.Add(ref upserted, await Transport.Upsert(batch, indexNamespace, ct));
+ });
+
+ return upserted;
+ }
+
///
/// Updates a vector using the object.
///
@@ -172,16 +231,68 @@ public Task Update(
}
///
- /// Looks up and returns vectors, by ID. The returned vectors include the vector data and/or metadata.
+ /// Looks up and returns vectors by ID. The returned vectors include the vector data and/or metadata.
///
+ ///
+ /// If the sequence of IDs is countable and greater than or equal to 600, it will be batched and the batches
+ /// will be fetched in parallel. The default batch size is 200 and the default parallelism is 20.
+ ///
/// IDs of vectors to fetch.
/// Namespace to fetch vectors from. If no namespace is provided, the operation applies to all namespaces.
/// A dictionary containing vector IDs and the corresponding objects containing the vector information.
public Task> Fetch(IEnumerable ids, string? indexNamespace = null, CancellationToken ct = default)
{
+ const int batchSize = 200;
+ const int parallelism = 20;
+ const int threshold = 600;
+
+ if (ids.TryGetNonEnumeratedCount(out var count) && count >= threshold)
+ {
+ return Fetch(ids, batchSize, parallelism, indexNamespace, ct);
+ }
+
return Transport.Fetch(ids, indexNamespace, ct);
}
+ ///
+ /// Looks up and returns vectors by ID as batches in parallel.
+ ///
+ /// IDs of vectors to fetch.
+ /// The number of vectors to fetch in each batch.
+ /// The maximum number of batches to process in parallel.
+ /// Namespace to fetch vectors from. If no namespace is provided, the operation applies to all namespaces.
+ /// A dictionary containing vector IDs and the corresponding objects containing the vector information.
+ public async Task> Fetch(
+ IEnumerable ids,
+ int batchSize,
+ int parallelism,
+ string? indexNamespace = null,
+ CancellationToken ct = default)
+ {
+ Guard.IsGreaterThan(batchSize, 0);
+ Guard.IsGreaterThan(parallelism, 0);
+
+ if (parallelism is 1)
+ {
+ return await Transport.Fetch(ids, indexNamespace, ct);
+ }
+
+ var fetched = new ConcurrentStack>();
+ var batches = ids.Chunk(batchSize);
+ var options = new ParallelOptions
+ {
+ CancellationToken = ct,
+ MaxDegreeOfParallelism = parallelism
+ };
+
+ await Parallel.ForEachAsync(batches, options, async (batch, ct) =>
+ {
+ fetched.Push(await Transport.Fetch(batch, indexNamespace, ct));
+ });
+
+ return new(fetched.SelectMany(batch => batch));
+ }
+
///
/// Deletes vectors with specified ids.
///