Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Reduce GPU to CPU copies at inference #3127

Merged
merged 5 commits into from
Sep 7, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions haystack/modeling/model/prediction_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,8 @@ def logits_to_preds(
sorted_candidates = torch.cat((start_indices, end_indices), dim=2)

# Get the n_best candidate answers for each sample
sorted_candidates = sorted_candidates.cpu().numpy()
start_end_matrix = start_end_matrix.cpu().numpy()
for sample_idx in range(batch_size):
sample_top_n = self.get_top_candidates(
sorted_candidates[sample_idx],
Expand All @@ -519,24 +521,21 @@ def get_top_candidates(self, sorted_candidates, start_end_matrix, sample_idx: in
start_idx_candidates = set()
end_idx_candidates = set()

start_matrix_softmax_start = torch.softmax(start_matrix[:, 0], dim=-1)
end_matrix_softmax_end = torch.softmax(end_matrix[0, :], dim=-1)
start_matrix_softmax_start = torch.softmax(start_matrix[:, 0], dim=-1).cpu().numpy()
end_matrix_softmax_end = torch.softmax(end_matrix[0, :], dim=-1).cpu().numpy()
# Iterate over all candidates and break when we have all our n_best candidates
for candidate_idx in range(n_candidates):
if len(top_candidates) == self.n_best_per_sample:
break

# Retrieve candidate's indices
start_idx = sorted_candidates[candidate_idx, 0].item()
end_idx = sorted_candidates[candidate_idx, 1].item()
start_idx = sorted_candidates[candidate_idx, 0]
end_idx = sorted_candidates[candidate_idx, 1]
# Ignore no_answer scores which will be extracted later in this method
if start_idx == 0 and end_idx == 0:
continue
if self.duplicate_filtering > -1 and (start_idx in start_idx_candidates or end_idx in end_idx_candidates):
continue
score = start_end_matrix[start_idx, end_idx].item()
score = start_end_matrix[start_idx, end_idx]
confidence = (
(start_matrix_softmax_start[start_idx].item() + end_matrix_softmax_end[end_idx].item()) / 2
(start_matrix_softmax_start[start_idx] + end_matrix_softmax_end[end_idx]) / 2
if score > -500
else np.exp(score / 10) # disqualify answers according to scores in logits_to_preds()
)
Expand All @@ -559,8 +558,12 @@ def get_top_candidates(self, sorted_candidates, start_end_matrix, sample_idx: in
end_idx_candidates.add(end_idx + i)
end_idx_candidates.add(end_idx - i)

no_answer_score = start_end_matrix[0, 0].item()
no_answer_confidence = (start_matrix_softmax_start[0].item() + end_matrix_softmax_end[0].item()) / 2
# Only check if we have enough candidates after adding new candidate to the list
if len(top_candidates) == self.n_best_per_sample:
break

no_answer_score = start_end_matrix[0, 0]
no_answer_confidence = (start_matrix_softmax_start[0] + end_matrix_softmax_end[0]) / 2
top_candidates.append(
QACandidate(
offset_answer_start=0,
Expand Down