diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java index 053cd3ff76ffe..3c8472c794ec6 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java @@ -26,6 +26,7 @@ import java.nio.channels.WritableByteChannel; import java.nio.file.Files; import java.nio.file.Path; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -35,20 +36,27 @@ import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils; import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.transforms.Convert; import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; import org.apache.beam.sdk.schemas.utils.JsonUtils; -import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.FinishBundle; +import org.apache.beam.sdk.transforms.DoFn.ProcessElement; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.Values; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.Row; -import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings; @@ -67,6 +75,11 @@ public class KafkaReadSchemaTransformProvider private static final Logger LOG = LoggerFactory.getLogger(KafkaReadSchemaTransformProvider.class); + public static final TupleTag OUTPUT_TAG = new TupleTag() {}; + public static final TupleTag ERROR_TAG = new TupleTag() {}; + public static final Schema ERROR_SCHEMA = + Schema.builder().addStringField("error").addNullableByteArrayField("row").build(); + final Boolean isTest; final Integer testTimeoutSecs; @@ -102,7 +115,37 @@ public List inputCollectionNames() { @Override public List outputCollectionNames() { - return Lists.newArrayList("output"); + return Arrays.asList("output", "errors"); + } + + public static class ErrorFn extends DoFn { + private SerializableFunction valueMapper; + private Counter errorCounter; + private Long errorsInBundle = 0L; + + public ErrorFn(String name, SerializableFunction valueMapper) { + this.errorCounter = Metrics.counter(KafkaReadSchemaTransformProvider.class, name); + this.valueMapper = valueMapper; + } + + @ProcessElement + public void process(@DoFn.Element byte[] msg, MultiOutputReceiver receiver) { + try { + receiver.get(OUTPUT_TAG).output(valueMapper.apply(msg)); + } catch (Exception e) { + errorsInBundle += 1; + LOG.warn("Error while parsing the element", e); + receiver + .get(ERROR_TAG) + .output(Row.withSchema(ERROR_SCHEMA).addValues(e.toString(), msg).build()); + } + } + + @FinishBundle + public void finish(FinishBundleContext c) { + errorCounter.inc(errorsInBundle); + errorsInBundle = 0L; + } } private static class KafkaReadSchemaTransform implements SchemaTransform { @@ -160,14 +203,19 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { kafkaRead = kafkaRead.withMaxReadTime(Duration.standardSeconds(testTimeoutSeconds)); } + PCollection kafkaValues = + input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create()); + + PCollectionTuple outputTuple = + kafkaValues.apply( + ParDo.of(new ErrorFn("Kafka-read-error-counter", valueMapper)) + .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG))); + return PCollectionRowTuple.of( "output", - input - .getPipeline() - .apply(kafkaRead.withoutMetadata()) - .apply(Values.create()) - .apply(MapElements.into(TypeDescriptors.rows()).via(valueMapper)) - .setRowSchema(beamSchema)); + outputTuple.get(OUTPUT_TAG).setRowSchema(beamSchema), + "errors", + outputTuple.get(ERROR_TAG).setRowSchema(ERROR_SCHEMA)); } }; } else { diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaDlqTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaDlqTest.java new file mode 100644 index 0000000000000..48fe969bc9f35 --- /dev/null +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaDlqTest.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import org.apache.beam.sdk.io.kafka.KafkaReadSchemaTransformProvider.ErrorFn; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.utils.JsonUtils; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class KafkaDlqTest { + + private static final TupleTag OUTPUTTAG = KafkaReadSchemaTransformProvider.OUTPUT_TAG; + private static final TupleTag ERRORTAG = KafkaReadSchemaTransformProvider.ERROR_TAG; + + private static final Schema BEAMSCHEMA = + Schema.of(Schema.Field.of("name", Schema.FieldType.STRING)); + private static final Schema ERRORSCHEMA = KafkaReadSchemaTransformProvider.ERROR_SCHEMA; + + private static final List ROWS = + Arrays.asList( + Row.withSchema(BEAMSCHEMA).withFieldValue("name", "a").build(), + Row.withSchema(BEAMSCHEMA).withFieldValue("name", "b").build(), + Row.withSchema(BEAMSCHEMA).withFieldValue("name", "c").build()); + + private static List messages; + + private static List messagesWithError; + + final SerializableFunction valueMapper = + JsonUtils.getJsonBytesToRowFunction(BEAMSCHEMA); + + @Rule public transient TestPipeline p = TestPipeline.create(); + + @Test + public void testKafkaErrorFnSuccess() throws Exception { + try { + messages = + Arrays.asList( + "{\"name\":\"a\"}".getBytes("UTF8"), + "{\"name\":\"b\"}".getBytes("UTF8"), + "{\"name\":\"c\"}".getBytes("UTF8")); + } catch (Exception e) { + } + PCollection input = p.apply(Create.of(messages)); + PCollectionTuple output = + input.apply( + ParDo.of(new ErrorFn("Kafka-read-error-counter", valueMapper)) + .withOutputTags(OUTPUTTAG, TupleTagList.of(ERRORTAG))); + + output.get(OUTPUTTAG).setRowSchema(BEAMSCHEMA); + output.get(ERRORTAG).setRowSchema(ERRORSCHEMA); + + PAssert.that(output.get(OUTPUTTAG)).containsInAnyOrder(ROWS); + p.run().waitUntilFinish(); + } + + @Test + public void testKafkaErrorFnFailure() throws Exception { + try { + messagesWithError = + Arrays.asList( + "{\"error\":\"a\"}".getBytes("UTF8"), + "{\"error\":\"b\"}".getBytes("UTF8"), + "{\"error\":\"c\"}".getBytes("UTF8")); + } catch (Exception e) { + } + PCollection input = p.apply(Create.of(messagesWithError)); + PCollectionTuple output = + input.apply( + ParDo.of(new ErrorFn("Read-Error-Counter", valueMapper)) + .withOutputTags(OUTPUTTAG, TupleTagList.of(ERRORTAG))); + + output.get(OUTPUTTAG).setRowSchema(BEAMSCHEMA); + output.get(ERRORTAG).setRowSchema(ERRORSCHEMA); + + PCollection count = output.get(ERRORTAG).apply("error_count", Count.globally()); + + PAssert.that(count).containsInAnyOrder(Collections.singletonList(3L)); + + p.run().waitUntilFinish(); + } +} diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java index b90585ebd79d0..8fdbd12212df5 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java @@ -88,7 +88,7 @@ public void testFindTransformAndMakeItWork() { .filter(provider -> provider.getClass() == KafkaReadSchemaTransformProvider.class) .collect(Collectors.toList()); SchemaTransformProvider kafkaProvider = providers.get(0); - assertEquals(kafkaProvider.outputCollectionNames(), Lists.newArrayList("output")); + assertEquals(kafkaProvider.outputCollectionNames(), Lists.newArrayList("output", "errors")); assertEquals(kafkaProvider.inputCollectionNames(), Lists.newArrayList()); assertEquals(