Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Segment Replication] Add check to cancel ongoing replication with old primary on onNewCheckpoint on replica #4363

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Fixed cancellation of segment replication events ([#4225](https://github.com/opensearch-project/OpenSearch/pull/4225))
- Bugs for dependabot changelog verifier workflow ([#4364](https://github.com/opensearch-project/OpenSearch/pull/4364))
- Fix flaky random test `NRTReplicationEngineTests.testUpdateSegments` ([#4352](https://github.com/opensearch-project/OpenSearch/pull/4352))
- [Segment Replication] Add check to cancel ongoing replication with old primary on onNewCheckpoint on replica ([#4363](https://github.com/opensearch-project/OpenSearch/pull/4363))

### Security
- CVE-2022-25857 org.yaml:snakeyaml DOS vulnerability ([#4341](https://github.com/opensearch-project/OpenSearch/pull/4341))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ public class SegmentReplicationTarget extends ReplicationTarget {
private final SegmentReplicationState state;
protected final MultiFileWriter multiFileWriter;

public ReplicationCheckpoint getCheckpoint() {
return this.checkpoint;
}

public SegmentReplicationTarget(
ReplicationCheckpoint checkpoint,
IndexShard indexShard,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.common.Nullable;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.CancellableThreads;
import org.opensearch.common.util.concurrent.ConcurrentCollections;
import org.opensearch.index.shard.IndexEventListener;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.shard.ShardId;
Expand All @@ -34,7 +35,6 @@
import org.opensearch.transport.TransportRequestHandler;
import org.opensearch.transport.TransportService;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;

Expand All @@ -54,7 +54,7 @@ public class SegmentReplicationTargetService implements IndexEventListener {

private final SegmentReplicationSourceFactory sourceFactory;

private final Map<ShardId, ReplicationCheckpoint> latestReceivedCheckpoint = new HashMap<>();
private final Map<ShardId, ReplicationCheckpoint> latestReceivedCheckpoint = ConcurrentCollections.newConcurrentMap();

// Empty Implementation, only required while Segment Replication is under feature flag.
public static final SegmentReplicationTargetService NO_OP = new SegmentReplicationTargetService() {
Expand Down Expand Up @@ -151,14 +151,23 @@ public synchronized void onNewCheckpoint(final ReplicationCheckpoint receivedChe
} else {
latestReceivedCheckpoint.put(replicaShard.shardId(), receivedCheckpoint);
}
if (onGoingReplications.isShardReplicating(replicaShard.shardId())) {
logger.trace(
() -> new ParameterizedMessage(
"Ignoring new replication checkpoint - shard is currently replicating to checkpoint {}",
replicaShard.getLatestReplicationCheckpoint()
)
);
return;
SegmentReplicationTarget ongoingReplicationTarget = onGoingReplications.getOngoingReplicationTarget(replicaShard.shardId());
if (ongoingReplicationTarget != null) {
if (ongoingReplicationTarget.getCheckpoint().getPrimaryTerm() < receivedCheckpoint.getPrimaryTerm()) {
logger.trace(
"Cancelling ongoing replication from old primary with primary term {}",
ongoingReplicationTarget.getCheckpoint().getPrimaryTerm()
);
onGoingReplications.cancel(ongoingReplicationTarget.getId(), "Cancelling stuck target after new primary");
} else {
logger.trace(
() -> new ParameterizedMessage(
"Ignoring new replication checkpoint - shard is currently replicating to checkpoint {}",
replicaShard.getLatestReplicationCheckpoint()
)
);
return;
}
}
final Thread thread = Thread.currentThread();
if (replicaShard.shouldProcessCheckpoint(receivedCheckpoint)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ConcurrentMap;
import java.util.stream.Collectors;

/**
* This class holds a collection of all on going replication events on the current node (i.e., the node is the target node
Expand Down Expand Up @@ -236,13 +237,18 @@ public boolean cancelForShard(ShardId shardId, String reason) {
}

/**
* check if a shard is currently replicating
* Get target for shard
*
* @param shardId shardId for which to check if replicating
* @return true if shard is currently replicating
* @param shardId shardId
* @return ReplicationTarget for input shardId
*/
public boolean isShardReplicating(ShardId shardId) {
return onGoingTargetEvents.values().stream().anyMatch(t -> t.indexShard.shardId().equals(shardId));
public T getOngoingReplicationTarget(ShardId shardId) {
final List<T> replicationTargetList = onGoingTargetEvents.values()
.stream()
.filter(t -> t.indexShard.shardId().equals(shardId))
.collect(Collectors.toList());
assert replicationTargetList.size() <= 1 : "More than one on-going replication targets";
return replicationTargetList.size() > 0 ? replicationTargetList.get(0) : null;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ public class SegmentReplicationTargetServiceTests extends IndexShardTestCase {
private ReplicationCheckpoint initialCheckpoint;
private ReplicationCheckpoint aheadCheckpoint;

private ReplicationCheckpoint newPrimaryCheckpoint;

@Override
public void setUp() throws Exception {
super.setUp();
Expand All @@ -74,6 +76,13 @@ public void setUp() throws Exception {
initialCheckpoint.getSeqNo(),
initialCheckpoint.getSegmentInfosVersion() + 1
);
newPrimaryCheckpoint = new ReplicationCheckpoint(
initialCheckpoint.getShardId(),
initialCheckpoint.getPrimaryTerm() + 1,
initialCheckpoint.getSegmentsGen(),
initialCheckpoint.getSeqNo(),
initialCheckpoint.getSegmentInfosVersion() + 1
);
}

@Override
Expand Down Expand Up @@ -160,7 +169,7 @@ public void testShardAlreadyReplicating() throws InterruptedException {
// Create a spy of Target Service so that we can verify invocation of startReplication call with specific checkpoint on it.
SegmentReplicationTargetService serviceSpy = spy(sut);
final SegmentReplicationTarget target = new SegmentReplicationTarget(
checkpoint,
initialCheckpoint,
replicaShard,
replicationSource,
mock(SegmentReplicationTargetService.SegmentReplicationListener.class)
Expand All @@ -185,9 +194,47 @@ public void testShardAlreadyReplicating() throws InterruptedException {

// wait for the new checkpoint to arrive, before the listener completes.
latch.await(30, TimeUnit.SECONDS);
verify(targetSpy, times(0)).cancel(any());
verify(serviceSpy, times(0)).startReplication(eq(aheadCheckpoint), eq(replicaShard), any());
}

public void testOnNewCheckpointFromNewPrimaryCancelOngoingReplication() throws IOException, InterruptedException {
// Create a spy of Target Service so that we can verify invocation of startReplication call with specific checkpoint on it.
SegmentReplicationTargetService serviceSpy = spy(sut);
// Create a Mockito spy of target to stub response of few method calls.
final SegmentReplicationTarget targetSpy = spy(
new SegmentReplicationTarget(
initialCheckpoint,
replicaShard,
replicationSource,
mock(SegmentReplicationTargetService.SegmentReplicationListener.class)
)
);

CountDownLatch latch = new CountDownLatch(1);
// Mocking response when startReplication is called on targetSpy we send a new checkpoint to serviceSpy and later reduce countdown
// of latch.
doAnswer(invocation -> {
final ActionListener<Void> listener = invocation.getArgument(0);
// a new checkpoint arrives before we've completed.
serviceSpy.onNewCheckpoint(newPrimaryCheckpoint, replicaShard);
listener.onResponse(null);
latch.countDown();
return null;
}).when(targetSpy).startReplication(any());
doNothing().when(targetSpy).onDone();

// start replication. This adds the target to on-ongoing replication collection
serviceSpy.startReplication(targetSpy);

// wait for the new checkpoint to arrive, before the listener completes.
latch.await(5, TimeUnit.SECONDS);
doNothing().when(targetSpy).startReplication(any());
verify(targetSpy, times(1)).cancel("Cancelling stuck target after new primary");
verify(serviceSpy, times(1)).startReplication(eq(newPrimaryCheckpoint), eq(replicaShard), any());
closeShards(replicaShard);
}

public void testNewCheckpointBehindCurrentCheckpoint() {
SegmentReplicationTargetService spy = spy(sut);
spy.onNewCheckpoint(checkpoint, replicaShard);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,25 @@ public void onFailure(ReplicationState state, OpenSearchException e, boolean sen
collection.cancel(recoveryId, "meh");
}
}
}

public void testMultiReplicationsForSingleShard() throws Exception {
try (ReplicationGroup shards = createGroup(0)) {
final ReplicationCollection<RecoveryTarget> collection = new ReplicationCollection<>(logger, threadPool);
final IndexShard shard1 = shards.addReplica();
final IndexShard shard2 = shards.addReplica();
final long recoveryId = startRecovery(collection, shards.getPrimaryNode(), shard1);
final long recoveryId2 = startRecovery(collection, shards.getPrimaryNode(), shard2);
try {
collection.getOngoingReplicationTarget(shard1.shardId());
} catch (AssertionError e) {
assertEquals(e.getMessage(), "More than one on-going replication targets");
} finally {
collection.cancel(recoveryId, "meh");
collection.cancel(recoveryId2, "meh");
}
closeShards(shard1, shard2);
}
}

public void testRecoveryCancellation() throws Exception {
Expand Down