diff --git a/haystack/nodes/reader/table.py b/haystack/nodes/reader/table.py index 78a5477acd..aea3ef0f1b 100644 --- a/haystack/nodes/reader/table.py +++ b/haystack/nodes/reader/table.py @@ -1,5 +1,11 @@ +from abc import abstractmethod from typing import List, Optional, Tuple, Dict, Union +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal # type: ignore + import logging from statistics import mean import torch @@ -98,7 +104,7 @@ def __init__( or commit hash. :param tokenizer: Name of the tokenizer (usually the same as model) :param use_gpu: Whether to use GPU or CPU. Falls back on CPU if no GPU is available. - :param top_k: The maximum number of answers to return + :param top_k: The maximum number of answers to return. :param top_k_per_candidate: How many answers to extract for each candidate table that is coming from the retriever. :param return_no_answer: Whether to include no_answer predictions in the results. @@ -128,27 +134,40 @@ def __init__( super().__init__() self.devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=False) - config = TapasConfig.from_pretrained(model_name_or_path, use_auth_token=use_auth_token) if len(self.devices) > 1: logger.warning( f"Multiple devices are not supported in {self.__class__.__name__} inference, " f"using the first device {self.devices[0]}." ) - if config.architectures[0] == "TapasForScoredQA": - self.model = self.TapasForScoredQA.from_pretrained( - model_name_or_path, revision=model_version, use_auth_token=use_auth_token + config = TapasConfig.from_pretrained(model_name_or_path, use_auth_token=use_auth_token) + self.table_encoder: Union[_TapasEncoder, _TapasScoredEncoder] + if config.architectures[0] == "TapasForQuestionAnswering": + self.table_encoder = _TapasEncoder( + device=self.devices[0], + model_name_or_path=model_name_or_path, + model_version=model_version, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + use_auth_token=use_auth_token, ) - else: - self.model = TapasForQuestionAnswering.from_pretrained( - model_name_or_path, revision=model_version, use_auth_token=use_auth_token + elif config.architectures[0] == "TapasForScoredQA": + self.table_encoder = _TapasScoredEncoder( + device=self.devices[0], + model_name_or_path=model_name_or_path, + model_version=model_version, + tokenizer=tokenizer, + top_k_per_candidate=top_k_per_candidate, + return_no_answer=return_no_answer, + max_seq_len=max_seq_len, + use_auth_token=use_auth_token, ) - self.model.to(str(self.devices[0])) - - if tokenizer is None: - self.tokenizer = TapasTokenizer.from_pretrained(model_name_or_path, use_auth_token=use_auth_token) else: - self.tokenizer = TapasTokenizer.from_pretrained(tokenizer, use_auth_token=use_auth_token) + logger.error( + "Unrecognized model architecture %s. Only the architectures TapasForQuestionAnswering and TapasForScoredQA are supported", + config.architectures[0], + ) + self.table_encoder.model.to(str(self.devices[0])) self.top_k = top_k self.top_k_per_candidate = top_k_per_candidate @@ -172,9 +191,93 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] = """ if top_k is None: top_k = self.top_k + return self.table_encoder.predict(query=query, documents=documents, top_k=top_k) - answers = [] - no_answer_score = 1.0 + def predict_batch( + self, + queries: List[str], + documents: Union[List[Document], List[List[Document]]], + top_k: Optional[int] = None, + batch_size: Optional[int] = None, + ): + """ + Use loaded TableQA model to find answers for the supplied queries in the supplied Documents + of content_type ``'table'``. + + Returns dictionary containing query and list of Answer objects sorted by (desc.) score. + WARNING: The answer scores are not reliable, as they are always extremely high, even if + a question cannot be answered by a given table. + + - If you provide a list containing a single query... + - ... and a single list of Documents, the query will be applied to each Document individually. + - ... and a list of lists of Documents, the query will be applied to each list of Documents and the Answers + will be aggregated per Document list. + + - If you provide a list of multiple queries... + - ... and a single list of Documents, each query will be applied to each Document individually. + - ... and a list of lists of Documents, each query will be applied to its corresponding list of Documents + and the Answers will be aggregated per query-Document pair. + + :param queries: Single query string or list of queries. + :param documents: Single list of Documents or list of lists of Documents in which to search for the answers. + Documents should be of content_type ``'table'``. + :param top_k: The maximum number of answers to return per query. + :param batch_size: Not applicable. + """ + results: Dict = {"queries": queries, "answers": []} + + single_doc_list = False + # Docs case 1: single list of Documents -> apply each query to all Documents + if len(documents) > 0 and isinstance(documents[0], Document): + single_doc_list = True + for query in queries: + for doc in documents: + if not isinstance(doc, Document): + raise HaystackError(f"doc was of type {type(doc)}, but expected a Document.") + preds = self.predict(query=query, documents=[doc], top_k=top_k) + results["answers"].append(preds["answers"]) + + # Docs case 2: list of lists of Documents -> apply each query to corresponding list of Documents, if queries + # contains only one query, apply it to each list of Documents + elif len(documents) > 0 and isinstance(documents[0], list): + if len(queries) == 1: + queries = queries * len(documents) + if len(queries) != len(documents): + raise HaystackError("Number of queries must be equal to number of provided Document lists.") + for query, cur_docs in zip(queries, documents): + if not isinstance(cur_docs, list): + raise HaystackError(f"cur_docs was of type {type(cur_docs)}, but expected a list of Documents.") + preds = self.predict(query=query, documents=cur_docs, top_k=top_k) + results["answers"].append(preds["answers"]) + + # Group answers by question in case of multiple queries and single doc list + if single_doc_list and len(queries) > 1: + answers_per_query = int(len(results["answers"]) / len(queries)) + answers = [] + for i in range(0, len(results["answers"]), answers_per_query): + answer_group = results["answers"][i : i + answers_per_query] + answers.append(answer_group) + results["answers"] = answers + + return results + + +class _BaseTapasEncoder: + @staticmethod + def _calculate_answer_offsets(answer_coordinates: List[Tuple[int, int]], table: pd.DataFrame) -> List[Span]: + """ + Calculates the answer cell offsets of the linearized table based on the answer cell coordinates. + """ + answer_offsets = [] + n_rows, n_columns = table.shape + for coord in answer_coordinates: + answer_cell_offset = (coord[0] * n_columns) + coord[1] + answer_offsets.append(Span(start=answer_cell_offset, end=answer_cell_offset + 1)) + return answer_offsets + + @staticmethod + def _check_documents(documents: List[Document]) -> List[Document]: + table_documents = [] for document in documents: if document.content_type != "table": logger.warning("Skipping document with id '%s' in TableReader as it is not of type table.", document.id) @@ -186,59 +289,62 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] = "Skipping document with id '%s' in TableReader as it does not contain any rows.", document.id ) continue - # Tokenize query and current table - inputs = self.tokenizer( - table=table, queries=query, max_length=self.max_seq_len, return_tensors="pt", truncation=True - ) - inputs.to(self.devices[0]) - - if isinstance(self.model, TapasForQuestionAnswering): - current_answer = self._predict_tapas_for_qa(inputs, document) - answers.append(current_answer) - elif isinstance(self.model, self.TapasForScoredQA): - current_answers, current_no_answer_score = self._predict_tapas_for_scored_qa(inputs, document) - answers.extend(current_answers) - if current_no_answer_score < no_answer_score: - no_answer_score = current_no_answer_score - - if self.return_no_answer and isinstance(self.model, self.TapasForScoredQA): - answers.append( - Answer( - answer="", - type="extractive", - score=no_answer_score, - context=None, - offsets_in_context=[Span(start=0, end=0)], - offsets_in_document=[Span(start=0, end=0)], - document_id=None, - meta=None, - ) - ) - answers = sorted(answers, reverse=True) - answers = answers[:top_k] - results = {"query": query, "answers": answers} + table_documents.append(document) + return table_documents - return results + @staticmethod + def _preprocess(query: str, table: pd.DataFrame, tokenizer, max_seq_len) -> BatchEncoding: + """Tokenize the query and table.""" + model_inputs = tokenizer( + table=table, queries=query, max_length=max_seq_len, return_tensors="pt", truncation=True + ) + return model_inputs + + @abstractmethod + def predict(self, query: str, documents: List[Document], top_k: int) -> Dict: + pass + + +class _TapasEncoder(_BaseTapasEncoder): + def __init__( + self, + device: torch.device, + model_name_or_path: str = "google/tapas-base-finetuned-wtq", + model_version: Optional[str] = None, + tokenizer: Optional[str] = None, + max_seq_len: int = 256, + use_auth_token: Optional[Union[str, bool]] = None, + ): + self.model = TapasForQuestionAnswering.from_pretrained( + model_name_or_path, revision=model_version, use_auth_token=use_auth_token + ) + if tokenizer is None: + self.tokenizer = TapasTokenizer.from_pretrained(model_name_or_path, use_auth_token=use_auth_token) + else: + self.tokenizer = TapasTokenizer.from_pretrained(tokenizer, use_auth_token=use_auth_token) + self.max_seq_len = max_seq_len + self.device = device - def _predict_tapas_for_qa(self, inputs: BatchEncoding, document: Document) -> Answer: + def _predict_tapas(self, inputs: BatchEncoding, document: Document) -> Answer: table: pd.DataFrame = document.content # Forward query and table through model and convert logits to predictions - outputs = self.model(**inputs) + with torch.no_grad(): + outputs = self.model(**inputs) + inputs.to("cpu") - if self.model.config.num_aggregation_labels > 0: - aggregation_logits = outputs.logits_aggregation.cpu().detach() - else: - aggregation_logits = None + outputs_logits = outputs.logits.cpu() - predicted_output = self.tokenizer.convert_logits_to_predictions( - inputs, outputs.logits.cpu().detach(), aggregation_logits - ) - if len(predicted_output) == 1: - predicted_answer_coordinates = predicted_output[0] + if self.model.config.num_aggregation_labels > 0: + aggregation_logits = outputs.logits_aggregation.cpu() + predicted_answer_coordinates, predicted_aggregation_indices = self.tokenizer.convert_logits_to_predictions( + inputs, outputs_logits, logits_agg=aggregation_logits, cell_classification_threshold=0.5 + ) else: - predicted_answer_coordinates, predicted_aggregation_indices = predicted_output + predicted_answer_coordinates = self.tokenizer.convert_logits_to_predictions( + inputs, outputs_logits, logits_agg=None, cell_classification_threshold=0.5 + ) # Get cell values current_answer_coordinates = predicted_answer_coordinates[0] @@ -253,7 +359,7 @@ def _predict_tapas_for_qa(self, inputs: BatchEncoding, document: Document) -> An current_aggregation_operator = "NONE" # Calculate answer score - current_score = self._calculate_answer_score(outputs.logits.cpu().detach(), inputs, current_answer_coordinates) + current_score = self._calculate_answer_score(outputs_logits, inputs, current_answer_coordinates) if current_aggregation_operator == "NONE": answer_str = ", ".join(current_answer_cells) @@ -272,19 +378,127 @@ def _predict_tapas_for_qa(self, inputs: BatchEncoding, document: Document) -> An document_id=document.id, meta={"aggregation_operator": current_aggregation_operator, "answer_cells": current_answer_cells}, ) - return answer - def _predict_tapas_for_scored_qa(self, inputs: BatchEncoding, document: Document) -> Tuple[List[Answer], float]: + def _calculate_answer_score( + self, logits: torch.Tensor, inputs: BatchEncoding, answer_coordinates: List[Tuple[int, int]] + ) -> float: + # Calculate answer score + # Values over 88.72284 will overflow when passed through exponential, so logits are truncated. + logits[logits < -88.7] = -88.7 + token_probabilities = 1 / (1 + np.exp(-logits)) * inputs.attention_mask + token_types = [ + "segment_ids", + "column_ids", + "row_ids", + "prev_labels", + "column_ranks", + "inv_column_ranks", + "numeric_relations", + ] + + segment_ids = inputs.token_type_ids[0, :, token_types.index("segment_ids")].tolist() + column_ids = inputs.token_type_ids[0, :, token_types.index("column_ids")].tolist() + row_ids = inputs.token_type_ids[0, :, token_types.index("row_ids")].tolist() + all_cell_probabilities = self.tokenizer._get_mean_cell_probs( + token_probabilities[0].tolist(), segment_ids, row_ids, column_ids + ) + # _get_mean_cell_probs seems to index cells by (col, row). DataFrames are, however, indexed by (row, col). + all_cell_probabilities = {(row, col): prob for (col, row), prob in all_cell_probabilities.items()} + answer_cell_probabilities = [all_cell_probabilities[coord] for coord in answer_coordinates] + + return np.mean(answer_cell_probabilities) + + @staticmethod + def _aggregate_answers(agg_operator: Literal["COUNT", "SUM", "AVERAGE"], answer_cells: List[str]) -> str: + if agg_operator == "COUNT": + return str(len(answer_cells)) + + # No aggregation needed as only one cell selected as answer_cells + if len(answer_cells) == 1: + return answer_cells[0] + # Return empty string if model did not select any cell as answer + if len(answer_cells) == 0: + return "" + + # Parse answer cells in order to aggregate numerical values + parsed_answer_cells = [parser.parse(cell) for cell in answer_cells] + # Check if all cells contain at least one numerical value and that all values share the same unit + try: + if all(parsed_answer_cells) and all( + cell[0].unit.name == parsed_answer_cells[0][0].unit.name for cell in parsed_answer_cells + ): + numerical_values = [cell[0].value for cell in parsed_answer_cells] + unit = parsed_answer_cells[0][0].unit.symbols[0] if parsed_answer_cells[0][0].unit.symbols else "" + + if agg_operator == "SUM": + answer_value = sum(numerical_values) + elif agg_operator == "AVERAGE": + answer_value = mean(numerical_values) + else: + raise ValueError("unknown aggregator") + + return f"{answer_value}{' ' + unit if unit else ''}" + + except ValueError as e: + if "unknown aggregator" in str(e): + pass + + # Not all selected answer cells contain a numerical value or answer cells don't share the same unit + return f"{agg_operator} > {', '.join(answer_cells)}" + + def predict(self, query: str, documents: List[Document], top_k: int) -> Dict: + answers = [] + table_documents = self._check_documents(documents) + for document in table_documents: + table: pd.DataFrame = document.content + model_inputs = self._preprocess(query, table, self.tokenizer, self.max_seq_len) + model_inputs.to(self.device) + + current_answer = self._predict_tapas(model_inputs, document) + answers.append(current_answer) + + answers = sorted(answers, reverse=True) + results = {"query": query, "answers": answers[:top_k]} + return results + + +class _TapasScoredEncoder(_BaseTapasEncoder): + def __init__( + self, + device: torch.device, + model_name_or_path: str = "deepset/tapas-large-nq-hn-reader", + model_version: Optional[str] = None, + tokenizer: Optional[str] = None, + top_k_per_candidate: int = 3, + return_no_answer: bool = False, + max_seq_len: int = 256, + use_auth_token: Optional[Union[str, bool]] = None, + ): + self.model = self._TapasForScoredQA.from_pretrained( + model_name_or_path, revision=model_version, use_auth_token=use_auth_token + ) + if tokenizer is None: + self.tokenizer = TapasTokenizer.from_pretrained(model_name_or_path, use_auth_token=use_auth_token) + else: + self.tokenizer = TapasTokenizer.from_pretrained(tokenizer, use_auth_token=use_auth_token) + self.max_seq_len = max_seq_len + self.device = device + self.top_k_per_candidate = top_k_per_candidate + self.return_no_answer = return_no_answer + + def _predict_tapas_scored(self, inputs: BatchEncoding, document: Document) -> Tuple[List[Answer], float]: table: pd.DataFrame = document.content # Forward pass through model - outputs = self.model.tapas(**inputs) + with torch.no_grad(): + outputs = self.model.tapas(**inputs) # Get general table score table_score = self.model.classifier(outputs.pooler_output) table_score_softmax = torch.nn.functional.softmax(table_score, dim=1) table_relevancy_prob = table_score_softmax[0][1].item() + no_answer_score = table_score_softmax[0][0].item() # Get possible answer spans token_types = [ @@ -302,21 +516,31 @@ def _predict_tapas_for_scored_qa(self, inputs: BatchEncoding, document: Document possible_answer_spans: List[ Tuple[int, int, int, int] ] = [] # List of tuples: (row_idx, col_idx, start_token, end_token) - current_start_idx = -1 + current_start_token_idx = -1 current_column_id = -1 - for idx, (row_id, column_id) in enumerate(zip(row_ids, column_ids)): + for token_idx, (row_id, column_id) in enumerate(zip(row_ids, column_ids)): if row_id == 0 or column_id == 0: continue # Beginning of new cell if column_id != current_column_id: - if current_start_idx != -1: + if current_start_token_idx != -1: possible_answer_spans.append( - (row_ids[current_start_idx] - 1, column_ids[current_start_idx] - 1, current_start_idx, idx - 1) + ( + row_ids[current_start_token_idx] - 1, + column_ids[current_start_token_idx] - 1, + current_start_token_idx, + token_idx - 1, + ) ) - current_start_idx = idx + current_start_token_idx = token_idx current_column_id = column_id possible_answer_spans.append( - (row_ids[current_start_idx] - 1, column_ids[current_start_idx] - 1, current_start_idx, len(row_ids) - 1) + ( + row_ids[current_start_token_idx] - 1, + column_ids[current_start_token_idx] - 1, + current_start_token_idx, + len(row_ids) - 1, + ) ) # Concat logits of start token and end token of possible answer spans @@ -358,160 +582,41 @@ def _predict_tapas_for_scored_qa(self, inputs: BatchEncoding, document: Document ) ) - no_answer_score = 1 - table_relevancy_prob - return answers, no_answer_score - def _calculate_answer_score( - self, logits: torch.Tensor, inputs: BatchEncoding, answer_coordinates: List[Tuple[int, int]] - ) -> float: - """ - Calculates the answer score by computing each cell's probability of being part of the answer - and taking the mean probability of the answer cells. - """ - # Calculate answer score - # Values over 88.72284 will overflow when passed through exponential, so logits are truncated. - logits[logits < -88.7] = -88.7 - token_probabilities = 1 / (1 + np.exp(-logits)) * inputs.attention_mask - - segment_ids = inputs.token_type_ids[0, :, 0].tolist() - column_ids = inputs.token_type_ids[0, :, 1].tolist() - row_ids = inputs.token_type_ids[0, :, 2].tolist() - all_cell_probabilities = self.tokenizer._get_mean_cell_probs( - token_probabilities[0].tolist(), segment_ids, row_ids, column_ids - ) - # _get_mean_cell_probs seems to index cells by (col, row). DataFrames are, however, indexed by (row, col). - all_cell_probabilities = {(row, col): prob for (col, row), prob in all_cell_probabilities.items()} - answer_cell_probabilities = [all_cell_probabilities[coord] for coord in answer_coordinates] - - return np.mean(answer_cell_probabilities) - - @staticmethod - def _aggregate_answers(agg_operator: str, answer_cells: List[str]) -> str: - if agg_operator == "COUNT": - return str(len(answer_cells)) - - # No aggregation needed as only one cell selected as answer_cells - if len(answer_cells) == 1: - return answer_cells[0] - # Return empty string if model did not select any cell as answer - if len(answer_cells) == 0: - return "" - - # Parse answer cells in order to aggregate numerical values - parsed_answer_cells = [parser.parse(cell) for cell in answer_cells] - # Check if all cells contain at least one numerical value and that all values share the same unit - try: - if all(parsed_answer_cells) and all( - cell[0].unit.name == parsed_answer_cells[0][0].unit.name for cell in parsed_answer_cells - ): - numerical_values = [cell[0].value for cell in parsed_answer_cells] - unit = parsed_answer_cells[0][0].unit.symbols[0] if parsed_answer_cells[0][0].unit.symbols else "" - - if agg_operator == "SUM": - answer_value = sum(numerical_values) - elif agg_operator == "AVERAGE": - answer_value = mean(numerical_values) - else: - raise KeyError("unknown aggregator") - - return f"{answer_value}{' ' + unit if unit else ''}" - - except KeyError as e: - if "unknown aggregator" in str(e): - pass - - # Not all selected answer cells contain a numerical value or answer cells don't share the same unit - return f"{agg_operator} > {', '.join(answer_cells)}" - - @staticmethod - def _calculate_answer_offsets(answer_coordinates: List[Tuple[int, int]], table: pd.DataFrame) -> List[Span]: - """ - Calculates the answer cell offsets of the linearized table based on the - answer cell coordinates. - """ - answer_offsets = [] - n_rows, n_columns = table.shape - for coord in answer_coordinates: - answer_cell_offset = (coord[0] * n_columns) + coord[1] - answer_offsets.append(Span(start=answer_cell_offset, end=answer_cell_offset + 1)) - - return answer_offsets - - def predict_batch( - self, - queries: List[str], - documents: Union[List[Document], List[List[Document]]], - top_k: Optional[int] = None, - batch_size: Optional[int] = None, - ): - """ - Use loaded TableQA model to find answers for the supplied queries in the supplied Documents - of content_type ``'table'``. - - Returns dictionary containing query and list of Answer objects sorted by (desc.) score. - - WARNING: The answer scores are not reliable, as they are always extremely high, even if - a question cannot be answered by a given table. - - - If you provide a list containing a single query... - - - ... and a single list of Documents, the query will be applied to each Document individually. - - ... and a list of lists of Documents, the query will be applied to each list of Documents and the Answers - will be aggregated per Document list. - - - If you provide a list of multiple queries... - - - ... and a single list of Documents, each query will be applied to each Document individually. - - ... and a list of lists of Documents, each query will be applied to its corresponding list of Documents - and the Answers will be aggregated per query-Document pair. - - :param queries: Single query string or list of queries. - :param documents: Single list of Documents or list of lists of Documents in which to search for the answers. - Documents should be of content_type ``'table'``. - :param top_k: The maximum number of answers to return per query. - :param batch_size: Not applicable. - """ - # TODO: This method currently just calls the predict method multiple times, so there is room for improvement. - - results: Dict = {"queries": queries, "answers": []} - - single_doc_list = False - # Docs case 1: single list of Documents -> apply each query to all Documents - if len(documents) > 0 and isinstance(documents[0], Document): - single_doc_list = True - for query in queries: - for doc in documents: - if not isinstance(doc, Document): - raise HaystackError(f"doc was of type {type(doc)}, but expected a Document.") - preds = self.predict(query=query, documents=[doc], top_k=top_k) - results["answers"].append(preds["answers"]) + def predict(self, query: str, documents: List[Document], top_k: int) -> Dict: + answers = [] + no_answer_score = 1.0 + table_documents = self._check_documents(documents) + for document in table_documents: + table: pd.DataFrame = document.content + model_inputs = self._preprocess(query, table, self.tokenizer, self.max_seq_len) + model_inputs.to(self.device) - # Docs case 2: list of lists of Documents -> apply each query to corresponding list of Documents, if queries - # contains only one query, apply it to each list of Documents - elif len(documents) > 0 and isinstance(documents[0], list): - if len(queries) == 1: - queries = queries * len(documents) - if len(queries) != len(documents): - raise HaystackError("Number of queries must be equal to number of provided Document lists.") - for query, cur_docs in zip(queries, documents): - if not isinstance(cur_docs, list): - raise HaystackError(f"cur_docs was of type {type(cur_docs)}, but expected a list of Documents.") - preds = self.predict(query=query, documents=cur_docs, top_k=top_k) - results["answers"].append(preds["answers"]) + current_answers, current_no_answer_score = self._predict_tapas_scored(model_inputs, document) + answers.extend(current_answers) + if current_no_answer_score < no_answer_score: + no_answer_score = current_no_answer_score - # Group answers by question in case of multiple queries and single doc list - if single_doc_list and len(queries) > 1: - answers_per_query = int(len(results["answers"]) / len(queries)) - answers = [] - for i in range(0, len(results["answers"]), answers_per_query): - answer_group = results["answers"][i : i + answers_per_query] - answers.append(answer_group) - results["answers"] = answers + if self.return_no_answer: + answers.append( + Answer( + answer="", + type="extractive", + score=no_answer_score, + context=None, + offsets_in_context=[Span(start=0, end=0)], + offsets_in_document=[Span(start=0, end=0)], + document_id=None, + meta=None, + ) + ) + answers = sorted(answers, reverse=True) + results = {"query": query, "answers": answers[:top_k]} return results - class TapasForScoredQA(TapasPreTrainedModel): + class _TapasForScoredQA(TapasPreTrainedModel): def __init__(self, config): super().__init__(config) diff --git a/test/conftest.py b/test/conftest.py index 7f0a62eacc..54dbce2316 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -684,10 +684,14 @@ def reader(request): ) -@pytest.fixture(params=["tapas", "rci"]) +@pytest.fixture(params=["tapas_small", "tapas_base", "tapas_scored", "rci"]) def table_reader(request): - if request.param == "tapas": + if request.param == "tapas_small": + return TableReader(model_name_or_path="google/tapas-small-finetuned-wtq") + elif request.param == "tapas_base": return TableReader(model_name_or_path="google/tapas-base-finetuned-wtq") + elif request.param == "tapas_scored": + return TableReader(model_name_or_path="deepset/tapas-large-nq-hn-reader") elif request.param == "rci": return RCIReader( row_model_name_or_path="michaelrglass/albert-base-rci-wikisql-row", diff --git a/test/nodes/test_table_reader.py b/test/nodes/test_table_reader.py index 4e8358bd25..44c4ca206c 100644 --- a/test/nodes/test_table_reader.py +++ b/test/nodes/test_table_reader.py @@ -7,6 +7,7 @@ from haystack.pipelines.base import Pipeline +@pytest.mark.parametrize("table_reader", ["tapas_small", "rci", "tapas_scored"], indirect=True) def test_table_reader(table_reader): data = { "actors": ["brad pitt", "leonardo di caprio", "george clooney"], @@ -23,6 +24,7 @@ def test_table_reader(table_reader): assert prediction["answers"][0].offsets_in_context[0].end == 8 +@pytest.mark.parametrize("table_reader", ["tapas_small", "rci"], indirect=True) def test_table_reader_batch_single_query_single_doc_list(table_reader): data = { "actors": ["brad pitt", "leonardo di caprio", "george clooney"], @@ -41,6 +43,7 @@ def test_table_reader_batch_single_query_single_doc_list(table_reader): assert len(prediction["answers"]) == 1 # Predictions for 5 docs +@pytest.mark.parametrize("table_reader", ["tapas_small", "rci"], indirect=True) def test_table_reader_batch_single_query_multiple_doc_lists(table_reader): data = { "actors": ["brad pitt", "leonardo di caprio", "george clooney"], @@ -61,6 +64,7 @@ def test_table_reader_batch_single_query_multiple_doc_lists(table_reader): assert len(prediction["answers"]) == 1 # Predictions for 1 collection of docs +@pytest.mark.parametrize("table_reader", ["tapas_small", "rci"], indirect=True) def test_table_reader_batch_multiple_queries_single_doc_list(table_reader): data = { "actors": ["brad pitt", "leonardo di caprio", "george clooney"], @@ -82,6 +86,7 @@ def test_table_reader_batch_multiple_queries_single_doc_list(table_reader): assert len(prediction["answers"]) == 2 # Predictions for 2 queries +@pytest.mark.parametrize("table_reader", ["tapas_small", "rci"], indirect=True) def test_table_reader_batch_multiple_queries_multiple_doc_lists(table_reader): data = { "actors": ["brad pitt", "leonardo di caprio", "george clooney"], @@ -103,6 +108,7 @@ def test_table_reader_batch_multiple_queries_multiple_doc_lists(table_reader): assert len(prediction["answers"]) == 2 # Predictions for 2 collections of documents +@pytest.mark.parametrize("table_reader", ["tapas_small", "rci"], indirect=True) def test_table_reader_in_pipeline(table_reader): pipeline = Pipeline() pipeline.add_node(table_reader, "TableReader", ["Query"]) @@ -123,7 +129,7 @@ def test_table_reader_in_pipeline(table_reader): assert prediction["answers"][0].offsets_in_context[0].end == 8 -@pytest.mark.parametrize("table_reader", ["tapas"], indirect=True) +@pytest.mark.parametrize("table_reader", ["tapas_base"], indirect=True) def test_table_reader_aggregation(table_reader): data = { "Mountain": ["Mount Everest", "K2", "Kangchenjunga", "Lhotse", "Makalu"], @@ -144,6 +150,7 @@ def test_table_reader_aggregation(table_reader): assert prediction["answers"][0].meta["answer_cells"] == ["8848m", "8,611 m", "8 586m", "8 516 m", "8,485m"] +@pytest.mark.parametrize("table_reader", ["tapas_small", "rci"], indirect=True) def test_table_without_rows(caplog, table_reader): # empty DataFrame table = pd.DataFrame() @@ -154,6 +161,7 @@ def test_table_without_rows(caplog, table_reader): assert len(predictions["answers"]) == 0 +@pytest.mark.parametrize("table_reader", ["tapas_small", "rci"], indirect=True) def test_text_document(caplog, table_reader): document = Document(content="text", id="text_doc") with caplog.at_level(logging.WARNING):