Skip to content

Commit

Permalink
feat: automatically parallelize upsert and fetch operations
Browse files Browse the repository at this point in the history
  • Loading branch information
neon-sunset committed Jun 4, 2024
1 parent fa54e8d commit 0fd32b6
Showing 1 changed file with 116 additions and 5 deletions.
121 changes: 116 additions & 5 deletions src/Index.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis;
using System.Text.Json.Serialization;
using CommunityToolkit.Diagnostics;

namespace Pinecone;

Expand Down Expand Up @@ -29,12 +31,12 @@ public sealed partial record Index<
/// The URL address where the index is hosted.
/// </summary>
public string? Host { get; init; }

/// <summary>
/// Additional information about the index.
/// </summary>
public required IndexSpec Spec { get; init; }

/// <summary>
/// The current status of the index.
/// </summary>
Expand Down Expand Up @@ -132,16 +134,73 @@ public Task<ScoredVector[]> Query(
}

/// <summary>
/// Writes vector into the index. If a new value is provided for an existing vector ID, it will overwrite the previous value.
/// Writes vectors into the index. If a new value is provided for an existing vector ID, it will overwrite the previous value.
/// </summary>
/// <remarks>
/// If the sequence of vectors is countable and greater than or equal to 400, it will be batched and the batches
/// will be upserted in parallel. The default batch size is 100 and the default parallelism is 20.
/// </remarks>
/// <param name="vectors">A collection of <see cref="Vector"/> objects to upsert.</param>
/// <param name="indexNamespace">Namespace to write the vector to. If no namespace is provided, the operation applies to all namespaces.</param>
/// <returns>The number of vectors upserted.</returns>
public Task<uint> Upsert(IEnumerable<Vector> vectors, string? indexNamespace = null, CancellationToken ct = default)
public Task<uint> Upsert(
IEnumerable<Vector> vectors,
string? indexNamespace = null,
CancellationToken ct = default)
{
const int batchSize = 100;
const int parallelism = 20;
const int threshold = 400;

if (vectors.TryGetNonEnumeratedCount(out var count) && count >= threshold)
{
return Upsert(vectors, batchSize, parallelism, indexNamespace, ct);
}

return Transport.Upsert(vectors, indexNamespace, ct);
}

/// <summary>
/// Writes vectors into the index as batches in parallel. If a new value is provided for an existing vector ID, it will overwrite the previous value.
/// </summary>
/// <param name="vectors">A collection of <see cref="Vector"/> objects to upsert.</param>
/// <param name="batchSize">The number of vectors to upsert in each batch.</param>
/// <param name="parallelism">The maximum number of batches to process in parallel.</param>
/// <param name="indexNamespace">Namespace to write the vector to. If no namespace is provided, the operation applies to all namespaces.</param>
/// <returns>The number of vectors upserted.</returns>
public async Task<uint> Upsert(
IEnumerable<Vector> vectors,
int batchSize,
int parallelism,
string? indexNamespace = null,
CancellationToken ct = default)
{
Guard.IsGreaterThan(batchSize, 0);
Guard.IsGreaterThan(parallelism, 0);

if (parallelism is 1)
{
return await Transport.Upsert(vectors, indexNamespace, ct);
}

var upserted = 0u;
var batches = vectors.Chunk(batchSize);
var options = new ParallelOptions
{
CancellationToken = ct,
MaxDegreeOfParallelism = parallelism
};

// TODO: Do we need to provide more specific cancellation exception that
// includes the number of upserted vectors?
await Parallel.ForEachAsync(batches, options, async (batch, ct) =>
{
Interlocked.Add(ref upserted, await Transport.Upsert(batch, indexNamespace, ct));
});

return upserted;
}

/// <summary>
/// Updates a vector using the <see cref="Vector"/> object.
/// </summary>
Expand Down Expand Up @@ -172,16 +231,68 @@ public Task Update(
}

/// <summary>
/// Looks up and returns vectors, by ID. The returned vectors include the vector data and/or metadata.
/// Looks up and returns vectors by ID. The returned vectors include the vector data and/or metadata.
/// </summary>
/// <remarks>
/// If the sequence of IDs is countable and greater than or equal to 600, it will be batched and the batches
/// will be fetched in parallel. The default batch size is 200 and the default parallelism is 20.
/// </remarks>
/// <param name="ids">IDs of vectors to fetch.</param>
/// <param name="indexNamespace">Namespace to fetch vectors from. If no namespace is provided, the operation applies to all namespaces.</param>
/// <returns>A dictionary containing vector IDs and the corresponding <see cref="Vector"/> objects containing the vector information.</returns>
public Task<Dictionary<string, Vector>> Fetch(IEnumerable<string> ids, string? indexNamespace = null, CancellationToken ct = default)
{
const int batchSize = 200;
const int parallelism = 20;
const int threshold = 600;

if (ids.TryGetNonEnumeratedCount(out var count) && count >= threshold)
{
return Fetch(ids, batchSize, parallelism, indexNamespace, ct);
}

return Transport.Fetch(ids, indexNamespace, ct);
}

/// <summary>
/// Looks up and returns vectors by ID as batches in parallel.
/// </summary>
/// <param name="ids">IDs of vectors to fetch.</param>
/// <param name="batchSize">The number of vectors to fetch in each batch.</param>
/// <param name="parallelism">The maximum number of batches to process in parallel.</param>
/// <param name="indexNamespace">Namespace to fetch vectors from. If no namespace is provided, the operation applies to all namespaces.</param>
/// <returns>A dictionary containing vector IDs and the corresponding <see cref="Vector"/> objects containing the vector information.</returns>
public async Task<Dictionary<string, Vector>> Fetch(
IEnumerable<string> ids,
int batchSize,
int parallelism,
string? indexNamespace = null,
CancellationToken ct = default)
{
Guard.IsGreaterThan(batchSize, 0);
Guard.IsGreaterThan(parallelism, 0);

if (parallelism is 1)
{
return await Transport.Fetch(ids, indexNamespace, ct);
}

var fetched = new ConcurrentStack<Dictionary<string, Vector>>();
var batches = ids.Chunk(batchSize);
var options = new ParallelOptions
{
CancellationToken = ct,
MaxDegreeOfParallelism = parallelism
};

await Parallel.ForEachAsync(batches, options, async (batch, ct) =>
{
fetched.Push(await Transport.Fetch(batch, indexNamespace, ct));
});

return new(fetched.SelectMany(batch => batch));
}

/// <summary>
/// Deletes vectors with specified ids.
/// </summary>
Expand Down

0 comments on commit 0fd32b6

Please sign in to comment.