Skip to content

Commit

Permalink
Refactor REST APIs to use Pipelines (#922)
Browse files Browse the repository at this point in the history
  • Loading branch information
oryx1729 authored Apr 7, 2021
1 parent 64ad953 commit 8c68699
Show file tree
Hide file tree
Showing 22 changed files with 429 additions and 965 deletions.
20 changes: 5 additions & 15 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,9 @@ services:
image: "deepset/haystack-cpu:latest"
ports:
- 8000:8000
volumes:
# Optional: mount your own models from disk into the container
- "./models:/home/user/models"
environment:
# See rest_api/config.py for more variables that you can configure here.
- DB_HOST=elasticsearch
- USE_GPU=False
- TOP_K_PER_SAMPLE=3 # how many answers can come from the same small passage (reduce value for more variety of answers)
# Load a model from transformers' model hub or a local path into the FARMReader.
- READER_MODEL_PATH=deepset/roberta-base-squad2
# - READER_MODEL_PATH=home/user/models/roberta-base-squad2
# Alternative: If you want to use the TransformersReader (e.g. for loading a local model in transformers format):
# - READER_TYPE=TransformersReader
# - READER_MODEL_PATH=/home/user/models/roberta-base-squad2
# - READER_TOKENIZER=/home/user/models/roberta-base-squad2
# See rest_api/pipelines.yaml for configurations of Search & Indexing Pipeline.
- ELASTICSEARCHDOCUMENTSTORE_PARAMS_HOST=elasticsearch
restart: always
depends_on:
- elasticsearch
Expand All @@ -36,7 +24,9 @@ services:
environment:
- discovery.type=single-node
ui:
image: "deepset/haystack-streamlit-ui"
build:
context: ui
dockerfile: Dockerfile
ports:
- 8501:8501
environment:
Expand Down
9 changes: 9 additions & 0 deletions docs/_src/api/api/file_converter.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@ supplied meta data like author, url, external IDs can be supplied as a dictionar

Validate if the language of the text is one of valid languages.

<a name="base.FileTypeClassifier"></a>
## FileTypeClassifier Objects

```python
class FileTypeClassifier(BaseComponent)
```

Route files in an Indexing Pipeline to corresponding file converters.

<a name="txt"></a>
# Module txt

Expand Down
2 changes: 1 addition & 1 deletion docs/_src/api/api/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Add a new node to the pipeline.
#### get\_node

```python
| get_node(name: str)
| get_node(name: str) -> Optional[BaseComponent]
```

Get a node from the Pipeline.
Expand Down
3 changes: 2 additions & 1 deletion haystack/file_converter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from haystack.file_converter.base import FileTypeClassifier
from haystack.file_converter.docx import DocxToTextConverter
from haystack.file_converter.markdown import MarkdownConverter
from haystack.file_converter.pdf import PDFToTextConverter
from haystack.file_converter.tika import TikaConverter
from haystack.file_converter.txt import TextConverter
from haystack.file_converter.markdown import MarkdownConverter
16 changes: 16 additions & 0 deletions haystack/file_converter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,19 @@ def run(self, file_path: Path, meta: Optional[Dict[str, str]] = None, remove_num

result = {"document": document, **kwargs}
return result, "output_1"


class FileTypeClassifier(BaseComponent):
"""
Route files in an Indexing Pipeline to corresponding file converters.
"""
outgoing_edges = 5

def run(self, file_path: Path, **kwargs): # type: ignore
output = {"file_path": file_path, **kwargs}
ext = file_path.name.split(".")[-1].lower()
try:
index = ["txt", "pdf", "md", "docx", "html"].index(ext) + 1
return output, f"output_{index}"
except ValueError:
raise Exception(f"Files with an extension '{ext}' are not supported.")
9 changes: 5 additions & 4 deletions haystack/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,14 @@ def add_node(self, component, name: str, inputs: List[str]):
input_edge_name = "output_1"
self.graph.add_edge(input_node_name, name, label=input_edge_name)

def get_node(self, name: str):
def get_node(self, name: str) -> Optional[BaseComponent]:
"""
Get a node from the Pipeline.
:param name: The name of the node.
"""
component = self.graph.nodes[name]["component"]
graph_node = self.graph.nodes.get(name)
component = graph_node["component"] if graph_node else None
return component

def set_node(self, name: str, component):
Expand Down Expand Up @@ -219,7 +220,7 @@ def load_from_yaml(cls, path: Path, pipeline_name: Optional[str] = None, overwri
else:
pipelines_in_yaml = list(filter(lambda p: p["name"] == pipeline_name, data["pipelines"]))
if not pipelines_in_yaml:
raise Exception(f"Cannot find any pipeline with name '{pipeline_name}' declared in the YAML file.")
raise KeyError(f"Cannot find any pipeline with name '{pipeline_name}' declared in the YAML file.")
pipeline_config = pipelines_in_yaml[0]

definitions = {} # definitions of each component from the YAML.
Expand Down Expand Up @@ -252,7 +253,7 @@ def _load_or_get_component(cls, name: str, definitions: dict, components: dict):
if name in components.keys(): # check if component is already loaded.
return components[name]

component_params = definitions[name]["params"]
component_params = definitions[name].get("params", {})
component_type = definitions[name]["type"]
logger.debug(f"Loading component `{name}` of type `{definitions[name]['type']}`")

Expand Down
25 changes: 22 additions & 3 deletions haystack/preprocessor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,27 @@ def split(
) -> List[Dict[str, Any]]:
raise NotImplementedError

def run(self, document: dict, **kwargs): # type: ignore
documents = self.process(document)

def run( # type: ignore
self,
document: dict,
clean_whitespace: Optional[bool] = None,
clean_header_footer: Optional[bool] = None,
clean_empty_lines: Optional[bool] = None,
split_by: Optional[str] = None,
split_length: Optional[int] = None,
split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = None,
**kwargs,
):
documents = self.process(
document=document,
clean_whitespace=clean_whitespace,
clean_header_footer=clean_header_footer,
clean_empty_lines=clean_empty_lines,
split_by=split_by,
split_length=split_length,
split_overlap=split_overlap,
split_respect_sentence_boundary=split_respect_sentence_boundary,
)
result = {"documents": documents, **kwargs}
return result, "output_1"
12 changes: 4 additions & 8 deletions rest_api/application.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
import logging

import uvicorn
from elasticapm.contrib.starlette import make_apm_client, ElasticAPM
from fastapi import FastAPI, HTTPException
from starlette.middleware.cors import CORSMiddleware

from rest_api.config import APM_SERVER, APM_SERVICE_NAME
from rest_api.controller.errors.http_error import http_error_handler
from rest_api.controller.router import router as api_router

logging.basicConfig(format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p")
logger = logging.getLogger(__name__)
logging.getLogger("elasticsearch").setLevel(logging.WARNING)
logging.getLogger("haystack").setLevel(logging.INFO)


def get_application() -> FastAPI:
application = FastAPI(title="Haystack-API", debug=True, version="0.1")

# This middleware enables allow all cross-domain requests to the API from a browser. For production
# deployments, it could be made more restrictive.
application.add_middleware(
CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
)

if APM_SERVER:
apm_config = {"SERVICE_NAME": APM_SERVICE_NAME, "SERVER_URL": APM_SERVER, "CAPTURE_BODY": "all"}
elasticapm = make_apm_client(apm_config)
application.add_middleware(ElasticAPM, client=elasticapm)

application.add_exception_handler(HTTPException, http_error_handler)

application.include_router(api_router)
Expand All @@ -38,7 +34,7 @@ def get_application() -> FastAPI:
logger.info("Open http://127.0.0.1:8000/docs to see Swagger API Documentation.")
logger.info(
"""
Or just try it out directly: curl --request POST --url 'http://127.0.0.1:8000/models/1/doc-qa' --data '{"questions": ["What is the capital of Germany?"]}'
Or just try it out directly: curl --request POST --url 'http://127.0.0.1:8000/query' --data '{"query": "Did Albus Dumbledore die?"}'
"""
)

Expand Down
73 changes: 4 additions & 69 deletions rest_api/config.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,9 @@
import ast
import os

# FastAPI
PROJECT_NAME = os.getenv("PROJECT_NAME", "FastAPI")
PIPELINE_YAML_PATH = os.getenv("PIPELINE_YAML_PATH", "rest_api/pipelines.yaml")
QUERY_PIPELINE_NAME = os.getenv("QUERY_PIPELINE_NAME", "query")
INDEXING_PIPELINE_NAME = os.getenv("INDEXING_PIPELINE_NAME", "indexing")

# Resources / Computation
USE_GPU = os.getenv("USE_GPU", "True").lower() == "true"
GPU_NUMBER = int(os.getenv("GPU_NUMBER", 1))
MAX_PROCESSES = int(os.getenv("MAX_PROCESSES", 0))
BATCHSIZE = int(os.getenv("BATCHSIZE", 50))
CONCURRENT_REQUEST_PER_WORKER = int(os.getenv("CONCURRENT_REQUEST_PER_WORKER", 4))
FILE_UPLOAD_PATH = os.getenv("FILE_UPLOAD_PATH", "./file-upload")

# DB
DB_HOST = os.getenv("DB_HOST", "localhost")
DB_PORT = int(os.getenv("DB_PORT", 9200))
DB_USER = os.getenv("DB_USER", "")
DB_PW = os.getenv("DB_PW", "")
DB_INDEX = os.getenv("DB_INDEX", "document")
DB_INDEX_FEEDBACK = os.getenv("DB_INDEX_FEEDBACK", "label")
ES_CONN_SCHEME = os.getenv("ES_CONN_SCHEME", "http")
TEXT_FIELD_NAME = os.getenv("TEXT_FIELD_NAME", "text")
NAME_FIELD_NAME = os.getenv("NAME_FIELD_NAME", "name")
SEARCH_FIELD_NAME = os.getenv("SEARCH_FIELD_NAME", "text")
FAQ_QUESTION_FIELD_NAME = os.getenv("FAQ_QUESTION_FIELD_NAME", "question")
EMBEDDING_FIELD_NAME = os.getenv("EMBEDDING_FIELD_NAME", "embedding")
EMBEDDING_DIM = int(os.getenv("EMBEDDING_DIM", 768))
VECTOR_SIMILARITY_METRIC = os.getenv("VECTOR_SIMILARITY_METRIC", "dot_product")
CREATE_INDEX = os.getenv("CREATE_INDEX", "True").lower() == "true"
UPDATE_EXISTING_DOCUMENTS = os.getenv("UPDATE_EXISTING_DOCUMENTS", "False").lower() == "true"

# Reader
READER_MODEL_PATH = os.getenv("READER_MODEL_PATH", "deepset/roberta-base-squad2")
READER_TYPE = os.getenv("READER_TYPE", "FARMReader") # alternative: 'TransformersReader'
READER_TOKENIZER = os.getenv("READER_TOKENIZER", None)
CONTEXT_WINDOW_SIZE = int(os.getenv("CONTEXT_WINDOW_SIZE", 500))
DEFAULT_TOP_K_READER = int(os.getenv("DEFAULT_TOP_K_READER", 5)) # How many answers to return in total
TOP_K_PER_CANDIDATE = int(os.getenv("TOP_K_PER_CANDIDATE", 3)) # How many answers can come from one indexed doc
TOP_K_PER_SAMPLE = int(os.getenv("TOP_K_PER_SAMPLE", 1)) # How many answers can come from one passage that the reader processes at once (i.e. text of max_seq_len from the doc)
NO_ANS_BOOST = int(os.getenv("NO_ANS_BOOST", -10))
READER_CAN_HAVE_NO_ANSWER = os.getenv("READER_CAN_HAVE_NO_ANSWER", "True").lower() == "true"
DOC_STRIDE = int(os.getenv("DOC_STRIDE", 128))
MAX_SEQ_LEN = int(os.getenv("MAX_SEQ_LEN", 256))

# Retriever
RETRIEVER_TYPE = os.getenv("RETRIEVER_TYPE", "ElasticsearchRetriever") # alternatives: 'EmbeddingRetriever', 'ElasticsearchRetriever', 'ElasticsearchFilterOnlyRetriever', None
DEFAULT_TOP_K_RETRIEVER = int(os.getenv("DEFAULT_TOP_K_RETRIEVER", 5))
EXCLUDE_META_DATA_FIELDS = os.getenv("EXCLUDE_META_DATA_FIELDS", f"['question_emb','embedding']")
if EXCLUDE_META_DATA_FIELDS:
EXCLUDE_META_DATA_FIELDS = ast.literal_eval(EXCLUDE_META_DATA_FIELDS)
EMBEDDING_MODEL_PATH = os.getenv("EMBEDDING_MODEL_PATH", "deepset/sentence_bert")
EMBEDDING_MODEL_FORMAT = os.getenv("EMBEDDING_MODEL_FORMAT", "farm")

# File uploads
FILE_UPLOAD_PATH = os.getenv("FILE_UPLOAD_PATH", "file-uploads")
REMOVE_NUMERIC_TABLES = os.getenv("REMOVE_NUMERIC_TABLES", "True").lower() == "true"
VALID_LANGUAGES = os.getenv("VALID_LANGUAGES", None)
if VALID_LANGUAGES:
VALID_LANGUAGES = ast.literal_eval(VALID_LANGUAGES)

# Preprocessing
REMOVE_WHITESPACE = os.getenv("REMOVE_WHITESPACE", "True").lower() == "true"
REMOVE_EMPTY_LINES = os.getenv("REMOVE_EMPTY_LINES", "True").lower() == "true"
REMOVE_HEADER_FOOTER = os.getenv("REMOVE_HEADER_FOOTER", "True").lower() == "true"
SPLIT_BY = os.getenv("SPLIT_BY", "word")
SPLIT_LENGTH = os.getenv("SPLIT_LENGTH", 1_000)
SPLIT_OVERLAP = os.getenv("SPLIT_OVERLAP", None)
SPLIT_RESPECT_SENTENCE_BOUNDARY = os.getenv("SPLIT_RESPECT_SENTENCE_BOUNDARY", True)


# Monitoring
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
APM_SERVER = os.getenv("APM_SERVER", None)
APM_SERVICE_NAME = os.getenv("APM_SERVICE_NAME", "haystack-backend")
Loading

0 comments on commit 8c68699

Please sign in to comment.