Skip to content

Commit

Permalink
BigQuery trim schema with selected fields (#32514)
Browse files Browse the repository at this point in the history
* BigQuery trim schema with selected fields

Trim BQ schema directly instead of converting to avro schema and back

* Add support for nested fields

* Adapt tests

* Apply review comments
  • Loading branch information
RustedBones authored Sep 27, 2024
1 parent 4309675 commit 271ea43
Show file tree
Hide file tree
Showing 11 changed files with 117 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.avro.Conversions;
import org.apache.avro.LogicalType;
import org.apache.avro.LogicalTypes;
Expand Down Expand Up @@ -176,41 +174,6 @@ private static String formatTime(long timeMicros) {
return LocalTime.ofNanoOfDay(timeMicros * 1000).format(formatter);
}

static TableSchema trimBigQueryTableSchema(TableSchema inputSchema, Schema avroSchema) {
List<TableFieldSchema> subSchemas =
inputSchema.getFields().stream()
.flatMap(fieldSchema -> mapTableFieldSchema(fieldSchema, avroSchema))
.collect(Collectors.toList());

return new TableSchema().setFields(subSchemas);
}

private static Stream<TableFieldSchema> mapTableFieldSchema(
TableFieldSchema fieldSchema, Schema avroSchema) {
Field avroFieldSchema = avroSchema.getField(fieldSchema.getName());
if (avroFieldSchema == null) {
return Stream.empty();
} else if (avroFieldSchema.schema().getType() != Type.RECORD) {
return Stream.of(fieldSchema);
}

List<TableFieldSchema> subSchemas =
fieldSchema.getFields().stream()
.flatMap(subSchema -> mapTableFieldSchema(subSchema, avroFieldSchema.schema()))
.collect(Collectors.toList());

TableFieldSchema output =
new TableFieldSchema()
.setCategories(fieldSchema.getCategories())
.setDescription(fieldSchema.getDescription())
.setFields(subSchemas)
.setMode(fieldSchema.getMode())
.setName(fieldSchema.getName())
.setType(fieldSchema.getType());

return Stream.of(output);
}

/**
* Utility function to convert from an Avro {@link GenericRecord} to a BigQuery {@link TableRow}.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1275,8 +1275,12 @@ public PCollection<T> expand(PBegin input) {

Schema beamSchema = null;
if (getTypeDescriptor() != null && getToBeamRowFn() != null && getFromBeamRowFn() != null) {
beamSchema = sourceDef.getBeamSchema(bqOptions);
beamSchema = getFinalSchema(beamSchema, getSelectedFields());
TableSchema tableSchema = sourceDef.getTableSchema(bqOptions);
ValueProvider<List<String>> selectedFields = getSelectedFields();
if (selectedFields != null && selectedFields.isAccessible()) {
tableSchema = BigQueryUtils.trimSchema(tableSchema, selectedFields.get());
}
beamSchema = BigQueryUtils.fromTableSchema(tableSchema);
}

final Coder<T> coder = inferCoder(p.getCoderRegistry());
Expand Down Expand Up @@ -1441,24 +1445,6 @@ void cleanup(PassThroughThenCleanup.ContextContainer c) throws Exception {
return rows;
}

private static Schema getFinalSchema(
Schema beamSchema, ValueProvider<List<String>> selectedFields) {
List<Schema.Field> flds =
beamSchema.getFields().stream()
.filter(
field -> {
if (selectedFields != null
&& selectedFields.isAccessible()
&& selectedFields.get() != null) {
return selectedFields.get().contains(field.getName());
} else {
return true;
}
})
.collect(Collectors.toList());
return Schema.builder().addFields(flds).build();
}

private PCollection<T> expandForDirectRead(
PBegin input, Coder<T> outputCoder, Schema beamSchema, BigQueryOptions bqOptions) {
ValueProvider<TableReference> tableProvider = getTableProvider();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.apache.beam.sdk.extensions.avro.io.AvroSource;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryResourceNaming.JobType;
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;
Expand Down Expand Up @@ -178,7 +177,7 @@ public <T> BigQuerySourceBase<T> toSource(

/** {@inheritDoc} */
@Override
public Schema getBeamSchema(BigQueryOptions bqOptions) {
public TableSchema getTableSchema(BigQueryOptions bqOptions) {
try {
JobStatistics stats =
BigQueryQueryHelper.dryRunQueryIfNeeded(
Expand All @@ -189,8 +188,7 @@ public Schema getBeamSchema(BigQueryOptions bqOptions) {
flattenResults,
useLegacySql,
location);
TableSchema tableSchema = stats.getQuery().getSchema();
return BigQueryUtils.fromTableSchema(tableSchema);
return stats.getQuery().getSchema();
} catch (IOException | InterruptedException | NullPointerException e) {
throw new BigQuerySchemaRetrievalException(
"Exception while trying to retrieve schema of query", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import java.io.Serializable;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.extensions.avro.io.AvroSource;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.SerializableFunction;

/**
Expand All @@ -46,11 +45,11 @@ <T> BigQuerySourceBase<T> toSource(
boolean useAvroLogicalTypes);

/**
* Extract the Beam {@link Schema} corresponding to this source.
* Extract the {@link TableSchema} corresponding to this source.
*
* @param bqOptions BigQueryOptions
* @return Beam schema of the source
* @return table schema of the source
* @throws BigQuerySchemaRetrievalException if schema retrieval fails
*/
Schema getBeamSchema(BigQueryOptions bqOptions);
TableSchema getTableSchema(BigQueryOptions bqOptions);
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@
import com.google.cloud.bigquery.storage.v1.ReadStream;
import java.io.IOException;
import java.util.List;
import org.apache.avro.Schema;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.extensions.arrow.ArrowConversion;
import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils;
import org.apache.beam.sdk.io.BoundedSource;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.StorageClient;
import org.apache.beam.sdk.metrics.Lineage;
Expand Down Expand Up @@ -126,17 +123,16 @@ public List<BigQueryStorageStreamSource<T>> split(
}
}

if (selectedFieldsProvider != null || rowRestrictionProvider != null) {
ReadSession.TableReadOptions.Builder tableReadOptionsBuilder =
ReadSession.TableReadOptions.newBuilder();
if (selectedFieldsProvider != null) {
tableReadOptionsBuilder.addAllSelectedFields(selectedFieldsProvider.get());
}
if (rowRestrictionProvider != null) {
tableReadOptionsBuilder.setRowRestriction(rowRestrictionProvider.get());
}
readSessionBuilder.setReadOptions(tableReadOptionsBuilder);
ReadSession.TableReadOptions.Builder tableReadOptionsBuilder =
ReadSession.TableReadOptions.newBuilder();
if (selectedFieldsProvider != null && selectedFieldsProvider.isAccessible()) {
tableReadOptionsBuilder.addAllSelectedFields(selectedFieldsProvider.get());
}
if (rowRestrictionProvider != null && rowRestrictionProvider.isAccessible()) {
tableReadOptionsBuilder.setRowRestriction(rowRestrictionProvider.get());
}
readSessionBuilder.setReadOptions(tableReadOptionsBuilder);

if (format != null) {
readSessionBuilder.setDataFormat(format);
}
Expand Down Expand Up @@ -182,30 +178,18 @@ public List<BigQueryStorageStreamSource<T>> split(
LOG.info("Read session returned {} streams", readSession.getStreamsList().size());
}

Schema sessionSchema;
if (readSession.getDataFormat() == DataFormat.ARROW) {
org.apache.arrow.vector.types.pojo.Schema schema =
ArrowConversion.arrowSchemaFromInput(
readSession.getArrowSchema().getSerializedSchema().newInput());
org.apache.beam.sdk.schemas.Schema beamSchema =
ArrowConversion.ArrowSchemaTranslator.toBeamSchema(schema);
sessionSchema = AvroUtils.toAvroSchema(beamSchema);
} else if (readSession.getDataFormat() == DataFormat.AVRO) {
sessionSchema = new Schema.Parser().parse(readSession.getAvroSchema().getSchema());
} else {
throw new IllegalArgumentException(
"data is not in a supported dataFormat: " + readSession.getDataFormat());
// TODO: this is inconsistent with method above, where it can be null
Preconditions.checkStateNotNull(targetTable);
TableSchema tableSchema = targetTable.getSchema();
if (selectedFieldsProvider != null && selectedFieldsProvider.isAccessible()) {
tableSchema = BigQueryUtils.trimSchema(tableSchema, selectedFieldsProvider.get());
}

Preconditions.checkStateNotNull(
targetTable); // TODO: this is inconsistent with method above, where it can be null
TableSchema trimmedSchema =
BigQueryAvroUtils.trimBigQueryTableSchema(targetTable.getSchema(), sessionSchema);
List<BigQueryStorageStreamSource<T>> sources = Lists.newArrayList();
for (ReadStream readStream : readSession.getStreamsList()) {
sources.add(
BigQueryStorageStreamSource.create(
readSession, readStream, trimmedSchema, parseFn, outputCoder, bqServices));
readSession, readStream, tableSchema, parseFn, outputCoder, bqServices));
}

return ImmutableList.copyOf(sources);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.apache.beam.sdk.extensions.avro.io.AvroSource;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.DatasetService;
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.util.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings;
Expand Down Expand Up @@ -102,13 +101,12 @@ public <T> BigQuerySourceBase<T> toSource(

/** {@inheritDoc} */
@Override
public Schema getBeamSchema(BigQueryOptions bqOptions) {
public TableSchema getTableSchema(BigQueryOptions bqOptions) {
try {
try (DatasetService datasetService = bqServices.getDatasetService(bqOptions)) {
TableReference tableRef = getTableReference(bqOptions);
Table table = datasetService.getTable(tableRef);
TableSchema tableSchema = Preconditions.checkStateNotNull(table).getSchema();
return BigQueryUtils.fromTableSchema(tableSchema);
return Preconditions.checkStateNotNull(table).getSchema();
}
} catch (Exception e) {
throw new BigQuerySchemaRetrievalException("Exception while trying to retrieve schema", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.avro.Conversions;
import org.apache.avro.LogicalTypes;
import org.apache.avro.generic.GenericData;
Expand Down Expand Up @@ -1039,6 +1041,48 @@ private static Object convertAvroNumeric(Object value) {
return tableSpec;
}

static TableSchema trimSchema(TableSchema schema, @Nullable List<String> selectedFields) {
if (selectedFields == null || selectedFields.isEmpty()) {
return schema;
}

List<TableFieldSchema> trimmedFields =
schema.getFields().stream()
.flatMap(f -> trimField(f, selectedFields))
.collect(Collectors.toList());
return new TableSchema().setFields(trimmedFields);
}

private static Stream<TableFieldSchema> trimField(
TableFieldSchema field, List<String> selectedFields) {
String name = field.getName();
if (selectedFields.contains(name)) {
return Stream.of(field);
}

if (field.getFields() != null) {
// record
List<String> selectedChildren =
selectedFields.stream()
.filter(sf -> sf.startsWith(name + "."))
.map(sf -> sf.substring(name.length() + 1))
.collect(toList());

if (!selectedChildren.isEmpty()) {
List<TableFieldSchema> trimmedChildren =
field.getFields().stream()
.flatMap(c -> trimField(c, selectedChildren))
.collect(toList());

if (!trimmedChildren.isEmpty()) {
return Stream.of(field.clone().setFields(trimmedChildren));
}
}
}

return Stream.empty();
}

private static @Nullable ServiceCallMetric callMetricForMethod(
@Nullable TableReference tableReference, String method) {
if (tableReference != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,8 @@ private void doQuerySourceInitialSplit(
.setParent("projects/" + options.getProject())
.setReadSession(
ReadSession.newBuilder()
.setTable(BigQueryHelpers.toTableResourceName(tempTableReference)))
.setTable(BigQueryHelpers.toTableResourceName(tempTableReference))
.setReadOptions(ReadSession.TableReadOptions.newBuilder()))
.setMaxStreamCount(requestedStreamCount)
.build();

Expand Down Expand Up @@ -482,7 +483,8 @@ public void testQuerySourceInitialSplit_NoReferencedTables() throws Exception {
.setParent("projects/" + options.getProject())
.setReadSession(
ReadSession.newBuilder()
.setTable(BigQueryHelpers.toTableResourceName(tempTableReference)))
.setTable(BigQueryHelpers.toTableResourceName(tempTableReference))
.setReadOptions(ReadSession.TableReadOptions.newBuilder()))
.setMaxStreamCount(1024)
.build();

Expand Down Expand Up @@ -652,7 +654,8 @@ public void testQuerySourceInitialSplitWithBigQueryProject_EmptyResult() throws
.setReadSession(
ReadSession.newBuilder()
.setTable(BigQueryHelpers.toTableResourceName(tempTableReference))
.setDataFormat(DataFormat.AVRO))
.setDataFormat(DataFormat.AVRO)
.setReadOptions(ReadSession.TableReadOptions.newBuilder()))
.setMaxStreamCount(10)
.build();

Expand Down Expand Up @@ -724,7 +727,8 @@ public void testQuerySourceInitialSplit_EmptyResult() throws Exception {
.setParent("projects/" + options.getProject())
.setReadSession(
ReadSession.newBuilder()
.setTable(BigQueryHelpers.toTableResourceName(tempTableReference)))
.setTable(BigQueryHelpers.toTableResourceName(tempTableReference))
.setReadOptions(ReadSession.TableReadOptions.newBuilder()))
.setMaxStreamCount(10)
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,8 @@ private void doTableSourceInitialSplitTest(long bundleSize, int streamCount) thr
.setParent("projects/project-id")
.setReadSession(
ReadSession.newBuilder()
.setTable("projects/foo.com:project/datasets/dataset/tables/table"))
.setTable("projects/foo.com:project/datasets/dataset/tables/table")
.setReadOptions(ReadSession.TableReadOptions.newBuilder()))
.setMaxStreamCount(streamCount)
.build();

Expand Down Expand Up @@ -551,7 +552,8 @@ public void testTableSourceInitialSplit_WithDefaultProject() throws Exception {
.setParent("projects/project-id")
.setReadSession(
ReadSession.newBuilder()
.setTable("projects/project-id/datasets/dataset/tables/table"))
.setTable("projects/project-id/datasets/dataset/tables/table")
.setReadOptions(ReadSession.TableReadOptions.newBuilder()))
.setMaxStreamCount(1024)
.build();

Expand Down Expand Up @@ -599,7 +601,8 @@ public void testTableSourceInitialSplit_EmptyTable() throws Exception {
.setParent("projects/project-id")
.setReadSession(
ReadSession.newBuilder()
.setTable("projects/foo.com:project/datasets/dataset/tables/table"))
.setTable("projects/foo.com:project/datasets/dataset/tables/table")
.setReadOptions(ReadSession.TableReadOptions.newBuilder()))
.setMaxStreamCount(1024)
.build();

Expand Down Expand Up @@ -1482,7 +1485,8 @@ public void testReadFromBigQueryIO() throws Exception {
.setReadSession(
ReadSession.newBuilder()
.setTable("projects/foo.com:project/datasets/dataset/tables/table")
.setDataFormat(DataFormat.AVRO))
.setDataFormat(DataFormat.AVRO)
.setReadOptions(ReadSession.TableReadOptions.newBuilder()))
.setMaxStreamCount(10)
.build();

Expand Down Expand Up @@ -1693,7 +1697,8 @@ public void testReadFromBigQueryIOArrow() throws Exception {
.setReadSession(
ReadSession.newBuilder()
.setTable("projects/foo.com:project/datasets/dataset/tables/table")
.setDataFormat(DataFormat.ARROW))
.setDataFormat(DataFormat.ARROW)
.setReadOptions(ReadSession.TableReadOptions.newBuilder()))
.setMaxStreamCount(10)
.build();

Expand Down
Loading

0 comments on commit 271ea43

Please sign in to comment.