Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: replace torch.no_grad with torch.inference_mode (where possible) #3601

Merged
merged 5 commits into from
Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion haystack/document_stores/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def get_scores_torch(self, query_emb: np.ndarray, document_to_search: List[Docum
while curr_pos < len(doc_embeds):
doc_embeds_slice = doc_embeds[curr_pos : curr_pos + self.scoring_batch_size]
doc_embeds_slice = doc_embeds_slice.to(self.main_device)
with torch.no_grad():
with torch.inference_mode():
slice_scores = torch.matmul(doc_embeds_slice, query_emb.T).cpu()
slice_scores = slice_scores.squeeze(dim=1)
slice_scores = slice_scores.numpy().tolist()
Expand Down
2 changes: 1 addition & 1 deletion haystack/modeling/data_handler/data_silo.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ def _pass_batches(
teacher_outputs: List[List[Tuple[torch.Tensor, ...]]],
tensor_names: List[str],
):
with torch.no_grad():
with torch.inference_mode():
batch_transposed = zip(*batch) # transpose dimensions (from batch, features, ... to features, batch, ...)
batch_transposed_list = [torch.stack(b) for b in batch_transposed] # create tensors for each feature
batch_dict = {
Expand Down
2 changes: 1 addition & 1 deletion haystack/modeling/evaluation/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def eval(
else:
module = model

with torch.no_grad():
with torch.inference_mode():
if isinstance(module, AdaptiveModel):
logits = model.forward(
input_ids=batch.get("input_ids", None),
Expand Down
4 changes: 2 additions & 2 deletions haystack/modeling/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def _get_predictions(self, dataset: Dataset, tensor_names: List, baskets):
batch_samples = samples[i * self.batch_size : (i + 1) * self.batch_size]

# get logits
with torch.no_grad():
with torch.inference_mode():
logits = self.model.forward(**batch)
preds = self.model.formatted_preds(
logits=logits, samples=batch_samples, padding_mask=batch.get("padding_mask", None)
Expand Down Expand Up @@ -402,7 +402,7 @@ def _get_predictions_and_aggregate(self, dataset: Dataset, tensor_names: List, b
batch = {key: batch[key].to(self.devices[0]) for key in batch}

# get logits
with torch.no_grad():
with torch.inference_mode():
# Aggregation works on preds, not logits. We want as much processing happening in one batch + on GPU
# So we transform logits to preds here as well
logits = self.model.forward(
Expand Down
2 changes: 1 addition & 1 deletion haystack/modeling/model/adaptive_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ def forward(self, **kwargs):
:param kwargs: All arguments that need to be passed on to the model.
:return: All logits as torch.tensor or multiple tensors.
"""
with torch.no_grad():
with torch.inference_mode():
if self.language_model_class == "Bert":
input_to_onnx = {
"input_ids": numpy.ascontiguousarray(kwargs["input_ids"].cpu().numpy()),
Expand Down
4 changes: 2 additions & 2 deletions haystack/nodes/ranker/sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] =
# 1. the logit as similarity score/answerable classification
# 2. the logits as answerable classification (no_answer / has_answer)
# https://www.sbert.net/docs/pretrained-models/ce-msmarco.html#usage-with-transformers
with torch.no_grad():
with torch.inference_mode():
similarity_scores = self.transformer_model(**features).logits

logits_dim = similarity_scores.shape[1] # [batch_size, logits_dim]
Expand Down Expand Up @@ -216,7 +216,7 @@ def predict_batch(
cur_queries, [doc.content for doc in cur_docs], padding=True, truncation=True, return_tensors="pt"
).to(self.devices[0])

with torch.no_grad():
with torch.inference_mode():
similarity_scores = self.transformer_model(**features).logits
preds.extend(similarity_scores)
pb.update(len(cur_docs))
Expand Down
4 changes: 2 additions & 2 deletions haystack/nodes/retriever/_embedding_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def embed_queries(self, queries: List[str]) -> np.ndarray:

for i, batch in enumerate(tqdm(dataloader, desc=f"Creating Embeddings", unit=" Batches", disable=disable_tqdm)):
batch = {key: batch[key].to(self.embedding_model.device) for key in batch}
with torch.no_grad():
with torch.inference_mode():
q_reps = (
self.embedding_model.embed_questions(
input_ids=batch["input_ids"], attention_mask=batch["padding_mask"]
Expand All @@ -331,7 +331,7 @@ def embed_documents(self, docs: List[Document]) -> np.ndarray:

for i, batch in enumerate(tqdm(dataloader, desc=f"Creating Embeddings", unit=" Batches", disable=disable_tqdm)):
batch = {key: batch[key].to(self.embedding_model.device) for key in batch}
with torch.no_grad():
with torch.inference_mode():
q_reps = (
self.embedding_model.embed_answers(
input_ids=batch["input_ids"], attention_mask=batch["padding_mask"]
Expand Down
4 changes: 2 additions & 2 deletions haystack/nodes/retriever/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def _get_predictions(self, dicts: List[Dict[str, Any]]) -> Dict[str, np.ndarray]
batch = {key: raw_batch[key].to(self.devices[0]) for key in raw_batch}

# get logits
with torch.no_grad():
with torch.inference_mode():
query_embeddings, passage_embeddings = self.model.forward(
query_input_ids=batch.get("query_input_ids", None),
query_segment_ids=batch.get("query_segment_ids", None),
Expand Down Expand Up @@ -1171,7 +1171,7 @@ def _get_predictions(self, dicts: List[Dict[str, Any]]) -> Dict[str, np.ndarray]
batch = {key: batch[key].to(self.devices[0]) for key in batch}

# get logits
with torch.no_grad():
with torch.inference_mode():
query_embeddings, passage_embeddings = self.model.forward(**batch)[0]
if query_embeddings is not None:
query_embeddings_batched.append(query_embeddings.cpu().numpy())
Expand Down
6 changes: 3 additions & 3 deletions haystack/utils/augment_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def load_glove(
id_word_mapping[i] = split[0]
vector_list.append(torch.tensor([float(x) for x in split[1:]]))
vectors = torch.stack(vector_list)
with torch.no_grad():
with torch.inference_mode():
vectors = vectors.to(device)
vectors = F.normalize(vectors, dim=1)
return word_id_mapping, id_word_mapping, vectors
Expand Down Expand Up @@ -132,7 +132,7 @@ def get_replacements(
inputs.append((input_ids_, subword_index))

# doing batched forward pass
with torch.no_grad():
with torch.inference_mode():
prediction_list = []
while len(inputs) != 0:
batch_list, token_indices = tuple(zip(*inputs[:batch_size]))
Expand Down Expand Up @@ -165,7 +165,7 @@ def get_replacements(
elif word in glove_word_id_mapping: # word was split into subwords so we use glove instead
word_id = glove_word_id_mapping[word]
glove_vector = glove_vectors[word_id]
with torch.no_grad():
with torch.inference_mode():
word_similarities = torch.mm(glove_vectors, glove_vector.unsqueeze(1)).squeeze(1)
ranking = torch.argsort(word_similarities, descending=True)[: word_possibilities + 1]
possible_words.append([glove_id_word_mapping[int(id_)] for id_ in ranking.cpu()])
Expand Down
6 changes: 3 additions & 3 deletions test/modeling/test_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa
batch = {key: batch[key].to(device) for key in batch}

# get logits
with torch.no_grad():
with torch.inference_mode():
query_embeddings, passage_embeddings = model.forward(
query_input_ids=batch.get("query_input_ids", None),
query_segment_ids=batch.get("query_segment_ids", None),
Expand Down Expand Up @@ -863,7 +863,7 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa
batch = {key: batch[key].to(device) for key in batch}

# get logits
with torch.no_grad():
with torch.inference_mode():
query_embeddings, passage_embeddings = loaded_model.forward(
query_input_ids=batch.get("query_input_ids", None),
query_segment_ids=batch.get("query_segment_ids", None),
Expand Down Expand Up @@ -952,7 +952,7 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa
batch = {key: batch[key].to(device) for key in batch}

# get logits
with torch.no_grad():
with torch.inference_mode():
query_embeddings, passage_embeddings = loaded_model.forward(
query_input_ids=batch.get("query_input_ids", None),
query_segment_ids=batch.get("query_segment_ids", None),
Expand Down