Skip to content

Commit

Permalink
Make LambdaClientFactory common to sink and processor
Browse files Browse the repository at this point in the history
Signed-off-by: Srikanth Govindarajan <srigovs@amazon.com>
  • Loading branch information
srikanthjg committed Aug 6, 2024
1 parent 0cf675b commit 2fe00e8
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 240 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.lambda.sink;
package org.opensearch.dataprepper.plugins.lambda.common.client;

import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
Expand All @@ -16,20 +16,21 @@
public final class LambdaClientFactory {
private LambdaClientFactory() { }

static LambdaClient createLambdaClient(final LambdaSinkConfig lambdaSinkConfig,
final AwsCredentialsSupplier awsCredentialsSupplier) {
final AwsCredentialsOptions awsCredentialsOptions = convertToCredentialsOptions(lambdaSinkConfig.getAwsAuthenticationOptions());
public static LambdaClient createLambdaClient(final AwsAuthenticationOptions awsAuthenticationOptions,
final int maxConnectionRetries,
final AwsCredentialsSupplier awsCredentialsSupplier) {
final AwsCredentialsOptions awsCredentialsOptions = convertToCredentialsOptions(awsAuthenticationOptions);
final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(awsCredentialsOptions);

return LambdaClient.builder()
.region(lambdaSinkConfig.getAwsAuthenticationOptions().getAwsRegion())
.region(awsAuthenticationOptions.getAwsRegion())
.credentialsProvider(awsCredentialsProvider)
.overrideConfiguration(createOverrideConfiguration(lambdaSinkConfig)).build();
.overrideConfiguration(createOverrideConfiguration(maxConnectionRetries)).build();

}

private static ClientOverrideConfiguration createOverrideConfiguration(final LambdaSinkConfig lambdaSinkConfig) {
final RetryPolicy retryPolicy = RetryPolicy.builder().numRetries(lambdaSinkConfig.getMaxConnectionRetries()).build();
private static ClientOverrideConfiguration createOverrideConfiguration(final int maxConnectionRetries) {
final RetryPolicy retryPolicy = RetryPolicy.builder().numRetries(maxConnectionRetries).build();
return ClientOverrideConfiguration.builder()
.retryPolicy(retryPolicy)
.build();
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBufferFactory;
import org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory;
import org.opensearch.dataprepper.plugins.lambda.common.codec.LambdaJsonCodec;
import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions;
import org.opensearch.dataprepper.plugins.lambda.common.util.ThresholdCheck;
Expand Down Expand Up @@ -114,14 +115,16 @@ public LambdaProcessor(final PluginMetrics pluginMetrics, final LambdaProcessorC
if (mode != null && mode.equalsIgnoreCase(LambdaProcessorConfig.SYNCHRONOUS_MODE)) {
invocationType = SYNC_INVOCATION_TYPE;
} else {
throw new RuntimeException("mode has to be synchronous or asynchronous");
throw new RuntimeException("Unsupported mode " + mode);
}

codec = new LambdaJsonCodec(batchKey);
bufferedEventHandles = new LinkedList<>();
events = new ArrayList();

lambdaClient = LambdaClientFactory.createLambdaClient(lambdaProcessorConfig, awsCredentialsSupplier);
lambdaClient = LambdaClientFactory.createLambdaClient(lambdaProcessorConfig.getAwsAuthenticationOptions(),
lambdaProcessorConfig.getMaxConnectionRetries()
, awsCredentialsSupplier);

this.bufferFactory = new InMemoryBufferFactory();
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.dataprepper.model.sink.SinkContext;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBufferFactory;
import org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory;
import org.opensearch.dataprepper.plugins.lambda.sink.dlq.DlqPushHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -47,7 +48,9 @@ public LambdaSink(final PluginSetting pluginSetting,
super(pluginSetting);
sinkInitialized = Boolean.FALSE;
OutputCodecContext outputCodecContext = OutputCodecContext.fromSinkContext(sinkContext);
LambdaClient lambdaClient = LambdaClientFactory.createLambdaClient(lambdaSinkConfig, awsCredentialsSupplier);
LambdaClient lambdaClient = LambdaClientFactory.createLambdaClient(lambdaSinkConfig.getAwsAuthenticationOptions(),
lambdaSinkConfig.getMaxConnectionRetries()
, awsCredentialsSupplier);
if(lambdaSinkConfig.getDlqPluginSetting() != null) {
this.dlqPushHandler = new DlqPushHandler(pluginFactory,
String.valueOf(lambdaSinkConfig.getDlqPluginSetting().get(BUCKET)),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package org.opensearch.dataprepper.plugins.lambda.common.client;

import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import static org.mockito.ArgumentMatchers.any;
import org.mockito.Mock;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.lambda.LambdaClient;
import software.amazon.awssdk.services.lambda.LambdaClientBuilder;

import java.util.Map;
import java.util.UUID;

@ExtendWith(MockitoExtension.class)
class LambdaClientFactoryTest {
@Mock
private AwsCredentialsSupplier awsCredentialsSupplier;

@Mock
private AwsAuthenticationOptions awsAuthenticationOptions;

@Mock
private AwsCredentialsProvider awsCredentialsProvider;

@BeforeEach
void setUp() {
// No setup needed here as we're mocking static methods in tests
}

@Test
void createLambdaClient_with_real_LambdaClient() {
try (var mockedStaticLambdaClient = mockStatic(LambdaClient.class)) {
LambdaClientBuilder lambdaClientBuilder = mock(LambdaClientBuilder.class);
mockedStaticLambdaClient.when(LambdaClient::builder).thenReturn(lambdaClientBuilder);

when(lambdaClientBuilder.region(any(Region.class))).thenReturn(lambdaClientBuilder);
when(lambdaClientBuilder.credentialsProvider(any(AwsCredentialsProvider.class))).thenReturn(lambdaClientBuilder);
when(lambdaClientBuilder.overrideConfiguration(any(ClientOverrideConfiguration.class))).thenReturn(lambdaClientBuilder);
when(lambdaClientBuilder.build()).thenReturn(mock(LambdaClient.class));

when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.US_EAST_1);
when(awsCredentialsSupplier.getProvider(any(AwsCredentialsOptions.class))).thenReturn(awsCredentialsProvider);

final LambdaClient lambdaClient = LambdaClientFactory.createLambdaClient(awsAuthenticationOptions, 3, awsCredentialsSupplier);

assertThat(lambdaClient, notNullValue());
}
}

@ParameterizedTest
@ValueSource(strings = {"us-east-1", "us-west-2", "eu-central-1"})
void createlambdaClient_provides_correct_inputs(final String regionString) {
try (var mockedStaticLambdaClient = mockStatic(LambdaClient.class)) {
LambdaClientBuilder lambdaClientBuilder = mock(LambdaClientBuilder.class);
mockedStaticLambdaClient.when(LambdaClient::builder).thenReturn(lambdaClientBuilder);

final Region region = Region.of(regionString);
final String stsRoleArn = UUID.randomUUID().toString();
final Map<String, String> stsHeaderOverrides = Map.of(UUID.randomUUID().toString(), UUID.randomUUID().toString());
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(region);
when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(stsRoleArn);
when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(stsHeaderOverrides);
when(awsCredentialsSupplier.getProvider(any())).thenReturn(awsCredentialsProvider);

when(lambdaClientBuilder.region(any(Region.class))).thenReturn(lambdaClientBuilder);
when(lambdaClientBuilder.credentialsProvider(any(AwsCredentialsProvider.class))).thenReturn(lambdaClientBuilder);
when(lambdaClientBuilder.overrideConfiguration(any(ClientOverrideConfiguration.class))).thenReturn(lambdaClientBuilder);
when(lambdaClientBuilder.build()).thenReturn(mock(LambdaClient.class));

final LambdaClient lambdaClient = LambdaClientFactory.createLambdaClient(awsAuthenticationOptions, 3, awsCredentialsSupplier);

final ArgumentCaptor<AwsCredentialsProvider> credentialsProviderArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsProvider.class);
verify(lambdaClientBuilder).credentialsProvider(credentialsProviderArgumentCaptor.capture());
final AwsCredentialsProvider actualCredentialsProvider = credentialsProviderArgumentCaptor.getValue();
assertThat(actualCredentialsProvider, equalTo(awsCredentialsProvider));

final ArgumentCaptor<AwsCredentialsOptions> optionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class);
verify(awsCredentialsSupplier).getProvider(optionsArgumentCaptor.capture());

final AwsCredentialsOptions actualCredentialsOptions = optionsArgumentCaptor.getValue();
assertThat(actualCredentialsOptions.getRegion(), equalTo(region));
assertThat(actualCredentialsOptions.getStsRoleArn(), equalTo(stsRoleArn));
assertThat(actualCredentialsOptions.getStsHeaderOverrides(), equalTo(stsHeaderOverrides));
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.opensearch.dataprepper.model.types.ByteCount;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory;
import org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory;
import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions;
Expand All @@ -60,8 +61,6 @@

@ExtendWith(MockitoExtension.class)
public class LambdaProcessorTest {
private static final String PROCESSOR_PLUGIN_NAME = "aws_lambda";
private static final String PROCESSOR_PIPELINE_NAME = "lambda-processor-pipeline";
private static final String RESPONSE_PAYLOAD = "{\"k1\":\"v1\",\"k2\":\"v2\"}";
private static MockedStatic<LambdaClientFactory> lambdaClientFactoryMockedStatic;
private final ObjectMapper objectMapper = new ObjectMapper(new YAMLFactory().enable(YAMLGenerator.Feature.USE_PLATFORM_LINE_BREAKS));
Expand Down Expand Up @@ -143,7 +142,9 @@ public void setUp() throws IOException {

InvokeResponse resp = InvokeResponse.builder().statusCode(200).payload(SdkBytes.fromUtf8String(RESPONSE_PAYLOAD)).build();
lambdaClientFactoryMockedStatic = Mockito.mockStatic(LambdaClientFactory.class);
when(LambdaClientFactory.createLambdaClient(any(LambdaProcessorConfig.class), any(AwsCredentialsSupplier.class))).thenReturn(lambdaClient);
when(LambdaClientFactory.createLambdaClient(any(AwsAuthenticationOptions.class),
eq(3),
any(AwsCredentialsSupplier.class))).thenReturn(lambdaClient);
lenient().when(lambdaClient.invoke(any(InvokeRequest.class))).thenReturn(resp);
try {
lenient().when(bufferFactory.getBuffer(lambdaClient, lambdaProcessorConfig.getFunctionName(), "RequestResponse")).thenReturn(buffer);
Expand Down
Loading

0 comments on commit 2fe00e8

Please sign in to comment.