Skip to content

Commit

Permalink
add related test
Browse files Browse the repository at this point in the history
  • Loading branch information
twosom committed Oct 1, 2024
1 parent 9405361 commit 4c89506
Showing 1 changed file with 57 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
package org.apache.beam.runners.spark.translation;

import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.util.Arrays;
import java.util.Iterator;
Expand All @@ -39,6 +45,9 @@
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.primitives.Bytes;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Assert;
Expand Down Expand Up @@ -112,6 +121,54 @@ public void testGbkIteratorValuesCannotBeReiterated() throws Coder.NonDeterminis
}
}

@Test
@SuppressWarnings({"rawtypes", "unchecked"})
public void testGroupByKeyInGlobalWindowWithPartitioner() {
// mocking
Partitioner mockPartitioner = mock(Partitioner.class);
JavaRDD mockRdd = mock(JavaRDD.class);
Coder mockKeyCoder = mock(Coder.class);
Coder mockValueCoder = mock(Coder.class);
JavaPairRDD mockRawKeyValues = mock(JavaPairRDD.class);
JavaPairRDD mockGrouped = mock(JavaPairRDD.class);

when(mockRdd.mapToPair(any())).thenReturn(mockRawKeyValues);
when(mockRawKeyValues.groupByKey(any(Partitioner.class)))
.thenAnswer(
invocation -> {
Partitioner partitioner = invocation.getArgument(0);
assertEquals(partitioner, mockPartitioner);
return mockGrouped;
});
when(mockGrouped.map(any())).thenReturn(mock(JavaRDD.class));

GroupNonMergingWindowsFunctions.groupByKeyInGlobalWindow(
mockRdd, mockKeyCoder, mockValueCoder, mockPartitioner);

verify(mockRawKeyValues, never()).groupByKey();
verify(mockRawKeyValues, times(1)).groupByKey(any(Partitioner.class));
}

@Test
@SuppressWarnings({"rawtypes", "unchecked"})
public void testGroupByKeyInGlobalWindowWithoutPartitioner() {
// mocking
JavaRDD mockRdd = mock(JavaRDD.class);
Coder mockKeyCoder = mock(Coder.class);
Coder mockValueCoder = mock(Coder.class);
JavaPairRDD mockRawKeyValues = mock(JavaPairRDD.class);
JavaPairRDD mockGrouped = mock(JavaPairRDD.class);

when(mockRdd.mapToPair(any())).thenReturn(mockRawKeyValues);
when(mockRawKeyValues.groupByKey()).thenReturn(mockGrouped);

GroupNonMergingWindowsFunctions.groupByKeyInGlobalWindow(
mockRdd, mockKeyCoder, mockValueCoder, null);

verify(mockRawKeyValues, times(1)).groupByKey();
verify(mockRawKeyValues, never()).groupByKey(any(Partitioner.class));
}

private GroupByKeyIterator<String, Integer, GlobalWindow> createGbkIterator()
throws Coder.NonDeterministicException {
return createGbkIterator(
Expand Down

0 comments on commit 4c89506

Please sign in to comment.