Skip to content

Commit

Permalink
[Segment Replication] Add check to cancel ongoing replication with ol…
Browse files Browse the repository at this point in the history
…d primary on onNewCheckpoint on replica

Signed-off-by: Suraj Singh <surajrider@gmail.com>
  • Loading branch information
dreamer-89 committed Aug 31, 2022
1 parent 19d1a2b commit 970e4de
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ public class SegmentReplicationTarget extends ReplicationTarget {
private final SegmentReplicationState state;
protected final MultiFileWriter multiFileWriter;

public ReplicationCheckpoint getCheckpoint() {
return this.checkpoint;
}

public SegmentReplicationTarget(
ReplicationCheckpoint checkpoint,
IndexShard indexShard,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.common.Nullable;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.CancellableThreads;
import org.opensearch.common.util.concurrent.ConcurrentCollections;
import org.opensearch.index.shard.IndexEventListener;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.shard.ShardId;
Expand All @@ -34,8 +35,8 @@
import org.opensearch.transport.TransportRequestHandler;
import org.opensearch.transport.TransportService;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicLong;

/**
Expand All @@ -54,7 +55,8 @@ public class SegmentReplicationTargetService implements IndexEventListener {

private final SegmentReplicationSourceFactory sourceFactory;

private final Map<ShardId, ReplicationCheckpoint> latestReceivedCheckpoint = new HashMap<>();
private final Map<ShardId, ReplicationCheckpoint> latestReceivedCheckpoint = ConcurrentCollections.newConcurrentMap();


// Empty Implementation, only required while Segment Replication is under feature flag.
public static final SegmentReplicationTargetService NO_OP = new SegmentReplicationTargetService() {
Expand Down Expand Up @@ -151,15 +153,23 @@ public synchronized void onNewCheckpoint(final ReplicationCheckpoint receivedChe
} else {
latestReceivedCheckpoint.put(replicaShard.shardId(), receivedCheckpoint);
}
if (onGoingReplications.isShardReplicating(replicaShard.shardId())) {
logger.trace(
() -> new ParameterizedMessage(
"Ignoring new replication checkpoint - shard is currently replicating to checkpoint {}",
replicaShard.getLatestReplicationCheckpoint()
)
);
return;
Optional<SegmentReplicationTarget> ongoingReplicationTarget = onGoingReplications.getOngoingReplicationTarget(replicaShard.shardId());
if (ongoingReplicationTarget.isPresent()) {
final SegmentReplicationTarget target = ongoingReplicationTarget.get();
if (target.getCheckpoint().getPrimaryTerm() < receivedCheckpoint.getPrimaryTerm()) {
logger.info("Cancelling ongoing replication from old primary with primary term {}", target.getCheckpoint().getPrimaryTerm());
target.cancel("Stuck target after new primary");
} else {
logger.trace(
() -> new ParameterizedMessage(
"Ignoring new replication checkpoint - shard is currently replicating to checkpoint {}",
replicaShard.getLatestReplicationCheckpoint()
)
);
return;
}
}

final Thread thread = Thread.currentThread();
if (replicaShard.shouldProcessCheckpoint(receivedCheckpoint)) {
startReplication(receivedCheckpoint, replicaShard, new SegmentReplicationListener() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ConcurrentMap;

/**
Expand Down Expand Up @@ -236,13 +237,13 @@ public boolean cancelForShard(ShardId shardId, String reason) {
}

/**
* check if a shard is currently replicating
* Get target for shard
*
* @param shardId shardId for which to check if replicating
* @return true if shard is currently replicating
* @param shardId shardId
* @return Optional ReplicationTarget for input shardId
*/
public boolean isShardReplicating(ShardId shardId) {
return onGoingTargetEvents.values().stream().anyMatch(t -> t.indexShard.shardId().equals(shardId));
public Optional<T> getOngoingReplicationTarget(ShardId shardId) {
return onGoingTargetEvents.values().stream().filter(t -> t.indexShard.shardId().equals(shardId)).findFirst();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ public class SegmentReplicationTargetServiceTests extends IndexShardTestCase {
private ReplicationCheckpoint initialCheckpoint;
private ReplicationCheckpoint aheadCheckpoint;

private ReplicationCheckpoint newPrimaryCheckpoint;

@Override
public void setUp() throws Exception {
super.setUp();
Expand All @@ -74,6 +76,14 @@ public void setUp() throws Exception {
initialCheckpoint.getSeqNo(),
initialCheckpoint.getSegmentInfosVersion() + 1
);

newPrimaryCheckpoint = new ReplicationCheckpoint(
initialCheckpoint.getShardId(),
initialCheckpoint.getPrimaryTerm() + 1,
initialCheckpoint.getSegmentsGen(),
initialCheckpoint.getSeqNo(),
initialCheckpoint.getSegmentInfosVersion() + 1
);
}

@Override
Expand Down Expand Up @@ -160,7 +170,7 @@ public void testShardAlreadyReplicating() throws InterruptedException {
// Create a spy of Target Service so that we can verify invocation of startReplication call with specific checkpoint on it.
SegmentReplicationTargetService serviceSpy = spy(sut);
final SegmentReplicationTarget target = new SegmentReplicationTarget(
checkpoint,
initialCheckpoint,
replicaShard,
replicationSource,
mock(SegmentReplicationTargetService.SegmentReplicationListener.class)
Expand All @@ -185,9 +195,47 @@ public void testShardAlreadyReplicating() throws InterruptedException {

// wait for the new checkpoint to arrive, before the listener completes.
latch.await(30, TimeUnit.SECONDS);
verify(targetSpy, times(0)).cancel(any());
verify(serviceSpy, times(0)).startReplication(eq(aheadCheckpoint), eq(replicaShard), any());
}

public void testOnNewCheckpointFromNewPrimaryCancelOngoingReplication() throws IOException, InterruptedException {
// Create a spy of Target Service so that we can verify invocation of startReplication call with specific checkpoint on it.
SegmentReplicationTargetService serviceSpy = spy(sut);
// Create a Mockito spy of target to stub response of few method calls.
final SegmentReplicationTarget targetSpy = spy(
new SegmentReplicationTarget(
initialCheckpoint,
replicaShard,
replicationSource,
mock(SegmentReplicationTargetService.SegmentReplicationListener.class)
)
);

CountDownLatch latch = new CountDownLatch(1);
// Mocking response when startReplication is called on targetSpy we send a new checkpoint to serviceSpy and later reduce countdown
// of latch.
doAnswer(invocation -> {
final ActionListener<Void> listener = invocation.getArgument(0);
// a new checkpoint arrives before we've completed.
serviceSpy.onNewCheckpoint(newPrimaryCheckpoint, replicaShard);
listener.onResponse(null);
latch.countDown();
return null;
}).when(targetSpy).startReplication(any());
doNothing().when(targetSpy).onDone();

// start replication. This adds the target to on-ongoing replication collection
serviceSpy.startReplication(targetSpy);

// wait for the new checkpoint to arrive, before the listener completes.
latch.await(5, TimeUnit.SECONDS);
doNothing().when(targetSpy).startReplication(any());
verify(targetSpy, times(1)).cancel(any());
verify(serviceSpy, times(1)).startReplication(eq(newPrimaryCheckpoint), eq(replicaShard), any());
closeShards(replicaShard);
}

public void testNewCheckpointBehindCurrentCheckpoint() {
SegmentReplicationTargetService spy = spy(sut);
spy.onNewCheckpoint(checkpoint, replicaShard);
Expand Down

0 comments on commit 970e4de

Please sign in to comment.