Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Python] Hugging Face pipeline support #27399

Merged
merged 35 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
06af6b7
automodel first pass
riteshghorse May 10, 2023
416166d
new model
riteshghorse May 16, 2023
6f063e5
updated model handler api
riteshghorse Jun 21, 2023
df87366
add model_class param
riteshghorse Jun 23, 2023
4da7edd
Merge branch 'master' of https://github.com/apache/beam into hf-model…
riteshghorse Jun 23, 2023
025cc52
update doc comments
riteshghorse Jun 26, 2023
8cf7a01
Merge branch 'master' of https://github.com/apache/beam into hf-model…
riteshghorse Jun 26, 2023
2c671ab
updated integration test and example
riteshghorse Jun 26, 2023
abaeb2a
unit test, modified params
riteshghorse Jun 27, 2023
d5e1cf3
add test setup for hugging face tests
riteshghorse Jun 27, 2023
4177c09
fix lints
riteshghorse Jun 27, 2023
6324752
fix import order
riteshghorse Jun 27, 2023
30029d3
refactor, doc, lints
riteshghorse Jun 28, 2023
c60d312
refactor, doc comments
riteshghorse Jun 29, 2023
a52536f
change test file
riteshghorse Jun 29, 2023
496d205
update types
riteshghorse Jul 7, 2023
8dd0ff2
add hugging face pipeline support
riteshghorse Jul 7, 2023
c670ada
integration test for pipeline
riteshghorse Jul 10, 2023
09e64a4
add doc, gs link
riteshghorse Jul 11, 2023
504b161
test raises exception
riteshghorse Jul 11, 2023
4ece137
fix python lints
riteshghorse Jul 18, 2023
250a2d5
add inference fn
riteshghorse Jul 24, 2023
4d6b6b2
Merge branch 'master', remote-tracking branch 'origin' into hf-pipeline
riteshghorse Jul 24, 2023
4787635
update doc
riteshghorse Jul 24, 2023
c9fa0d5
merge master
riteshghorse Jul 24, 2023
c592f91
docs, lint
riteshghorse Jul 24, 2023
db99ad0
docs, lint
riteshghorse Jul 24, 2023
b539d32
remove optional from inference_fn
riteshghorse Jul 24, 2023
e912d35
add enum for tasks
riteshghorse Jul 26, 2023
ba5e31f
update pydoc
riteshghorse Jul 26, 2023
6963a5d
update pydoc
riteshghorse Jul 26, 2023
44916b9
doc, formatting changes
riteshghorse Aug 1, 2023
4d3fdd0
fix doc
riteshghorse Aug 1, 2023
9b64975
fix optional in doc
riteshghorse Aug 1, 2023
7db987b
pin model version
riteshghorse Aug 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

""""A pipeline that uses RunInference to perform Question Answering using the
model from Hugging Face Models Hub.

This pipeline takes questions and context from a custom text file separated by
a semicolon. These are converted to SquadExamples by using the utility provided
by transformers.QuestionAnsweringPipeline and passed to the model handler.
We just provide the model name here because the model repository specifies the
task that it will do. The pipeline then writes the prediction to an output
file in which users can then compare against the original context.
"""

import argparse
import logging
from typing import Iterable
from typing import Tuple

import apache_beam as beam
from apache_beam.ml.inference.base import KeyedModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.huggingface_inference import HuggingFacePipelineModelHandler
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
from apache_beam.runners.runner import PipelineResult
from transformers import QuestionAnsweringPipeline


class PostProcessor(beam.DoFn):
"""Processes the PredictionResult to get the predicted answer.

Hugging Face Pipeline for Question Answering returns a dictionary
with score, start and end index of answer and the answer.
"""
def process(self, result: Tuple[str, PredictionResult]) -> Iterable[str]:
text, prediction = result
predicted_answer = prediction.inference['answer']
yield text + ';' + predicted_answer


def preprocess(text):
if len(text.strip()) > 0:
question, context = text.split(';')
yield (question, context)


def create_squad_example(text):
question, context = text
yield question, QuestionAnsweringPipeline.create_sample(question, context)
riteshghorse marked this conversation as resolved.
Show resolved Hide resolved


def parse_known_args(argv):
"""Parses args for the workflow."""
parser = argparse.ArgumentParser()
parser.add_argument(
'--input',
dest='input',
help='Path of file containing question and context separated by semicolon'
)
parser.add_argument(
'--output',
dest='output',
required=True,
help='Path of file in which to save the output predictions.')
parser.add_argument(
'--model_name',
dest='model_name',
default="deepset/roberta-base-squad2",
help='Model repository-id from Hugging Face Models Hub.')
return parser.parse_known_args(argv)


def run(
argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult:
"""
Args:
argv: Command line arguments defined for this example.
save_main_session: Used for internal testing.
test_pipeline: Used for internal testing.
"""
known_args, pipeline_args = parse_known_args(argv)
pipeline_options = PipelineOptions(pipeline_args)
pipeline_options.view_as(SetupOptions).save_main_session = save_main_session

pipeline = test_pipeline
if not test_pipeline:
pipeline = beam.Pipeline(options=pipeline_options)

model_handler = HuggingFacePipelineModelHandler(model=known_args.model_name, )
if not known_args.input:
text = (
pipeline | 'CreateSentences' >> beam.Create([
"What does Apache Beam do?;"
"Apache Beam enables batch and streaming data processing.",
"What is the capital of France?;The capital of France is Paris .",
"Where was beam summit?;Apache Beam Summit 2023 was in NYC.",
]))
else:
text = (
pipeline | 'ReadSentences' >> beam.io.ReadFromText(known_args.input))
processed_text = (
text
| 'PreProcess' >> beam.ParDo(preprocess)
| 'SquadExample' >> beam.ParDo(create_squad_example))
output = (
processed_text
| 'RunInference' >> RunInference(KeyedModelHandler(model_handler))
| 'ProcessOutput' >> beam.ParDo(PostProcessor()))
_ = output | "WriteOutput" >> beam.io.WriteToText(
known_args.output, shard_name_template='', append_trailing_newlines=True)

result = pipeline.run()
result.wait_until_finish()
return result


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
run()
151 changes: 145 additions & 6 deletions sdks/python/apache_beam/ml/inference/huggingface_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,16 @@
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.pytorch_inference import _convert_to_device
from transformers import AutoModel
from transformers import Pipeline
from transformers import TFAutoModel
from transformers import pipeline

_LOGGER = logging.getLogger(__name__)

__all__ = [
"HuggingFaceModelHandlerTensor",
"HuggingFaceModelHandlerKeyedTensor",
"HuggingFacePipelineModelHandler",
]

TensorInferenceFn = Callable[[
Expand All @@ -59,10 +62,13 @@
Union[AutoModel, TFAutoModel],
str,
Optional[Dict[str, Any]],
Optional[str],
Optional[str]
],
Iterable[PredictionResult],
]
Iterable[PredictionResult]]

PipelineInferenceFn = Callable[
[Sequence[str], Pipeline, Optional[Dict[str, Any]]],
Iterable[PredictionResult]]


def _validate_constructor_args(model_uri, model_class):
Expand Down Expand Up @@ -109,6 +115,13 @@ def is_gpu_available_tensorflow(device):
return True


def _validate_constructor_args_hf_pipeline(task, model):
if not task and not model:
raise RuntimeError(
'Please provide both task and model to HuggingFacePipelineModelHandler.'
'If the model already defines the task, no need to specify the task.')
riteshghorse marked this conversation as resolved.
Show resolved Hide resolved


def _run_inference_torch_keyed_tensor(
batch: Sequence[Dict[str, torch.Tensor]],
model: AutoModel,
Expand Down Expand Up @@ -447,7 +460,7 @@ def run_inference(
else:
self._framework = "pt"

if (self._framework == 'pt' and self._device == "GPU" and
if (self._framework == "pt" and self._device == "GPU" and
is_gpu_available_torch()):
model.to(torch.device("cuda"))

Expand All @@ -462,6 +475,9 @@ def run_inference(
return _default_inference_fn_torch(
batch, model, self._device, inference_args, self._model_uri)

def update_model_path(self, model_path: Optional[str] = None):
self._model_uri = model_path if model_path else self._model_uri

def get_num_bytes(
self, batch: Sequence[Union[tf.Tensor, torch.Tensor]]) -> int:
"""
Expand All @@ -483,6 +499,129 @@ def share_model_across_processes(self) -> bool:
def get_metrics_namespace(self) -> str:
"""
Returns:
A namespace for metrics collected by the RunInference transform.
A namespace for metrics collected by the RunInference transform.
"""
return 'BeamML_HuggingFaceModelHandler_Tensor'


def _convert_to_result(
batch: Iterable,
predictions: Union[Iterable, Dict[Any, Iterable]],
model_id: Optional[str] = None,
) -> Iterable[PredictionResult]:
return [
PredictionResult(x, y, model_id) for x, y in zip(batch, [predictions])
]


def _default_pipeline_inference_fn(
batch, model, inference_args) -> Iterable[PredictionResult]:
riteshghorse marked this conversation as resolved.
Show resolved Hide resolved
predicitons = model(batch, **inference_args)
return predicitons


class HuggingFacePipelineModelHandler(ModelHandler[str,
riteshghorse marked this conversation as resolved.
Show resolved Hide resolved
PredictionResult,
Pipeline]):
def __init__(
self,
task: str = "",
model=None,
*,
inference_fn: PipelineInferenceFn = _default_pipeline_inference_fn,
load_model_args: Optional[Dict[str, Any]] = None,
riteshghorse marked this conversation as resolved.
Show resolved Hide resolved
inference_args: Optional[Dict[str, Any]] = None,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
large_model: bool = False,
**kwargs):
"""
Implementation of the ModelHandler interface for Hugging Face Pipelines.

**Note:** To specify which device to use (CPU/GPU),
use the load_model_args with key-value as you would do in the usual
Hugging Face pipeline. Ex: load_model_args={'device':0})

Example Usage model::
pcoll | RunInference(HuggingFacePipelineModelHandler(
task="fill-mask"))

Args:
task (str): task supported by HuggingFace Pipelines.
riteshghorse marked this conversation as resolved.
Show resolved Hide resolved
model : path to pretrained model on Hugging Face Models Hub to use custom
model for the chosen task. If the model already defines the task then
no need to specify the task parameter.
inference_fn: the inference function to use during RunInference.
Default is _default_pipeline_inference_fn.
load_model_args (Dict[str, Any]): keyword arguments to provide load
options while loading models from Hugging Face Hub. Defaults to None.
inference_args (Dict[str, Any]): Non-batchable arguments
required as inputs to the model's inference function.
Defaults to None.
min_batch_size: the minimum batch size to use when batching inputs.
max_batch_size: the maximum batch size to use when batching inputs.
large_model: set to true if your model is large enough to run into
memory pressure if you load multiple copies. Given a model that
consumes N memory and a machine with W cores and M memory, you should
set this to True if N*W > M.
kwargs: 'env_vars' can be used to set environment variables
before loading the model.

**Supported Versions:** HuggingFacePipelineModelHandler supports
transformers>=4.18.0.
"""
return "BeamML_HuggingFaceModelHandler_Tensor"
self._task = task
self._model = model
self._inference_fn = inference_fn
self._load_model_args = load_model_args if load_model_args else {}
self._inference_args = inference_args if inference_args else {}
self._batching_kwargs = {}
self._framework = "torch"
self._env_vars = kwargs.get('env_vars', {})
if min_batch_size is not None:
self._batching_kwargs['min_batch_size'] = min_batch_size
if max_batch_size is not None:
self._batching_kwargs['max_batch_size'] = max_batch_size
self._large_model = large_model
_validate_constructor_args_hf_pipeline(self._task, self._model)

def load_model(self):
return pipeline(task=self._task, model=self._model, **self._load_model_args)
riteshghorse marked this conversation as resolved.
Show resolved Hide resolved

def run_inference(
self,
batch: Sequence[str],
model: Pipeline,
inference_args: Optional[Dict[str, Any]] = None
) -> Iterable[PredictionResult]:
"""
Runs inferences on a batch of examples passed as a string resource.
These can either be string sentences, or string path to images or
audio files.

Args:
batch: A sequence of strings resources.
model: A Hugging Face Pipeline.
inference_args: Non-batchable arguments required as inputs to the model's
inference function.
Returns:
An Iterable of type PredictionResult.
"""
inference_args = {} if not inference_args else inference_args
predictions = self._inference_fn(batch, model, inference_args)
return _convert_to_result(batch, predictions)

def update_model_path(self, model_path: Optional[str] = None):
self._model = model_path if model_path else self._model
riteshghorse marked this conversation as resolved.
Show resolved Hide resolved

def get_num_bytes(self, batch: Sequence[str]) -> int:
return sum(sys.getsizeof(element) for element in batch)

def batch_elements_kwargs(self):
return self._batching_kwargs

def share_model_across_processes(self) -> bool:
return self._large_model

def get_metrics_namespace(self) -> str:
return 'BeamML_HuggingFacePipelineModelHandler'
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

try:
from apache_beam.examples.inference import huggingface_language_modeling
from apache_beam.examples.inference import huggingface_question_answering
from apache_beam.ml.inference import pytorch_inference_it_test
except ImportError:
raise unittest.SkipTest(
Expand Down Expand Up @@ -74,6 +75,37 @@ def test_hf_language_modeling(self):
predicted_predicted_text = predictions_dict[text]
self.assertEqual(actual_predicted_text, predicted_predicted_text)

def test_hf_pipeline(self):
test_pipeline = TestPipeline(is_integration_test=True)
# Path to text file containing some questions and context
input_file = 'gs://apache-beam-ml/datasets/custom/questions.txt'
output_file_dir = 'gs://apache-beam-ml/hf/testing/predictions'
riteshghorse marked this conversation as resolved.
Show resolved Hide resolved
output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt'])
extra_opts = {
'input': input_file,
'output': output_file,
}
huggingface_question_answering.run(
test_pipeline.get_full_options_as_args(**extra_opts),
save_main_session=False)
self.assertEqual(FileSystems().exists(output_file), True)
predictions = pytorch_inference_it_test.process_outputs(
filepath=output_file)
actuals_file = (
'gs://apache-beam-ml/testing/expected_outputs/'
'test_hf_pipeline_answers.txt')
actuals = pytorch_inference_it_test.process_outputs(filepath=actuals_file)

predictions_dict = {}
for prediction in predictions:
text, predicted_text = prediction.split(';')
predictions_dict[text] = predicted_text.strip()

for actual in actuals:
text, actual_predicted_text = actual.split(';')
predicted_predicted_text = predictions_dict[text]
self.assertEqual(actual_predicted_text, predicted_predicted_text)
riteshghorse marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == '__main__':
logging.getLogger().setLevel(logging.DEBUG)
Expand Down
Loading