Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding translator with many generic input parameter support #782

Merged
merged 22 commits into from
Feb 12, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
9e0951f
Adding translator with many generic input parameter support
lalitpagaria Jan 27, 2021
8acf7ef
Making dict_key as generic
lalitpagaria Jan 27, 2021
ea3386d
Fixing mypy issue
lalitpagaria Jan 27, 2021
b8a39eb
Adding pipeline and using opus models
lalitpagaria Feb 8, 2021
184aac9
Add latest docstring and tutorial changes
github-actions[bot] Feb 8, 2021
2c2e38b
Adding test cases for end-to-end translation for generator, summerize…
lalitpagaria Feb 8, 2021
2cebfd6
raise error join and merge nodes
lalitpagaria Feb 8, 2021
1d2fd4c
Fix test failure
lalitpagaria Feb 8, 2021
6eee6f1
add docstrings. add usage documentation. rm skip_special_tokens param
tholor Feb 10, 2021
07f25d7
Add latest docstring and tutorial changes
github-actions[bot] Feb 10, 2021
61a243e
fix code snippets in md
tholor Feb 10, 2021
4e8140b
Merge branch 'translation' of github.com:lalitpagaria/haystack into t…
tholor Feb 10, 2021
9c7516f
Adding few extra configuration parameters and fixing tests
lalitpagaria Feb 10, 2021
6ec6699
Fixingmypy issue and updating usage document
lalitpagaria Feb 10, 2021
ad801ae
fix for mypy issue in pipeline.py
lalitpagaria Feb 10, 2021
8934c4a
reverting renaming of pytest_collection_modifyitems method
lalitpagaria Feb 10, 2021
43829d3
Addressing review comments
lalitpagaria Feb 11, 2021
3b85b8d
setting skip_special_tokens to True
lalitpagaria Feb 11, 2021
d6d6bed
removing model_max_length argument as None type is not supported to m…
lalitpagaria Feb 11, 2021
ded600d
Merge remote-tracking branch 'upstream/master' into translation
tholor Feb 12, 2021
997c35a
Merge branch 'translation' of github.com:lalitpagaria/haystack into t…
tholor Feb 12, 2021
2a77dfa
Removing padding parameter. Better to leave it as default otherwise i…
lalitpagaria Feb 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions docs/_src/api/api/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
## Pipeline Objects

```python
class Pipeline()
class Pipeline(ABC)
```

Pipeline brings together building blocks to build a complex search pipeline with Haystack & user-defined components.
Expand Down Expand Up @@ -131,7 +131,7 @@ variable 'MYDOCSTORE_PARAMS_INDEX=documents-2021' can be set. Note that an
## BaseStandardPipeline Objects

```python
class BaseStandardPipeline()
class BaseStandardPipeline(ABC)
```

<a name="pipeline.BaseStandardPipeline.add_node"></a>
Expand Down Expand Up @@ -316,6 +316,32 @@ Initialize a Pipeline for finding similar FAQs using semantic document search.

- `retriever`: Retriever instance

<a name="pipeline.TranslationWrapperPipeline"></a>
## TranslationWrapperPipeline Objects

```python
class TranslationWrapperPipeline(BaseStandardPipeline)
```

Takes an existing search pipeline and adds one "input translation node" after the Query and one
"output translation" node just before returning the results

<a name="pipeline.TranslationWrapperPipeline.__init__"></a>
#### \_\_init\_\_

```python
| __init__(input_translator: BaseTranslator, output_translator: BaseTranslator, pipeline: BaseStandardPipeline)
```

Wrap a given `pipeline` with the `input_translator` and `output_translator`.

**Arguments**:

- `input_translator`: A Translator node that shall translate the input query from language A to B
- `output_translator`: A Translator node that shall translate the pipeline results from language B to A
- `pipeline`: The pipeline object (e.g. ExtractiveQAPipeline) you want to "wrap".
Note that pipelines with split or merge nodes are currently not supported.

<a name="pipeline.JoinDocuments"></a>
## JoinDocuments Objects

Expand Down
58 changes: 58 additions & 0 deletions docs/_src/usage/usage/translator.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
<!---
title: "Translator"
metaTitle: "Translator"
metaDescription: ""
slug: "/docs/translator"
date: "2021-02-10"
id: "translatormd"
--->

# Translator

Texts come in different languages. This is not different for search and there are plenty of options to deal with it.
One of them is actually to translate the incoming query, the documents or the search results.

Let's imagine you have an English corpus of technical docs, but the mother tongue of many of your users is French.
You can use a Translator node in your pipeline to
1. Translate the incoming query from French to English
2. Search in your English corpus for the right document / answer
3. Translate the results back from English to French

<div class="recommendation">

**Example (Stand-alone Translator)**

You can use the Translator component directly to translate your query or document(s):
```python
DOCS = [
Document(
text="""Heinz von Foerster was an Austrian American scientist
combining physics and philosophy, and widely attributed
as the originator of Second-order cybernetics."""
)
]
translator = TransformersTranslator(model_name_or_path="Helsinki-NLP/opus-mt-en-fr")
res = translator.translate(documents=DOCS, query=None)
```

**Example (Wrapping another Pipeline)**

You can also wrap one of your existing pipelines and "add" the translation nodes at the beginning and at the end of your pipeline.
For example, lets translate the incoming query to from French to English, then do our document retrieval and then translate the results back from English to French:

```python
from haystack.pipeline import TranslationWrapperPipeline, DocumentSearchPipeline
from haystack.translator import TransformersTranslator

pipeline = DocumentSearchPipeline(retriever=my_dpr_retriever)

in_translator = TransformersTranslator(model_name_or_path="Helsinki-NLP/opus-mt-fr-en")
out_translator = TransformersTranslator(model_name_or_path="Helsinki-NLP/opus-mt-en-fr")

pipeline_with_translation = TranslationWrapperPipeline(input_translator=in_translator,
output_translator=out_translator,
pipeline=pipeline)
```


</div>
58 changes: 53 additions & 5 deletions haystack/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABC
import os
from copy import deepcopy
from pathlib import Path
Expand All @@ -13,9 +14,10 @@
from haystack.reader.base import BaseReader
from haystack.retriever.base import BaseRetriever
from haystack.summarizer.base import BaseSummarizer
from haystack.translator.base import BaseTranslator


class Pipeline:
class Pipeline(ABC):
"""
Pipeline brings together building blocks to build a complex search pipeline with Haystack & user-defined components.

Expand Down Expand Up @@ -45,7 +47,7 @@ def add_node(self, component, name: str, inputs: List[str]):
In cases when the predecessor node has multiple outputs, e.g., a "QueryClassifier", the output
must be specified explicitly as "QueryClassifier.output_2".
"""
self.graph.add_node(name, component=component)
self.graph.add_node(name, component=component, inputs=inputs)

for i in inputs:
if "." in i:
Expand Down Expand Up @@ -93,7 +95,7 @@ def run(self, **kwargs):
while has_next_node:
output_dict, stream_id = self.graph.nodes[current_node_id]["component"].run(**input_dict)
input_dict = output_dict
next_nodes = self._get_next_nodes(current_node_id, stream_id)
next_nodes = self.get_next_nodes(current_node_id, stream_id)

if len(next_nodes) > 1:
join_node_id = list(nx.neighbors(self.graph, next_nodes[0]))[0]
Expand All @@ -114,7 +116,7 @@ def run(self, **kwargs):

return output_dict

def _get_next_nodes(self, node_id: str, stream_id: str):
def get_next_nodes(self, node_id: str, stream_id: str):
current_node_edges = self.graph.edges(node_id, data=True)
next_nodes = [
next_node
Expand Down Expand Up @@ -259,7 +261,7 @@ def _overwrite_with_env_variables(cls, definition: dict):
definition["params"][param_name] = value


class BaseStandardPipeline:
class BaseStandardPipeline(ABC):
pipeline: Pipeline

def add_node(self, component, name: str, inputs: List[str]):
Expand Down Expand Up @@ -452,6 +454,52 @@ def run(self, query: str, filters: Optional[Dict] = None, top_k_retriever: int =
return results


class TranslationWrapperPipeline(BaseStandardPipeline):

"""
Takes an existing search pipeline and adds one "input translation node" after the Query and one
"output translation" node just before returning the results
"""

def __init__(
self,
input_translator: BaseTranslator,
output_translator: BaseTranslator,
pipeline: BaseStandardPipeline
):
"""
Wrap a given `pipeline` with the `input_translator` and `output_translator`.

:param input_translator: A Translator node that shall translate the input query from language A to B
:param output_translator: A Translator node that shall translate the pipeline results from language B to A
:param pipeline: The pipeline object (e.g. ExtractiveQAPipeline) you want to "wrap".
Note that pipelines with split or merge nodes are currently not supported.
"""

self.pipeline = Pipeline()
self.pipeline.add_node(component=input_translator, name="InputTranslator", inputs=["Query"])

graph = pipeline.pipeline.graph
previous_node_name = ["InputTranslator"]
# Traverse in BFS
for node in graph.nodes:
if node == "Query":
continue

# TODO: Do not work properly for Join Node and Answer format
if graph.nodes[node]["inputs"] and len(graph.nodes[node]["inputs"]) > 1:
raise AttributeError("Split and merge nodes are not supported currently")

self.pipeline.add_node(name=node, component=graph.nodes[node]["component"], inputs=previous_node_name)
previous_node_name = [node]

self.pipeline.add_node(component=output_translator, name="OutputTranslator", inputs=previous_node_name)

def run(self, **kwargs):
output = self.pipeline.run(**kwargs)
return output


class QueryNode:
outgoing_edges = 1

Expand Down
2 changes: 0 additions & 2 deletions haystack/retriever/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from tqdm import tqdm

from haystack.document_store.base import BaseDocumentStore
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
from haystack.document_store.memory import InMemoryDocumentStore
from haystack import Document
from haystack.retriever.base import BaseRetriever

Expand Down
3 changes: 1 addition & 2 deletions haystack/summarizer/transformers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import logging
from typing import Any, Dict, List, Optional
from typing import List, Optional

from transformers import pipeline
from transformers.models.auto.modeling_auto import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer

from haystack import Document
from haystack.summarizer.base import BaseSummarizer
Expand Down
1 change: 1 addition & 0 deletions haystack/translator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from haystack.translator.transformers import TransformersTranslator
58 changes: 58 additions & 0 deletions haystack/translator/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Mapping, Optional, Union

from haystack import Document


class BaseTranslator(ABC):
"""
Abstract class for a Translator component that translates either a query or a doc from language A to language B.
"""

outgoing_edges = 1

@abstractmethod
def translate(
self,
query: Optional[str] = None,
documents: Optional[Union[List[Document], List[str], List[Dict[str, Any]]]] = None,
dict_key: Optional[str] = None,
**kwargs
) -> Union[str, List[Document], List[str], List[Dict[str, Any]]]:
"""
Translate the passed query or a list of documents from language A to B.
"""
pass

def run(
self,
query: Optional[str] = None,
documents: Optional[Union[List[Document], List[str], List[Dict[str, Any]]]] = None,
answers: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
dict_key: Optional[str] = None,
**kwargs
):
"""Method that gets executed when this class is used as a Node in a Haystack Pipeline"""

results: Dict = {
**kwargs
}

# This will cover input query stage
if query:
results["query"] = self.translate(query=query)
# This will cover retriever and summarizer
if documents:
dict_key = dict_key or "text"
results["documents"] = self.translate(documents=documents, dict_key=dict_key)

if answers:
dict_key = dict_key or "answer"
if isinstance(answers, Mapping):
# This will cover reader
results["answers"] = self.translate(documents=answers["answers"], dict_key=dict_key)
else:
# This will cover generator
results["answers"] = self.translate(documents=answers, dict_key=dict_key)

return results, "output_1"
Loading