Skip to content

Commit

Permalink
Free search vector from artificial size limit
Browse files Browse the repository at this point in the history
The earlier implementation imposed a size limit of 1 for
the search vector. This perhaps was not needed.
It should be flexible enough to work as a standard aggregate
function. So, it wouldn't make sense to impose that limit
and as a result use-cases.
  • Loading branch information
ttanay committed Sep 10, 2023
1 parent a76372b commit b9e6e09
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 16 deletions.
36 changes: 21 additions & 15 deletions src/list_distance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ struct DistanceStateVector {
Vector state_vector;
};

// TODO: Maybe use better names?
// Note: `search_l` should be a constant vector
// Take two lists - `l` and `search_l`
// Compute the distance between the tuples resulting from the cross product `l` and `search_l`
Expand Down Expand Up @@ -119,18 +120,22 @@ static void ListDistanceFunction(DataChunk &args, ExpressionState &state, Vector

D_ASSERT(aggr.function.update);

// Iterate over `l` list
// TODO: Add some sort of an iterator interface for DuckDB's Vector
UnifiedVectorFormat l_data;
UnifiedVectorFormat search_l_data;
l.ToUnifiedFormat(count, l_data);
search_l.ToUnifiedFormat(count, search_l_data);
auto l_entries = UnifiedVectorFormat::GetData<list_entry_t>(l_data);
auto search_l_entries = UnifiedVectorFormat::GetData<list_entry_t>(search_l_data);

auto l_list_size = ListVector::GetListSize(l);
auto search_l_list_size = ListVector::GetListSize(search_l);
auto &l_child = ListVector::GetEntry(l);
auto &search_l_child = ListVector::GetEntry(search_l);
UnifiedVectorFormat l_child_data;
UnifiedVectorFormat search_l_child_data;
l_child.ToUnifiedFormat(l_list_size, l_child_data);
search_l_child.ToUnifiedFormat(search_l_list_size, search_l_child_data);

// state_buffer holds the state for each list of this chunk
idx_t size = aggr.function.state_size();
Expand All @@ -144,54 +149,55 @@ static void ListDistanceFunction(DataChunk &args, ExpressionState &state, Vector
Vector state_vector_update = Vector(LogicalType::POINTER);
auto states_update = FlatVector::GetData<data_ptr_t>(state_vector_update);

// // Get the first index of the search_l since there won't be any others
D_ASSERT(search_l.length == 1);

for (idx_t i = 0; i < count; i++) {
// initialize the state for this list
auto state_ptr = state_buffer.get() + size * i;
states[i] = state_ptr;
aggr.function.initialize(states[i]);

auto l_index = l_data.sel->get_index(i);
auto search_l_index = search_l_data.sel->get_index(i);
const auto &l_entry = l_entries[l_index];
// D_ASSERT(l_entry.length == search_l_entry.length);
const auto &search_l_entry = search_l_entries[search_l_index];
D_ASSERT(l_entry.length == search_l_entry.length);

// nothing to do for this list
if (!l_data.validity.RowIsValid(l_index)) {
if (!l_data.validity.RowIsValid(l_index) || !search_l_data.validity.RowIsValid(search_l_index)) {
result_validity.SetInvalid(i);
continue;
}
if (l_entry.length == 0)
if (l_entry.length == 0 || search_l_entry.length == 0)
continue;

SelectionVector l_sel_vector(STANDARD_VECTOR_SIZE);
// SelectionVector search_l_sel_vector(STANDARD_VECTOR_SIZE);
SelectionVector search_l_sel_vector(STANDARD_VECTOR_SIZE);

// Assumes that that all vectors are of the same length/size
idx_t states_idx = 0;
// A selection index: 0..l_entry.length; value of selection index is updated to the latest values
// B selection index: 0..l_entry.length; value of selection index is the same first l_entry.length vectors
// Iterate over both lists to compute the distance
for (idx_t j = 0; j < l_entry.length; j++) {
if (states_idx == STANDARD_VECTOR_SIZE) {
// Do the update and reset the states_idx
Vector l_slice(l_child, l_sel_vector, states_idx);
Vector inputs[] = {l_slice, search_l_child};
Vector search_l_slice(search_l_child, search_l_sel_vector, states_idx);
Vector inputs[] = {l_slice, search_l_slice};
aggr.function.update(inputs, aggr_input_data, 2, state_vector_update, states_idx);

states_idx = 0;
}

idx_t actual_idx = l_child_data.sel->get_index(l_entry.offset + j);
l_sel_vector.set_index(states_idx, actual_idx);
// search_l_sel_vector.set_index(states_idx, actual_idx);
idx_t l_actual_idx = l_child_data.sel->get_index(l_entry.offset + j);
idx_t search_l_actual_idx = search_l_child_data.sel->get_index(search_l_entry.offset + j);
l_sel_vector.set_index(states_idx, l_actual_idx);
search_l_sel_vector.set_index(states_idx, search_l_actual_idx);
states_update[states_idx] = state_ptr;
states_idx++;
}

if (states_idx != 0) {
Vector l_slice(l_child, l_sel_vector, states_idx);
Vector inputs[] = {l_slice, search_l_child};
Vector search_l_slice(search_l_child, search_l_sel_vector, states_idx);
Vector inputs[] = {l_slice, search_l_slice};
aggr.function.update(inputs, aggr_input_data, 2, state_vector_update, states_idx);
}
}
Expand Down
36 changes: 36 additions & 0 deletions test/sql/list_distance.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# name: test/sql/list_distance_inputs.test
# description: test list_distance inputs
# group: [list_distance]

# Test for different-sized inputs of list_distance

require vector

# prepare table
statement ok
CREATE TABLE vectors(v1 DOUBLE[10], v2 DOUBLE[10]);

statement ok
INSERT INTO vectors VALUES
([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 789.0, 10.0, 11.0], [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 181.0, 10.0, 10.0]),
([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 181.0, 10.0, 10.0], [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 789.0, 10.0, 11.0]),
([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 785.0, 10.0, 11.0], [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 181.0, 10.0, 10.0]);

# distance w.r.t a single search vector
query R
SELECT list_distance(v1, [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 785.0, 10.0, 11.0], 'l2distance') FROM vectors;
----
4.0
604.0008278140023
0.0

# distance for multiple columns of a table
query R
SELECT list_euclidean_distance(v1, v2) FROM vectors;
----
608.0008223678649
608.0008223678649
604.0008278140023

statement ok
DROP TABLE vectors;
5 changes: 4 additions & 1 deletion test/sql/vector.test
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,7 @@ SELECT list_cosine_similarity([1,2],[2,3]);
query R
SELECT list_distance([1,2],[2,3], 'cosine_similarity') + list_distance([1,2],[2,3], 'cosine_distance');
----
1
1

statement ok
DROP TABLE vectors;

0 comments on commit b9e6e09

Please sign in to comment.