Skip to content

Commit

Permalink
bug: change type of split_by to Literal including None (#3389)
Browse files Browse the repository at this point in the history
* change type of split_by

* fix mpy and update schema files

* change split_by type to Literal

* handle ImportError for Literal py<3.8
  • Loading branch information
julian-risch authored Oct 19, 2022
1 parent f4a49f7 commit 16723bf
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 8 deletions.
5 changes: 5 additions & 0 deletions haystack/json-schemas/haystack-pipeline-1.10.0rc0.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -3993,6 +3993,11 @@
},
"split_by": {
"default": "word",
"enum": [
"word",
"sentence",
"passage"
],
"title": "Split By",
"type": "string"
},
Expand Down
5 changes: 5 additions & 0 deletions haystack/json-schemas/haystack-pipeline-main.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -3993,6 +3993,11 @@
},
"split_by": {
"default": "word",
"enum": [
"word",
"sentence",
"passage"
],
"title": "Split By",
"type": "string"
},
Expand Down
13 changes: 9 additions & 4 deletions haystack/nodes/preprocessor/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from typing import List, Optional, Union

try:
from typing import Literal
except ImportError:
from typing_extensions import Literal # type: ignore

from abc import abstractmethod

from haystack.nodes.base import BaseComponent
Expand All @@ -17,7 +22,7 @@ def process(
clean_header_footer: Optional[bool] = False,
clean_empty_lines: Optional[bool] = True,
remove_substrings: List[str] = [],
split_by: Optional[str] = "word",
split_by: Literal["word", "sentence", "passage", None] = "word",
split_length: Optional[int] = 1000,
split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = True,
Expand All @@ -44,7 +49,7 @@ def clean(
def split(
self,
document: Union[dict, Document],
split_by: str,
split_by: Literal["word", "sentence", "passage", None],
split_length: int,
split_overlap: int,
split_respect_sentence_boundary: bool,
Expand All @@ -57,7 +62,7 @@ def run( # type: ignore
clean_whitespace: Optional[bool] = None,
clean_header_footer: Optional[bool] = None,
clean_empty_lines: Optional[bool] = None,
split_by: Optional[str] = None,
split_by: Literal["word", "sentence", "passage", None] = None,
split_length: Optional[int] = None,
split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = None,
Expand All @@ -83,7 +88,7 @@ def run_batch( # type: ignore
clean_whitespace: Optional[bool] = None,
clean_header_footer: Optional[bool] = None,
clean_empty_lines: Optional[bool] = None,
split_by: Optional[str] = None,
split_by: Literal["word", "sentence", "passage", None] = None,
split_length: Optional[int] = None,
split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = None,
Expand Down
13 changes: 9 additions & 4 deletions haystack/nodes/preprocessor/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from functools import partial, reduce
from itertools import chain
from typing import List, Optional, Generator, Set, Union

try:
from typing import Literal
except ImportError:
from typing_extensions import Literal # type: ignore
import warnings
from pathlib import Path
from pickle import UnpicklingError
Expand Down Expand Up @@ -52,7 +57,7 @@ def __init__(
clean_header_footer: bool = False,
clean_empty_lines: bool = True,
remove_substrings: List[str] = [],
split_by: str = "word",
split_by: Literal["word", "sentence", "passage", None] = "word",
split_length: int = 200,
split_overlap: int = 0,
split_respect_sentence_boundary: bool = True,
Expand Down Expand Up @@ -124,7 +129,7 @@ def process(
clean_header_footer: Optional[bool] = None,
clean_empty_lines: Optional[bool] = None,
remove_substrings: List[str] = [],
split_by: Optional[str] = None,
split_by: Literal["word", "sentence", "passage", None] = None,
split_length: Optional[int] = None,
split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = None,
Expand Down Expand Up @@ -172,7 +177,7 @@ def _process_single(
clean_header_footer: Optional[bool] = None,
clean_empty_lines: Optional[bool] = None,
remove_substrings: List[str] = [],
split_by: Optional[str] = None,
split_by: Literal["word", "sentence", "passage", None] = None,
split_length: Optional[int] = None,
split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = None,
Expand Down Expand Up @@ -291,7 +296,7 @@ def clean(
def split(
self,
document: Union[dict, Document],
split_by: str,
split_by: Literal["word", "sentence", "passage", None],
split_length: int,
split_overlap: int,
split_respect_sentence_boundary: bool,
Expand Down

0 comments on commit 16723bf

Please sign in to comment.