Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add option to use MultipleNegativesRankingLoss for EmbeddingRetriever training with sentence-transformers #3164

Merged
merged 4 commits into from
Sep 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/_src/api/api/retriever.md
Original file line number Diff line number Diff line change
Expand Up @@ -1678,7 +1678,8 @@ def train(training_data: List[Dict[str, Any]],
learning_rate: float = 2e-5,
n_epochs: int = 1,
num_warmup_steps: int = None,
batch_size: int = 16) -> None
batch_size: int = 16,
train_loss: str = "mnrl") -> None
```

Trains/adapts the underlying embedding model.
Expand All @@ -1697,6 +1698,9 @@ Each training data example is a dictionary with the following keys:
- `n_epochs` (`int`): The number of epochs
- `num_warmup_steps` (`int`): The number of warmup steps
- `batch_size` (`int (optional)`): The batch size to use for the training, defaults to 16
- `train_loss` (`str (optional)`): The loss to use for training.
If you're using sentence-transformers as embedding_model (which are the only ones that currently support training),
possible values are 'mnrl' (Multiple Negatives Ranking Loss) or 'margin_mse' (MarginMSE).

<a id="dense.EmbeddingRetriever.save"></a>

Expand Down
33 changes: 27 additions & 6 deletions haystack/nodes/retriever/_embedding_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
import torch
from sentence_transformers import InputExample, losses
from sentence_transformers import InputExample
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SequentialSampler
from tqdm.auto import tqdm
Expand All @@ -14,6 +14,7 @@
from haystack.modeling.data_handler.dataloader import NamedDataLoader
from haystack.modeling.data_handler.dataset import convert_features_to_dataset, flatten_rename
from haystack.modeling.infer import Inferencer
from haystack.nodes.retriever._losses import _TRAINING_LOSSES
from haystack.schema import Document

if TYPE_CHECKING:
Expand Down Expand Up @@ -195,14 +196,34 @@ def train(
n_epochs: int = 1,
num_warmup_steps: int = None,
batch_size: int = 16,
train_loss: str = "mnrl",
):

train_examples = [
InputExample(texts=[i["question"], i["pos_doc"], i["neg_doc"]], label=i["score"]) for i in training_data
]
logger.info(f"GPL training/adapting {self.embedding_model} with {len(train_examples)} examples")
if train_loss not in _TRAINING_LOSSES:
raise ValueError(f"Unrecognized train_loss {train_loss}. Should be one of: {_TRAINING_LOSSES.keys()}")

st_loss = _TRAINING_LOSSES[train_loss]

train_examples = []
for train_i in training_data:
missing_attrs = st_loss.required_attrs.difference(set(train_i.keys()))
if len(missing_attrs) > 0:
raise ValueError(
f"Some training examples don't contain the fields {missing_attrs} which are necessary when using the '{train_loss}' loss."
)

texts = [train_i["question"], train_i["pos_doc"]]
if "neg_doc" in train_i:
texts.append(train_i["neg_doc"])

if "score" in train_i:
train_examples.append(InputExample(texts=texts, label=train_i["score"]))
else:
train_examples.append(InputExample(texts=texts))

logger.info(f"Training/adapting {self.embedding_model} with {len(train_examples)} examples")
train_dataloader = DataLoader(train_examples, batch_size=batch_size, drop_last=True, shuffle=True)
train_loss = losses.MarginMSELoss(self.embedding_model)
train_loss = st_loss.loss(self.embedding_model)

# Tune the model
self.embedding_model.fit(
Expand Down
12 changes: 12 additions & 0 deletions haystack/nodes/retriever/_losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from collections import namedtuple
from typing import Dict

from sentence_transformers import losses


SentenceTransformerLoss = namedtuple("SentenceTransformerLoss", "loss required_attrs")

_TRAINING_LOSSES: Dict[str, SentenceTransformerLoss] = {
"mnrl": SentenceTransformerLoss(losses.MultipleNegativesRankingLoss, {"question", "pos_doc"}),
"margin_mse": SentenceTransformerLoss(losses.MarginMSELoss, {"question", "pos_doc", "neg_doc", "score"}),
}
6 changes: 6 additions & 0 deletions haystack/nodes/retriever/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,6 +1863,7 @@ def train(
n_epochs: int = 1,
num_warmup_steps: int = None,
batch_size: int = 16,
train_loss: str = "mnrl",
) -> None:
"""
Trains/adapts the underlying embedding model.
Expand All @@ -1885,13 +1886,18 @@ def train(
:type num_warmup_steps: int
:param batch_size: The batch size to use for the training, defaults to 16
:type batch_size: int (optional)
:param train_loss: The loss to use for training.
If you're using sentence-transformers as embedding_model (which are the only ones that currently support training),
possible values are 'mnrl' (Multiple Negatives Ranking Loss) or 'margin_mse' (MarginMSE).
:type train_loss: str (optional)
"""
self.embedding_encoder.train(
training_data,
learning_rate=learning_rate,
n_epochs=n_epochs,
num_warmup_steps=num_warmup_steps,
batch_size=batch_size,
train_loss=train_loss,
)

def save(self, save_dir: Union[Path, str]) -> None:
Expand Down