-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Incorrect device specified when using HF transformer pipeline object #3160
Comments
@vblagoje I'm not sure if this is actually a bug in the Transformer library since they just added support for |
@vblagoje One option would be to update this line haystack/haystack/modeling/utils.py Line 88 in e1f3992
to torch.device('cuda:0') I think this would be the cleanest option and I don't think it should cause any problems. Or the perhaps safer option would be to edit the device we pass to the device = self.devices[0]
if device == torch.device('cuda'):
device = torch.device('cuda:0')
self.model = pipeline(
"ner",
model=token_classifier,
tokenizer=tokenizer,
aggregation_strategy="simple",
device=self.devices[0],
use_auth_token=use_auth_token,
) what do you think? |
What if we handle this case by replacing all instances of |
Ahh yes, sorry I wasn't clear. That is what my first suggestion above was meant to say. |
Leaving open since we aren't sure if we should also handle the case where the user passes |
@sjrl I'd vote to make an easy fix and replace any instance of |
@vblagoje Yeah that sounds good to me, just so we can stay on the safe side. |
Describe the bug
After the integration of PR #3062 we get an error when running Haystack nodes that use HuggingFace's pipeline object.
For example, the code
results in the error
The node works as expected if you initialize the node using the following code
Adding the index
:0
fixes the torch error. Without the addition of:0
the device is auto determined to betorch.device('cuda')
, which seems to cause the error at run time.Expected behavior
For the device to be correctly provided.
Additional context
pipeline
object from HF. For example, Tutorial 14 also fails with this error when you reach the cell that runs thetransformer_keyword_classifier
.To Reproduce
Run tutorial 14 or the provided code snippet in an environment with a GPU.
FAQ Check
System:
The text was updated successfully, but these errors were encountered: