Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect device specified when using HF transformer pipeline object #3160

Closed
1 task done
sjrl opened this issue Sep 5, 2022 · 7 comments · Fixed by #3161 or #3184
Closed
1 task done

Incorrect device specified when using HF transformer pipeline object #3160

sjrl opened this issue Sep 5, 2022 · 7 comments · Fixed by #3161 or #3184
Assignees
Labels
topic:pipeline type:bug Something isn't working

Comments

@sjrl
Copy link
Contributor

sjrl commented Sep 5, 2022

Describe the bug
After the integration of PR #3062 we get an error when running Haystack nodes that use HuggingFace's pipeline object.

For example, the code

from haystack.nodes import EntityExtractor
ner_node = EntityExtractor()
ner_node.extract("This is a test. I live in Berlin")

results in the error

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [11], in <cell line: 1>()
----> 1 ner_node.extract("This is a test. I live in Berlin")

File ~/code/haystack/haystack/nodes/extractor/entity.py:179, in EntityExtractor.extract(self, text)
    175 def extract(self, text):
    176     """
    177     This function can be called to perform entity extraction when using the node in isolation.
    178     """
--> 179     entities = self.extractor_pipeline(text)
    180     return entities

File ~/miniconda3/envs/haystack_src/lib/python3.8/site-packages/transformers/pipelines/token_classification.py:191, in TokenClassificationPipeline.__call__(self, inputs, **kwargs)
    188 if offset_mapping:
    189     kwargs["offset_mapping"] = offset_mapping
--> 191 return super().__call__(inputs, **kwargs)

File ~/miniconda3/envs/haystack_src/lib/python3.8/site-packages/transformers/pipelines/base.py:1067, in Pipeline.__call__(self, inputs, num_workers, batch_size, *args, **kwargs)
   1065     return self.iterate(inputs, preprocess_params, forward_params, postprocess_params)
   1066 else:
-> 1067     return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)

File ~/miniconda3/envs/haystack_src/lib/python3.8/site-packages/transformers/pipelines/base.py:1074, in Pipeline.run_single(self, inputs, preprocess_params, forward_params, postprocess_params)
   1072 def run_single(self, inputs, preprocess_params, forward_params, postprocess_params):
   1073     model_inputs = self.preprocess(inputs, **preprocess_params)
-> 1074     model_outputs = self.forward(model_inputs, **forward_params)
   1075     outputs = self.postprocess(model_outputs, **postprocess_params)
   1076     return outputs

File ~/miniconda3/envs/haystack_src/lib/python3.8/site-packages/transformers/pipelines/base.py:975, in Pipeline.forward(self, model_inputs, **forward_params)
    974 def forward(self, model_inputs, **forward_params):
--> 975     with self.device_placement():
    976         if self.framework == "tf":
    977             model_inputs["training"] = False

File ~/miniconda3/envs/haystack_src/lib/python3.8/contextlib.py:113, in _GeneratorContextManager.__enter__(self)
    111 del self.args, self.kwds, self.func
    112 try:
--> 113     return next(self.gen)
    114 except StopIteration:
    115     raise RuntimeError("generator didn't yield") from None

File ~/miniconda3/envs/haystack_src/lib/python3.8/site-packages/transformers/pipelines/base.py:864, in Pipeline.device_placement(self)
    862 else:
    863     if self.device.type == "cuda":
--> 864         torch.cuda.set_device(self.device)
    866     yield

File ~/miniconda3/envs/haystack_src/lib/python3.8/site-packages/torch/cuda/__init__.py:312, in set_device(device)
    302 def set_device(device: _device_t) -> None:
    303     r"""Sets the current device.
    304 
    305     Usage of this function is discouraged in favor of :any:`device`. In most
   (...)
    310             if this argument is negative.
    311     """
--> 312     device = _get_device_index(device)
    313     if device >= 0:
    314         torch._C._cuda_setDevice(device)

File ~/miniconda3/envs/haystack_src/lib/python3.8/site-packages/torch/cuda/_utils.py:34, in _get_device_index(device, optional, allow_cpu)
     32     if isinstance(device, torch.cuda.device):
     33         return device.idx
---> 34 return _torch_get_device_index(device, optional, allow_cpu)

File ~/miniconda3/envs/haystack_src/lib/python3.8/site-packages/torch/_utils.py:540, in _get_device_index(device, optional, allow_cpu)
    538             device_idx = _get_current_device_index()
    539     else:
--> 540         raise ValueError('Expected a torch.device with a specified index '
    541                          'or an integer, but got:{}'.format(device))
    542 return device_idx

ValueError: Expected a torch.device with a specified index or an integer, but got:cuda

The node works as expected if you initialize the node using the following code

import torch
ner_node = EntityExtractor(
    devices=[torch.device('cuda:0')]
)

Adding the index :0 fixes the torch error. Without the addition of :0 the device is auto determined to be torch.device('cuda'), which seems to cause the error at run time.

Expected behavior
For the device to be correctly provided.

Additional context

  • This error only occurs when running with a GPU available.
  • This error does not occur at initialization. It only occurs when running the node.
  • This appears to affect any node that uses the pipeline object from HF. For example, Tutorial 14 also fails with this error when you reach the cell that runs the transformer_keyword_classifier.

To Reproduce
Run tutorial 14 or the provided code snippet in an environment with a GPU.

FAQ Check

System:

  • OS: Linux
  • GPU/CPU: GPU
  • Haystack version (commit or version number): main branch e1f3992
@sjrl sjrl added type:bug Something isn't working topic:pipeline labels Sep 5, 2022
@sjrl
Copy link
Contributor Author

sjrl commented Sep 5, 2022

@vblagoje I'm not sure if this is actually a bug in the Transformer library since they just added support for torch.devices. But to be on the safe side it may be smart to add a default index (:0) whenever we pass a device to the pipeline object from the Transformers library. What do you think?

@sjrl
Copy link
Contributor Author

sjrl commented Sep 5, 2022

@vblagoje One option would be to update this line

devices_to_use = [torch.device("cuda")]

to

torch.device('cuda:0')

I think this would be the cleanest option and I don't think it should cause any problems. Or the perhaps safer option would be to edit the device we pass to the pipeline object. So something like this:

      device = self.devices[0]
      if device == torch.device('cuda'):
          device = torch.device('cuda:0')
      self.model = pipeline(
          "ner",
          model=token_classifier,
          tokenizer=tokenizer,
          aggregation_strategy="simple",
          device=self.devices[0],
          use_auth_token=use_auth_token,
      )

what do you think?

@vblagoje
Copy link
Member

vblagoje commented Sep 5, 2022

What if we handle this case by replacing all instances of torch.device('cuda') with torch.device('cuda:0') in initialize_device_settings?

@sjrl
Copy link
Contributor Author

sjrl commented Sep 5, 2022

Ahh yes, sorry I wasn't clear. That is what my first suggestion above was meant to say.

@sjrl
Copy link
Contributor Author

sjrl commented Sep 6, 2022

Leaving open since we aren't sure if we should also handle the case where the user passes torch.device("cuda") to the HF pipeline function or if we should investigate if this is a bug on HF's side.

@sjrl sjrl reopened this Sep 6, 2022
@vblagoje
Copy link
Member

vblagoje commented Sep 8, 2022

@sjrl I'd vote to make an easy fix and replace any instance of torch.device("cuda") with torch.device("cuda:0") in initialize_device_settings and be done with this issue. If HF fixes it, great - if not, we have it handled already.

@sjrl
Copy link
Contributor Author

sjrl commented Sep 8, 2022

@vblagoje Yeah that sounds good to me, just so we can stay on the safe side.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic:pipeline type:bug Something isn't working
Projects
None yet
2 participants