diff --git a/data-prepper-plugins/aws-lambda/README.md b/data-prepper-plugins/aws-lambda/README.md index 4c49873350..d995bc5202 100644 --- a/data-prepper-plugins/aws-lambda/README.md +++ b/data-prepper-plugins/aws-lambda/README.md @@ -1,4 +1,39 @@ +# Lambda Processor + +This plugin enables you to send data from your Data Prepper pipeline directly to AWS Lambda functions for further processing. + +## Usage +```aidl +lambda-pipeline: +... + processor: + - aws_lambda: + aws: + region: "us-east-1" + sts_role_arn: "" + function_name: "uploadToS3Lambda" + max_retries: 3 + mode: "synchronous" + batch: + batch_key: "osi_key" + threshold: + event_count: 3 + maximum_size: 6mb + event_collect_timeout: 15s +``` + +## Developer Guide + +The integration tests for this plugin do not run as part of the Data Prepper build. +The following command runs the integration tests: + +``` +./gradlew :data-prepper-plugins:aws-lambda:integrationTest -Dtests.processor.lambda.region="us-east-1" -Dtests.processor.lambda.functionName="lambda_test_function" -Dtests.processor.lambda.sts_role_arn="arn:aws:iam::123456789012:role/dataprepper-role + +``` + + # Lambda Sink This plugin enables you to send data from your Data Prepper pipeline directly to AWS Lambda functions for further processing. diff --git a/data-prepper-plugins/aws-lambda/build.gradle b/data-prepper-plugins/aws-lambda/build.gradle index be9280e8c8..d59f4fd066 100644 --- a/data-prepper-plugins/aws-lambda/build.gradle +++ b/data-prepper-plugins/aws-lambda/build.gradle @@ -65,6 +65,10 @@ task integrationTest(type: Test) { systemProperty 'tests.lambda.sink.functionName', System.getProperty('tests.lambda.sink.functionName') systemProperty 'tests.lambda.sink.sts_role_arn', System.getProperty('tests.lambda.sink.sts_role_arn') + systemProperty 'tests.lambda.processor.region', System.getProperty('tests.lambda.processor.region') + systemProperty 'tests.lambda.processor.functionName', System.getProperty('tests.lambda.processor.functionName') + systemProperty 'tests.lambda.processor.sts_role_arn', System.getProperty('tests.lambda.processor.sts_role_arn') + filter { includeTestsMatching '*IT' } diff --git a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorServiceIT.java b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorServiceIT.java new file mode 100644 index 0000000000..686312cfe7 --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorServiceIT.java @@ -0,0 +1,165 @@ +package org.opensearch.dataprepper.plugins.lambda.processor; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator; +import io.micrometer.core.instrument.Counter; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import static org.mockito.Mockito.when; +import org.mockito.MockitoAnnotations; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.expression.ExpressionEvaluator; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.event.JacksonEvent; +import org.opensearch.dataprepper.model.log.JacksonLog; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.types.ByteCount; +import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory; +import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBufferFactory; +import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; +import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; +import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.lambda.LambdaClient; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; + +@ExtendWith(MockitoExtension.class) + +public class LambdaProcessorServiceIT { + + private LambdaClient lambdaClient; + private String functionName; + private String lambdaRegion; + private String role; + private BufferFactory bufferFactory; + @Mock + private LambdaProcessorConfig lambdaProcessorConfig; + @Mock + private BatchOptions batchOptions; + @Mock + private ThresholdOptions thresholdOptions; + @Mock + private AwsAuthenticationOptions awsAuthenticationOptions; + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + @Mock + private PluginMetrics pluginMetrics; + @Mock + private PluginFactory pluginFactory; + @Mock + private PluginSetting pluginSetting; + @Mock + private Counter numberOfRecordsSuccessCounter; + @Mock + private Counter numberOfRecordsFailedCounter; + @Mock + private ExpressionEvaluator expressionEvaluator; + + private final ObjectMapper objectMapper = new ObjectMapper(new YAMLFactory().enable(YAMLGenerator.Feature.USE_PLATFORM_LINE_BREAKS)); + + + @BeforeEach + public void setUp() throws Exception { + MockitoAnnotations.openMocks(this); + lambdaRegion = System.getProperty("tests.lambda.processor.region"); + functionName = System.getProperty("tests.lambda.processor.functionName"); + role = System.getProperty("tests.lambda.processor.sts_role_arn"); + + final Region region = Region.of(lambdaRegion); + + lambdaClient = LambdaClient.builder() + .region(Region.of(lambdaRegion)) + .build(); + + bufferFactory = new InMemoryBufferFactory(); + + when(pluginMetrics.counter(LambdaProcessor.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS)). + thenReturn(numberOfRecordsSuccessCounter); + when(pluginMetrics.counter(LambdaProcessor.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED)). + thenReturn(numberOfRecordsFailedCounter); + } + + + private static Record createRecord() { + final JacksonEvent event = JacksonLog.builder().withData("[{\"name\":\"test\"}]").build(); + return new Record<>(event); + } + + public LambdaProcessor createObjectUnderTest(final String config) throws JsonProcessingException { + + final LambdaProcessorConfig lambdaProcessorConfig = objectMapper.readValue(config, LambdaProcessorConfig.class); + return new LambdaProcessor(pluginMetrics,lambdaProcessorConfig,awsCredentialsSupplier,expressionEvaluator); + } + + public LambdaProcessor createObjectUnderTest(LambdaProcessorConfig lambdaSinkConfig) throws JsonProcessingException { + return new LambdaProcessor(pluginMetrics,lambdaSinkConfig,awsCredentialsSupplier,expressionEvaluator); + } + + + private static Collection> generateRecords(int numberOfRecords) { + List> recordList = new ArrayList<>(); + + for (int rows = 1; rows <= numberOfRecords; rows++) { + HashMap eventData = new HashMap<>(); + eventData.put("name", "Person" + rows); + eventData.put("age", Integer.toString(rows)); + + Record eventRecord = new Record<>(JacksonEvent.builder().withData(eventData).withEventType("event").build()); + recordList.add(eventRecord); + } + return recordList; + } + + @ParameterizedTest + @ValueSource(ints = {1,3}) + void verify_records_to_lambda_success(final int recordCount) throws Exception { + + when(lambdaProcessorConfig.getFunctionName()).thenReturn(functionName); + when(lambdaProcessorConfig.getMaxConnectionRetries()).thenReturn(3); + when(lambdaProcessorConfig.getMode()).thenReturn("synchronous"); + + LambdaProcessor objectUnderTest = createObjectUnderTest(lambdaProcessorConfig); + + Collection> recordsData = generateRecords(recordCount); + List> recordsResult = (List>) objectUnderTest.doExecute(recordsData); + Thread.sleep(Duration.ofSeconds(10).toMillis()); + + assertEquals(recordsResult.size(),recordCount); + } + + @ParameterizedTest + @ValueSource(ints = {1,3}) + void verify_records_with_batching_to_lambda(final int recordCount) throws JsonProcessingException, InterruptedException { + + when(lambdaProcessorConfig.getFunctionName()).thenReturn(functionName); + when(lambdaProcessorConfig.getMaxConnectionRetries()).thenReturn(3); + when(lambdaProcessorConfig.getMode()).thenReturn("synchronous"); + when(thresholdOptions.getEventCount()).thenReturn(1); + when(thresholdOptions.getMaximumSize()).thenReturn(ByteCount.parse("2mb")); + when(thresholdOptions.getEventCollectTimeOut()).thenReturn(Duration.parse("PT10s")); + when(batchOptions.getBatchKey()).thenReturn("lambda_batch_key"); + when(batchOptions.getThresholdOptions()).thenReturn(thresholdOptions); + when(lambdaProcessorConfig.getBatchOptions()).thenReturn(batchOptions); + + LambdaProcessor objectUnderTest = createObjectUnderTest(lambdaProcessorConfig); + Collection> records = generateRecords(recordCount); + Collection> recordsResult = objectUnderTest.doExecute(records); + Thread.sleep(Duration.ofSeconds(10).toMillis()); + assertEquals(recordsResult.size(),recordCount); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaClientFactory.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaClientFactory.java new file mode 100644 index 0000000000..3806939052 --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaClientFactory.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.lambda.processor; + +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.core.retry.RetryPolicy; +import software.amazon.awssdk.services.lambda.LambdaClient; + +public final class LambdaClientFactory { + private LambdaClientFactory() { + } + + public static LambdaClient createLambdaClient(final LambdaProcessorConfig lambdaProcessorConfig, final AwsCredentialsSupplier awsCredentialsSupplier) { + final AwsCredentialsOptions awsCredentialsOptions = convertToCredentialsOptions(lambdaProcessorConfig.getAwsAuthenticationOptions()); + final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(awsCredentialsOptions); + + return LambdaClient.builder().region(lambdaProcessorConfig.getAwsAuthenticationOptions().getAwsRegion()).credentialsProvider(awsCredentialsProvider).overrideConfiguration(createOverrideConfiguration(lambdaProcessorConfig)).build(); + + } + + private static ClientOverrideConfiguration createOverrideConfiguration(final LambdaProcessorConfig lambdaProcessorConfig) { + final RetryPolicy retryPolicy = RetryPolicy.builder().numRetries(lambdaProcessorConfig.getMaxConnectionRetries()).build(); + return ClientOverrideConfiguration.builder().retryPolicy(retryPolicy).build(); + } + + private static AwsCredentialsOptions convertToCredentialsOptions(final AwsAuthenticationOptions awsAuthenticationOptions) { + return AwsCredentialsOptions.builder().withRegion(awsAuthenticationOptions.getAwsRegion()).withStsRoleArn(awsAuthenticationOptions.getAwsStsRoleArn()).withStsExternalId(awsAuthenticationOptions.getAwsStsExternalId()).withStsHeaderOverrides(awsAuthenticationOptions.getAwsStsHeaderOverrides()).build(); + } +} diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java new file mode 100644 index 0000000000..01be1a40bc --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java @@ -0,0 +1,271 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.lambda.processor; + +import com.fasterxml.jackson.core.JsonParseException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Timer; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.expression.ExpressionEvaluator; +import static org.opensearch.dataprepper.logging.DataPrepperMarkers.EVENT; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin; +import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor; +import org.opensearch.dataprepper.model.codec.OutputCodec; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.event.EventHandle; +import org.opensearch.dataprepper.model.event.JacksonEvent; +import org.opensearch.dataprepper.model.processor.AbstractProcessor; +import org.opensearch.dataprepper.model.processor.Processor; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.types.ByteCount; +import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; +import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory; +import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBufferFactory; +import org.opensearch.dataprepper.plugins.lambda.common.codec.LambdaJsonCodec; +import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; +import org.opensearch.dataprepper.plugins.lambda.common.util.ThresholdCheck; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.services.lambda.LambdaClient; +import software.amazon.awssdk.services.lambda.model.InvokeResponse; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.LinkedList; +import java.util.List; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +@DataPrepperPlugin(name = "aws_lambda", pluginType = Processor.class, pluginConfigurationType = LambdaProcessorConfig.class) +public class LambdaProcessor extends AbstractProcessor, Record> { + + public static final String NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS = "lambdaProcessorObjectsEventsSucceeded"; + public static final String NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED = "lambdaProcessorObjectsEventsFailed"; + public static final String LAMBDA_LATENCY_METRIC = "lambdaLatency"; + public static final String REQUEST_PAYLOAD_SIZE = "requestPayloadSize"; + public static final String RESPONSE_PAYLOAD_SIZE = "responsePayloadSize"; + private static final String SYNC_INVOCATION_TYPE = "RequestResponse"; + private static final String ASYNC_INVOCATION_TYPE = "Event"; + private static final Logger LOG = LoggerFactory.getLogger(LambdaProcessor.class); + + private final String functionName; + private final String whenCondition; + private final ExpressionEvaluator expressionEvaluator; + private final Counter numberOfRecordsSuccessCounter; + private final Counter numberOfRecordsFailedCounter; + private final Timer lambdaLatencyMetric; + private final String invocationType; + private final Collection bufferedEventHandles; + private final List events; + private final BatchOptions batchOptions; + private final ObjectMapper objectMapper = new ObjectMapper(); + private final BufferFactory bufferFactory; + private final LambdaClient lambdaClient; + private final Boolean isBatchEnabled; + private final String batchKey; + Buffer currentBuffer; + private final AtomicLong requestPayload; + private final AtomicLong responsePayload; + private int maxEvents = 0; + private ByteCount maxBytes = null; + private Duration maxCollectionDuration = null; + private int maxRetries = 0; + private String mode = null; + private OutputCodec codec = null; + + @DataPrepperPluginConstructor + public LambdaProcessor(final PluginMetrics pluginMetrics, final LambdaProcessorConfig lambdaProcessorConfig, final AwsCredentialsSupplier awsCredentialsSupplier, final ExpressionEvaluator expressionEvaluator) { + super(pluginMetrics); + this.expressionEvaluator = expressionEvaluator; + this.numberOfRecordsSuccessCounter = pluginMetrics.counter(NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS); + this.numberOfRecordsFailedCounter = pluginMetrics.counter(NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED); + this.lambdaLatencyMetric = pluginMetrics.timer(LAMBDA_LATENCY_METRIC); + this.requestPayload = pluginMetrics.gauge(REQUEST_PAYLOAD_SIZE, new AtomicLong()); + this.responsePayload = pluginMetrics.gauge(RESPONSE_PAYLOAD_SIZE, new AtomicLong()); + functionName = lambdaProcessorConfig.getFunctionName(); + whenCondition = lambdaProcessorConfig.getWhenCondition(); + maxRetries = lambdaProcessorConfig.getMaxConnectionRetries(); + batchOptions = lambdaProcessorConfig.getBatchOptions(); + if (batchOptions != null) { + maxEvents = batchOptions.getThresholdOptions().getEventCount(); + maxBytes = batchOptions.getThresholdOptions().getMaximumSize(); + maxCollectionDuration = batchOptions.getThresholdOptions().getEventCollectTimeOut(); + batchKey = batchOptions.getBatchKey(); + isBatchEnabled = true; + LOG.info("maxEvents:" + maxEvents + " maxbytes:" + maxBytes + " maxDuration:" + maxCollectionDuration); + } else { + batchKey = null; + isBatchEnabled = false; + } + mode = lambdaProcessorConfig.getMode(); + // TODO - Support for Async mode to be added. + if (mode != null && mode.equalsIgnoreCase(LambdaProcessorConfig.SYNCHRONOUS_MODE)) { + invocationType = SYNC_INVOCATION_TYPE; + } else { + throw new RuntimeException("mode has to be synchronous or asynchronous"); + } + + codec = new LambdaJsonCodec(batchKey); + bufferedEventHandles = new LinkedList<>(); + events = new ArrayList(); + + lambdaClient = LambdaClientFactory.createLambdaClient(lambdaProcessorConfig, awsCredentialsSupplier); + + this.bufferFactory = new InMemoryBufferFactory(); + try { + currentBuffer = this.bufferFactory.getBuffer(lambdaClient, functionName, invocationType); + } catch (IOException e) { + throw new RuntimeException(e); + } + + } + + @Override + public Collection> doExecute(Collection> records) { + if (records.isEmpty()) { + return records; + } + + //lambda mutates event + List> resultRecords = new ArrayList<>(); + + for (Record record : records) { + final Event event = record.getData(); + + if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition, event)) { + continue; + } + + try { + if (currentBuffer.getEventCount() == 0) { + codec.start(currentBuffer.getOutputStream(), event, null); + } + codec.writeEvent(event, currentBuffer.getOutputStream()); + int count = currentBuffer.getEventCount() + 1; + currentBuffer.setEventCount(count); + + // flush to lambda and update result record + flushToLambdaIfNeeded(resultRecords); + } catch (Exception e) { + numberOfRecordsFailedCounter.increment(currentBuffer.getEventCount()); + LOG.error(EVENT, "There was an exception while processing Event [{}]" + ", number of events dropped={}", event, e, numberOfRecordsFailedCounter); + //reset buffer + try { + currentBuffer = bufferFactory.getBuffer(lambdaClient, functionName, invocationType); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + } + } + return resultRecords; + } + + @Override + public void prepareForShutdown() { + + } + + @Override + public boolean isReadyForShutdown() { + return false; + } + + @Override + public void shutdown() { + + } + + void flushToLambdaIfNeeded(List> resultRecords) throws InterruptedException, IOException { + + LOG.info("Flush to Lambda check: currentBuffer.size={}, currentBuffer.events={}, currentBuffer.duration={}", currentBuffer.getSize(), currentBuffer.getEventCount(), currentBuffer.getDuration()); + final AtomicReference errorMsgObj = new AtomicReference<>(); + + if (ThresholdCheck.checkThresholdExceed(currentBuffer, maxEvents, maxBytes, maxCollectionDuration, isBatchEnabled)) { + codec.complete(currentBuffer.getOutputStream()); + LOG.info("Writing {} to Lambda with {} events and size of {} bytes.", functionName, currentBuffer.getEventCount(), currentBuffer.getSize()); + LambdaResult lambdaResult = retryFlushToLambda(currentBuffer, errorMsgObj); + + if (lambdaResult.getIsUploadedToLambda()) { + LOG.info("Successfully flushed to Lambda {}.", functionName); + numberOfRecordsSuccessCounter.increment(currentBuffer.getEventCount()); + lambdaLatencyMetric.record(currentBuffer.getFlushLambdaSyncLatencyMetric()); + + requestPayload.set(currentBuffer.getPayloadRequestSyncSize()); + responsePayload.set(currentBuffer.getPayloadResponseSyncSize()); + + InvokeResponse lambdaResponse = lambdaResult.getLambdaResponse(); + Event lambdaEvent = convertLambdaResponseToEvent(lambdaResponse); + resultRecords.add(new Record<>(lambdaEvent)); + //reset buffer after flush + currentBuffer = bufferFactory.getBuffer(lambdaClient, functionName, invocationType); + } else { + LOG.error("Failed to save to Lambda {}", functionName); + numberOfRecordsFailedCounter.increment(currentBuffer.getEventCount()); + } + } + } + + LambdaResult retryFlushToLambda(Buffer currentBuffer, final AtomicReference errorMsgObj) throws InterruptedException { + boolean isUploadedToLambda = Boolean.FALSE; + int retryCount = maxRetries; + do { + + try { + InvokeResponse resp = currentBuffer.flushToLambdaSync(); + isUploadedToLambda = Boolean.TRUE; + LambdaResult lambdaResult = LambdaResult.builder().withIsUploadedToLambda(isUploadedToLambda).withLambdaResponse(resp).build(); + return lambdaResult; + } catch (AwsServiceException | SdkClientException e) { + errorMsgObj.set(e.getMessage()); + LOG.error("Exception occurred while uploading records to lambda. Retry countdown : {} | exception:", retryCount, e); + --retryCount; + if (retryCount == 0) { + LambdaResult lambdaResult = LambdaResult.builder().withIsUploadedToLambda(isUploadedToLambda).withLambdaResponse(null).build(); + return lambdaResult; + } + Thread.sleep(5000); + } + } while (!isUploadedToLambda); + + LambdaResult lambdaResult = LambdaResult.builder().withIsUploadedToLambda(false).withLambdaResponse(null).build(); + return lambdaResult; + } + + Event convertLambdaResponseToEvent(InvokeResponse lambdaResponse) { + try { + int statusCode = lambdaResponse.statusCode(); + if (statusCode < 200 || statusCode >= 300) { + throw new RuntimeException("Lambda invocation failed with status code: " + statusCode); + } + + SdkBytes payload = lambdaResponse.payload(); + if (payload != null) { + String payloadJsonString = payload.asString(StandardCharsets.UTF_8); + + JsonNode jsonNode = null; + try { + jsonNode = objectMapper.readTree(payloadJsonString); + } catch (JsonParseException e) { + throw new RuntimeException("payload output is not json formatted"); + } + return JacksonEvent.builder().withEventType("event").withData(jsonNode).build(); + } + } catch (Exception e) { + LOG.error("Error converting Lambda response to Event", e); + throw new RuntimeException("Error converting Lambda response to Event"); + } + return null; + } +} \ No newline at end of file diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorConfig.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorConfig.java new file mode 100644 index 0000000000..7a8abdf4a4 --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorConfig.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.dataprepper.plugins.lambda.processor; + +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.Valid; +import jakarta.validation.constraints.NotEmpty; +import jakarta.validation.constraints.NotNull; +import jakarta.validation.constraints.Size; +import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; +import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; + +public class LambdaProcessorConfig { + + public static final String SYNCHRONOUS_MODE = "RequestResponse"; + public static final String ASYNCHRONOUS_MODE = "Event"; + private static final int DEFAULT_CONNECTION_RETRIES = 3; + + @JsonProperty("aws") + @NotNull + @Valid + private AwsAuthenticationOptions awsAuthenticationOptions; + + @JsonProperty("function_name") + @NotEmpty + @NotNull + @Size(min = 3, max = 500, message = "function name length should be at least 3 characters") + private String functionName; + + @JsonProperty("max_retries") + private int maxConnectionRetries = DEFAULT_CONNECTION_RETRIES; + + @JsonProperty("mode") + private String mode = SYNCHRONOUS_MODE; + + @JsonProperty("batch") + private BatchOptions batchOptions; + + @JsonProperty("lambda_when") + private String whenCondition; + + public AwsAuthenticationOptions getAwsAuthenticationOptions() { + return awsAuthenticationOptions; + } + + public BatchOptions getBatchOptions(){return batchOptions;} + + public String getFunctionName() { + return functionName; + } + + public int getMaxConnectionRetries() { + return maxConnectionRetries; + } + + public String getMode(){return mode;} + + public String getWhenCondition() { + return whenCondition; + } +} \ No newline at end of file diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaResult.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaResult.java new file mode 100644 index 0000000000..d7f7b86168 --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaResult.java @@ -0,0 +1,18 @@ +package org.opensearch.dataprepper.plugins.lambda.processor; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NoArgsConstructor; +import software.amazon.awssdk.services.lambda.model.InvokeResponse; + +@Builder(setterPrefix = "with") +@Getter +@AllArgsConstructor +@NoArgsConstructor +public class LambdaResult { + + private InvokeResponse lambdaResponse; + + private Boolean isUploadedToLambda; +} diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaClientFactoryTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaClientFactoryTest.java new file mode 100644 index 0000000000..4af09e759d --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaClientFactoryTest.java @@ -0,0 +1,93 @@ +package org.opensearch.dataprepper.plugins.lambda.processor; + +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; +import static org.mockito.ArgumentMatchers.any; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; +import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaClientFactory.createLambdaClient; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.lambda.LambdaClient; +import software.amazon.awssdk.services.lambda.LambdaClientBuilder; + +import java.util.Map; +import java.util.UUID; + +@ExtendWith(MockitoExtension.class) +class LambdaClientFactoryTest { + @Mock + private LambdaProcessorConfig lambdaProcessorConfig; + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + + @Mock + private AwsAuthenticationOptions awsAuthenticationOptions; + + @BeforeEach + void setUp() { + when(lambdaProcessorConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); + } + + @Test + void createLambdaClient_with_real_LambdaClient() { + when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.US_EAST_1); + final LambdaClient lambdaClient = createLambdaClient(lambdaProcessorConfig, awsCredentialsSupplier); + + assertThat(lambdaClient, notNullValue()); + } + + @ParameterizedTest + @ValueSource(strings = {"us-east-1", "us-west-2", "eu-central-1"}) + void createlambdaClient_provides_correct_inputs(final String regionString) { + final Region region = Region.of(regionString); + final String stsRoleArn = UUID.randomUUID().toString(); + final Map stsHeaderOverrides = Map.of(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + when(awsAuthenticationOptions.getAwsRegion()).thenReturn(region); + when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(stsRoleArn); + when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(stsHeaderOverrides); + + final AwsCredentialsProvider expectedCredentialsProvider = mock(AwsCredentialsProvider.class); + when(awsCredentialsSupplier.getProvider(any())).thenReturn(expectedCredentialsProvider); + + final LambdaClientBuilder lambdaClientBuilder = mock(LambdaClientBuilder.class); + when(lambdaClientBuilder.region(region)).thenReturn(lambdaClientBuilder); + when(lambdaClientBuilder.credentialsProvider(any())).thenReturn(lambdaClientBuilder); + when(lambdaClientBuilder.overrideConfiguration(any(ClientOverrideConfiguration.class))).thenReturn(lambdaClientBuilder); + try (final MockedStatic lambdaClientMockedStatic = mockStatic(LambdaClient.class)) { + lambdaClientMockedStatic.when(LambdaClient::builder).thenReturn(lambdaClientBuilder); + createLambdaClient(lambdaProcessorConfig, awsCredentialsSupplier); + } + + final ArgumentCaptor credentialsProviderArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsProvider.class); + verify(lambdaClientBuilder).credentialsProvider(credentialsProviderArgumentCaptor.capture()); + + final AwsCredentialsProvider actualCredentialsProvider = credentialsProviderArgumentCaptor.getValue(); + + assertThat(actualCredentialsProvider, equalTo(expectedCredentialsProvider)); + + final ArgumentCaptor optionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class); + verify(awsCredentialsSupplier).getProvider(optionsArgumentCaptor.capture()); + + final AwsCredentialsOptions actualCredentialsOptions = optionsArgumentCaptor.getValue(); + assertThat(actualCredentialsOptions.getRegion(), equalTo(region)); + assertThat(actualCredentialsOptions.getStsRoleArn(), equalTo(stsRoleArn)); + assertThat(actualCredentialsOptions.getStsHeaderOverrides(), equalTo(stsHeaderOverrides)); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorConfigTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorConfigTest.java new file mode 100644 index 0000000000..3f94ccf6fe --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorConfigTest.java @@ -0,0 +1,32 @@ +package org.opensearch.dataprepper.plugins.lambda.processor; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.regions.Region; + +public class LambdaProcessorConfigTest { + + public static final int DEFAULT_MAX_RETRIES = 3; + private final ObjectMapper objectMapper = new ObjectMapper(new YAMLFactory().enable(YAMLGenerator.Feature.USE_PLATFORM_LINE_BREAKS)); + + @Test + void lambda_processor_default_max_connection_retries_test() { + assertThat(new LambdaProcessorConfig().getMaxConnectionRetries(), equalTo(DEFAULT_MAX_RETRIES)); + } + + @Test + public void testAwsAuthenticationOptionsNotNull() throws JsonProcessingException { + final String config = " function_name: test_function\n" + " aws:\n" + " region: ap-south-1\n" + " sts_role_arn: arn:aws:iam::524239988912:role/app-test\n" + " sts_header_overrides: {\"test\":\"test\"}\n" + " max_retries: 10\n"; + final LambdaProcessorConfig lambdaProcessorConfig = objectMapper.readValue(config, LambdaProcessorConfig.class); + + assertThat(lambdaProcessorConfig.getMaxConnectionRetries(), equalTo(10)); + assertThat(lambdaProcessorConfig.getAwsAuthenticationOptions().getAwsRegion(), equalTo(Region.AP_SOUTH_1)); + assertThat(lambdaProcessorConfig.getAwsAuthenticationOptions().getAwsStsRoleArn(), equalTo("arn:aws:iam::524239988912:role/app-test")); + assertThat(lambdaProcessorConfig.getAwsAuthenticationOptions().getAwsStsHeaderOverrides().get("test"), equalTo("test")); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java new file mode 100644 index 0000000000..aa6be490d8 --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java @@ -0,0 +1,281 @@ +package org.opensearch.dataprepper.plugins.lambda.processor; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Timer; +import org.junit.jupiter.api.AfterEach; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import org.mockito.MockitoAnnotations; +import org.mockito.Spy; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.expression.ExpressionEvaluator; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.event.JacksonEvent; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.types.ByteCount; +import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; +import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory; +import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; +import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; +import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions; +import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor.LAMBDA_LATENCY_METRIC; +import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED; +import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS; +import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor.REQUEST_PAYLOAD_SIZE; +import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor.RESPONSE_PAYLOAD_SIZE; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.lambda.LambdaClient; +import software.amazon.awssdk.services.lambda.model.InvokeRequest; +import software.amazon.awssdk.services.lambda.model.InvokeResponse; + +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicLong; + +@ExtendWith(MockitoExtension.class) +public class LambdaProcessorTest { + private static final String PROCESSOR_PLUGIN_NAME = "aws_lambda"; + private static final String PROCESSOR_PIPELINE_NAME = "lambda-processor-pipeline"; + private static final String RESPONSE_PAYLOAD = "{\"k1\":\"v1\",\"k2\":\"v2\"}"; + private static MockedStatic lambdaClientFactoryMockedStatic; + private final ObjectMapper objectMapper = new ObjectMapper(new YAMLFactory().enable(YAMLGenerator.Feature.USE_PLATFORM_LINE_BREAKS)); + +// @Mock +// private PluginSetting pluginSetting; + + @Mock + private PluginMetrics pluginMetrics; + + @Mock + private ExpressionEvaluator expressionEvaluator; + + @Mock + private LambdaProcessorConfig lambdaProcessorConfig; + + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + + @Mock + private LambdaClient lambdaClient; + + @Spy + private BufferFactory bufferFactory; + + @Mock + private Buffer buffer; + + @Mock + private Counter numberOfRecordsSuccessCounter; + + @Mock + private Counter numberOfRecordsFailedCounter; + + @Mock + private Counter numberOfRecordsDroppedCounter; + + @Mock + private Timer lambdaLatencyMetric; + + @Mock + private AtomicLong requestPayload; + + @Mock + private AtomicLong responsePayload; + + private LambdaProcessor createObjectUnderTest() { + return new LambdaProcessor(pluginMetrics, lambdaProcessorConfig, awsCredentialsSupplier, expressionEvaluator); + } + + @BeforeEach + public void setUp() throws IOException { + MockitoAnnotations.openMocks(this); + + BatchOptions batchOptions = mock(BatchOptions.class); + ThresholdOptions thresholdOptions = mock(ThresholdOptions.class); + AwsAuthenticationOptions awsAuthenticationOptions = mock(AwsAuthenticationOptions.class); + + lenient().when(lambdaProcessorConfig.getFunctionName()).thenReturn("test-function1"); + lenient().when(lambdaProcessorConfig.getMaxConnectionRetries()).thenReturn(3); + lenient().when(lambdaProcessorConfig.getMode()).thenReturn("requestresponse"); + + lenient().when(thresholdOptions.getEventCount()).thenReturn(10); + lenient().when(thresholdOptions.getMaximumSize()).thenReturn(ByteCount.ofBytes(6)); + lenient().when(thresholdOptions.getEventCollectTimeOut()).thenReturn(Duration.ofSeconds(5)); + + lenient().when(batchOptions.getThresholdOptions()).thenReturn(thresholdOptions); + lenient().when(batchOptions.getBatchKey()).thenReturn("key"); + lenient().when(lambdaProcessorConfig.getBatchOptions()).thenReturn(batchOptions); + + lenient().when(lambdaProcessorConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); + lenient().when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.of("test-region")); + + lenient().when(pluginMetrics.counter(NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS)).thenReturn(numberOfRecordsDroppedCounter); + lenient().when(pluginMetrics.counter(NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED)).thenReturn(numberOfRecordsFailedCounter); + lenient().when(pluginMetrics.timer(LAMBDA_LATENCY_METRIC)).thenReturn(lambdaLatencyMetric); + lenient().when(pluginMetrics.gauge(eq(REQUEST_PAYLOAD_SIZE), any(AtomicLong.class))).thenReturn(requestPayload); + lenient().when(pluginMetrics.gauge(eq(RESPONSE_PAYLOAD_SIZE), any(AtomicLong.class))).thenReturn(responsePayload); + + InvokeResponse resp = InvokeResponse.builder().statusCode(200).payload(SdkBytes.fromUtf8String(RESPONSE_PAYLOAD)).build(); + lambdaClientFactoryMockedStatic = Mockito.mockStatic(LambdaClientFactory.class); + when(LambdaClientFactory.createLambdaClient(any(LambdaProcessorConfig.class), any(AwsCredentialsSupplier.class))).thenReturn(lambdaClient); + lenient().when(lambdaClient.invoke(any(InvokeRequest.class))).thenReturn(resp); + try { + lenient().when(bufferFactory.getBuffer(lambdaClient, lambdaProcessorConfig.getFunctionName(), "RequestResponse")).thenReturn(buffer); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @AfterEach + public void cleanup() { + lambdaClientFactoryMockedStatic.close(); + } + + @Test + public void testDoExecuteWithEmptyRecords() { + Collection> records = Collections.emptyList(); + LambdaProcessor lambdaProcessor = createObjectUnderTest(); + Collection> result = lambdaProcessor.doExecute(records); + + assertTrue(result.isEmpty()); + } + + @Test + public void testDoExecute() throws JsonProcessingException { + Event event = JacksonEvent.builder().withEventType("event").withData("{\"status\":true}").build(); + Record record = new Record<>(event); + Collection> records = List.of(record); + + InvokeResponse invokeResponse = InvokeResponse.builder().statusCode(200).payload(SdkBytes.fromUtf8String(RESPONSE_PAYLOAD)).build(); + + LambdaProcessor lambdaProcessor = createObjectUnderTest(); + Collection> resultRecords = lambdaProcessor.doExecute(records); + + assertEquals(1, resultRecords.size()); + Record resultRecord = resultRecords.iterator().next(); + + ObjectMapper objectMapper = new ObjectMapper(); + JsonNode responseJsonNode = objectMapper.readTree(RESPONSE_PAYLOAD); + assertEquals(responseJsonNode, resultRecord.getData().getJsonNode()); + } + + @Test + public void testDoExecute_withException() { + List> records = new ArrayList<>(); + Event event = mock(Event.class); + records.add(new Record<>(event)); + + lenient().when(buffer.getOutputStream()).thenThrow(new RuntimeException("Test exception")); + + LambdaProcessor lambdaProcessor = createObjectUnderTest(); + Collection> result = lambdaProcessor.doExecute(records); + + assertEquals(1, result.size()); + verify(buffer, times(0)).flushToLambdaSync(); + } + + @Test + public void testFlushToLambdaIfNeeded_withThresholdNotExceeded() throws Exception { + lenient().when(buffer.getSize()).thenReturn(100L); + lenient().when(buffer.getEventCount()).thenReturn(1); + lenient().when(buffer.getDuration()).thenReturn(Duration.ofSeconds(1)); + + LambdaProcessor lambdaProcessor = createObjectUnderTest(); + List> records = mock(ArrayList.class); + lambdaProcessor.flushToLambdaIfNeeded(records); + verify(buffer, times(0)).flushToLambdaSync(); + verify(records, times(0)).add(any(Record.class)); + } + + @Test + public void testConvertLambdaResponseToEvent_withNon200StatusCode() { + InvokeResponse response = InvokeResponse.builder().statusCode(500).payload(SdkBytes.fromUtf8String(RESPONSE_PAYLOAD)).build(); + lenient().when(lambdaClient.invoke(any(InvokeRequest.class))).thenReturn(response); + + LambdaProcessor lambdaProcessor = createObjectUnderTest(); + + Exception exception = assertThrows(RuntimeException.class, () -> { + lambdaProcessor.convertLambdaResponseToEvent(response); + }); + assertEquals("Error converting Lambda response to Event", exception.getMessage()); + } + + @Test + public void testDoExecute_withNonSuccessfulStatusCode() { + InvokeResponse response = InvokeResponse.builder().statusCode(500).payload(SdkBytes.fromUtf8String(RESPONSE_PAYLOAD)).build(); + lenient().when(lambdaClient.invoke(any(InvokeRequest.class))).thenReturn(response); + + LambdaProcessor lambdaProcessor = createObjectUnderTest(); + + List> records = new ArrayList<>(); + Event event = mock(Event.class); + records.add(new Record<>(event)); + List> resultRecords = (List>) lambdaProcessor.doExecute(records); + + verify(lambdaClient, times(1)).invoke(any(InvokeRequest.class)); + + //event should be dropped on failure + assertEquals(resultRecords.size(), 0); + verify(numberOfRecordsFailedCounter, times(1)).increment(1); + //check if buffer is reset + assertEquals(buffer.getSize(), 0); + } + + @Test + public void testConvertLambdaResponseToEvent() throws JsonProcessingException { + InvokeResponse response = InvokeResponse.builder().statusCode(200).payload(SdkBytes.fromUtf8String(RESPONSE_PAYLOAD)).build(); + lenient().when(lambdaClient.invoke(any(InvokeRequest.class))).thenReturn(response); + + LambdaProcessor lambdaProcessor = createObjectUnderTest(); + Event eventResponse = lambdaProcessor.convertLambdaResponseToEvent(response); + + ObjectMapper objectMapper = new ObjectMapper(); + JsonNode jsonNode = objectMapper.readTree(RESPONSE_PAYLOAD); + Event event = JacksonEvent.builder().withEventType("event").withData(jsonNode).build(); + assertEquals(event.getJsonNode(), eventResponse.getJsonNode()); + } + + @Test + public void testDoExecute_WithConfig() throws JsonProcessingException { + final String config = " function_name: test_function\n" + " mode: requestresponse\n" + " aws:\n" + " region: us-east-1\n" + " sts_role_arn: arn:aws:iam::524239988912:role/app-test\n" + " sts_header_overrides: {\"test\":\"test\"}\n" + " max_retries: 3\n"; + + this.lambdaProcessorConfig = objectMapper.readValue(config, LambdaProcessorConfig.class); + + Event event = JacksonEvent.builder().withEventType("event").withData("{\"status\":true}").build(); + Record record = new Record<>(event); + Collection> records = List.of(record); + + InvokeResponse invokeResponse = InvokeResponse.builder().statusCode(200).payload(SdkBytes.fromUtf8String(RESPONSE_PAYLOAD)).build(); + + LambdaProcessor lambdaProcessor = createObjectUnderTest(); + Collection> resultRecords = lambdaProcessor.doExecute(records); + verify(lambdaClient, times(1)).invoke(any(InvokeRequest.class)); + assertEquals(resultRecords.size(), 1); + } +}