Skip to content

Commit

Permalink
Add pickling support for Python tokenizers (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
rth committed Jun 12, 2020
1 parent 025468e commit 172838c
Show file tree
Hide file tree
Showing 11 changed files with 210 additions and 70 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ keywords = [
"tokenization",
"tfidf",
"levenshtein",
"matching"
"text-processing"
]
edition = "2018"
exclude = [
Expand Down Expand Up @@ -41,6 +41,7 @@ lazy_static = "1.4.0"
seahash = "4.0.0"
itertools = "0.8"
ndarray = "0.13.0"
serde = { version = "1.0", features = ["derive"] }
sprs = {version = "0.7.1", default-features = false}
unicode-segmentation = "1.6.0"
hashbrown = { version = "0.7", features = ["rayon"] }
Expand Down
2 changes: 2 additions & 0 deletions python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ crate-type = ["cdylib"]

[dependencies]
ndarray = "0.13"
serde = { version = "1.0", features = ["derive"] }
sprs = {version = "0.7.1", default-features = false}
vtext = {"path" = "../", features=["python", "rayon"]}
rust-stemmers = "1.1"
rayon = "1.3"
bincode = "1.2.1"

[dependencies.numpy]
version = "0.9.0"
Expand Down
1 change: 1 addition & 0 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use pyo3::wrap_pyfunction;
mod stem;
mod tokenize;
mod tokenize_sentence;
mod utils;
mod vectorize;

use vtext::metrics;
Expand Down
70 changes: 46 additions & 24 deletions python/src/stem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
// <http://apache.org/licenses/LICENSE-2.0>. This file may not be copied,
// modified, or distributed except according to those terms.

use crate::utils::{deserialize_params, serialize_params};
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::PyDict;

/// __init__(self, lang='english')
///
Expand All @@ -14,40 +16,43 @@ use pyo3::prelude::*;
/// Wraps the rust-stemmers crate that uses an implementation generated
/// by the `Snowball compiler <https://github.com/snowballstem/snowball>`_
/// for Rust.
#[pyclass]
#[pyclass(module = "vtext.stem")]
pub struct SnowballStemmer {
pub lang: String,
inner: rust_stemmers::Stemmer,
}

fn get_algorithm(lang: &str) -> PyResult<rust_stemmers::Algorithm> {
match lang {
"arabic" => Ok(rust_stemmers::Algorithm::Arabic),
"danish" => Ok(rust_stemmers::Algorithm::Danish),
"dutch" => Ok(rust_stemmers::Algorithm::Dutch),
"english" => Ok(rust_stemmers::Algorithm::English),
"french" => Ok(rust_stemmers::Algorithm::French),
"german" => Ok(rust_stemmers::Algorithm::German),
"greek" => Ok(rust_stemmers::Algorithm::Greek),
"hungarian" => Ok(rust_stemmers::Algorithm::Hungarian),
"italian" => Ok(rust_stemmers::Algorithm::Italian),
"portuguese" => Ok(rust_stemmers::Algorithm::Portuguese),
"romanian" => Ok(rust_stemmers::Algorithm::Romanian),
"russian" => Ok(rust_stemmers::Algorithm::Russian),
"spanish" => Ok(rust_stemmers::Algorithm::Spanish),
"swedish" => Ok(rust_stemmers::Algorithm::Swedish),
"tamil" => Ok(rust_stemmers::Algorithm::Tamil),
"turkish" => Ok(rust_stemmers::Algorithm::Turkish),
_ => Err(exceptions::ValueError::py_err(format!(
"lang={} is unsupported!",
lang
))),
}
}

#[pymethods]
impl SnowballStemmer {
#[new]
#[args(lang = "\"english\"")]
fn new(lang: &str) -> PyResult<Self> {
let algorithm = match lang {
"arabic" => Ok(rust_stemmers::Algorithm::Arabic),
"danish" => Ok(rust_stemmers::Algorithm::Danish),
"dutch" => Ok(rust_stemmers::Algorithm::Dutch),
"english" => Ok(rust_stemmers::Algorithm::English),
"french" => Ok(rust_stemmers::Algorithm::French),
"german" => Ok(rust_stemmers::Algorithm::German),
"greek" => Ok(rust_stemmers::Algorithm::Greek),
"hungarian" => Ok(rust_stemmers::Algorithm::Hungarian),
"italian" => Ok(rust_stemmers::Algorithm::Italian),
"portuguese" => Ok(rust_stemmers::Algorithm::Portuguese),
"romanian" => Ok(rust_stemmers::Algorithm::Romanian),
"russian" => Ok(rust_stemmers::Algorithm::Russian),
"spanish" => Ok(rust_stemmers::Algorithm::Spanish),
"swedish" => Ok(rust_stemmers::Algorithm::Swedish),
"tamil" => Ok(rust_stemmers::Algorithm::Tamil),
"turkish" => Ok(rust_stemmers::Algorithm::Turkish),
_ => Err(exceptions::ValueError::py_err(format!(
"lang={} is unsupported!",
lang
))),
}?;

let algorithm = get_algorithm(lang)?;
let stemmer = rust_stemmers::Stemmer::create(algorithm);

Ok(SnowballStemmer {
Expand All @@ -73,4 +78,21 @@ impl SnowballStemmer {
let res = self.inner.stem(word).to_string();
Ok(res)
}

fn get_params<'p>(&self, py: Python<'p>) -> PyResult<&'p PyDict> {
let params = PyDict::new(py);
params.set_item("lang", self.lang.clone())?;
Ok(params)
}

pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
serialize_params(&self.lang, py)
}

pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
self.lang = deserialize_params(py, state)?;
let algorithm = get_algorithm(&self.lang)?;
self.inner = rust_stemmers::Stemmer::create(algorithm);
Ok(())
}
}
84 changes: 50 additions & 34 deletions python/src/tokenize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
use pyo3::prelude::*;
use pyo3::types::PyList;

use crate::utils::{deserialize_params, serialize_params};
use vtext::tokenize::*;

#[pyclass]
#[pyclass(module = "vtext.tokenize")]
pub struct BaseTokenizer {}

#[pymethods]
Expand All @@ -31,9 +32,8 @@ impl BaseTokenizer {
/// References
/// ----------
/// - `Unicode® Standard Annex #29 <http://www.unicode.org/reports/tr29/>`_
#[pyclass(extends=BaseTokenizer)]
#[pyclass(extends=BaseTokenizer, module="vtext.tokenize")]
pub struct UnicodeSegmentTokenizer {
pub word_bounds: bool,
inner: vtext::tokenize::UnicodeSegmentTokenizer,
}

Expand All @@ -48,10 +48,7 @@ impl UnicodeSegmentTokenizer {
.unwrap();

(
UnicodeSegmentTokenizer {
word_bounds: word_bounds,
inner: tokenizer,
},
UnicodeSegmentTokenizer { inner: tokenizer },
BaseTokenizer::new(),
)
}
Expand Down Expand Up @@ -86,6 +83,16 @@ impl UnicodeSegmentTokenizer {
fn get_params(&self) -> PyResult<UnicodeSegmentTokenizerParams> {
Ok(self.inner.params.clone())
}

pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
serialize_params(&self.inner.params, py)
}

pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
let mut params: UnicodeSegmentTokenizerParams = deserialize_params(py, state)?;
self.inner = params.build().unwrap();
Ok(())
}
}

/// __init__(self, lang="en")
Expand All @@ -104,9 +111,8 @@ impl UnicodeSegmentTokenizer {
/// ----------
///
/// - `Unicode® Standard Annex #29 <http://www.unicode.org/reports/tr29/>`_
#[pyclass(extends=BaseTokenizer)]
#[pyclass(extends=BaseTokenizer, module="vtext.tokenize")]
pub struct VTextTokenizer {
pub lang: String,
inner: vtext::tokenize::VTextTokenizer,
}

Expand All @@ -120,13 +126,7 @@ impl VTextTokenizer {
.build()
.unwrap();

(
VTextTokenizer {
lang: lang.to_string(),
inner: tokenizer,
},
BaseTokenizer::new(),
)
(VTextTokenizer { inner: tokenizer }, BaseTokenizer::new())
}

/// tokenize(self, x)
Expand Down Expand Up @@ -159,14 +159,23 @@ impl VTextTokenizer {
fn get_params(&self) -> PyResult<VTextTokenizerParams> {
Ok(self.inner.params.clone())
}

pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
serialize_params(&self.inner.params, py)
}

pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
let mut params: VTextTokenizerParams = deserialize_params(py, state)?;
self.inner = params.build().unwrap();
Ok(())
}
}

/// __init__(self, pattern=r'\\b\\w\\w+\\b')
///
/// Tokenize a document using regular expressions
#[pyclass(extends=BaseTokenizer)]
#[pyclass(extends=BaseTokenizer, module="vtext.tokenize")]
pub struct RegexpTokenizer {
pub pattern: String,
inner: vtext::tokenize::RegexpTokenizer,
}

Expand All @@ -180,13 +189,7 @@ impl RegexpTokenizer {
.build()
.unwrap();

(
RegexpTokenizer {
pattern: pattern.to_string(),
inner: inner,
},
BaseTokenizer::new(),
)
(RegexpTokenizer { inner: inner }, BaseTokenizer::new())
}

/// tokenize(self, x)
Expand Down Expand Up @@ -219,6 +222,16 @@ impl RegexpTokenizer {
fn get_params(&self) -> PyResult<RegexpTokenizerParams> {
Ok(self.inner.params.clone())
}

pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
serialize_params(&self.inner.params, py)
}

pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
let mut params: RegexpTokenizerParams = deserialize_params(py, state)?;
self.inner = params.build().unwrap();
Ok(())
}
}

/// __init__(self, window_size=4)
Expand All @@ -237,9 +250,8 @@ impl RegexpTokenizer {
/// >>> tokenizer.tokenize('fox can\'t')
/// ['fox ', 'ox c', 'x ca', ' can', 'can\'', 'an\'t']
///
#[pyclass(extends=BaseTokenizer)]
#[pyclass(extends=BaseTokenizer, module="vtext.tokenize")]
pub struct CharacterTokenizer {
pub window_size: usize,
inner: vtext::tokenize::CharacterTokenizer,
}

Expand All @@ -253,13 +265,7 @@ impl CharacterTokenizer {
.build()
.unwrap();

(
CharacterTokenizer {
window_size: window_size,
inner: inner,
},
BaseTokenizer::new(),
)
(CharacterTokenizer { inner: inner }, BaseTokenizer::new())
}

/// tokenize(self, x)
Expand Down Expand Up @@ -292,4 +298,14 @@ impl CharacterTokenizer {
fn get_params(&self) -> PyResult<CharacterTokenizerParams> {
Ok(self.inner.params.clone())
}

pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
serialize_params(&self.inner.params, py)
}

pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
let mut params: CharacterTokenizerParams = deserialize_params(py, state)?;
self.inner = params.build().unwrap();
Ok(())
}
}
25 changes: 23 additions & 2 deletions python/src/tokenize_sentence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use pyo3::types::PyList;
use vtext::tokenize::Tokenizer;
use vtext::tokenize_sentence::*;

use crate::utils::{deserialize_params, serialize_params};
// macro located `vtext::tokenize_sentence::vecString`
use vtext::vecString;

Expand All @@ -24,7 +25,7 @@ use vtext::vecString;
/// References
/// ----------
/// - `Unicode® Standard Annex #29 <http://www.unicode.org/reports/tr29/>`_
#[pyclass(extends=BaseTokenizer)]
#[pyclass(extends=BaseTokenizer, module="vtext.tokenize_sentence")]
pub struct UnicodeSentenceTokenizer {
inner: vtext::tokenize_sentence::UnicodeSentenceTokenizer,
}
Expand Down Expand Up @@ -73,6 +74,16 @@ impl UnicodeSentenceTokenizer {
fn get_params(&self) -> PyResult<UnicodeSentenceTokenizerParams> {
Ok(self.inner.params.clone())
}

pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
serialize_params(&self.inner.params, py)
}

pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
let mut params: UnicodeSentenceTokenizerParams = deserialize_params(py, state)?;
self.inner = params.build().unwrap();
Ok(())
}
}

/// __init__(self, punctuation=[".", "?", "!"])
Expand All @@ -88,7 +99,7 @@ impl UnicodeSentenceTokenizer {
/// Punctuation tokens used to determine boundaries. Only the first unicode "character" is used.
///
///
#[pyclass(extends=BaseTokenizer)]
#[pyclass(extends=BaseTokenizer, module="vtext.tokenize_sentence")]
pub struct PunctuationTokenizer {
inner: vtext::tokenize_sentence::PunctuationTokenizer,
}
Expand Down Expand Up @@ -139,4 +150,14 @@ impl PunctuationTokenizer {
fn get_params(&self) -> PyResult<PunctuationTokenizerParams> {
Ok(self.inner.params.clone())
}

pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
serialize_params(&self.inner.params, py)
}

pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
let mut params: PunctuationTokenizerParams = deserialize_params(py, state)?;
self.inner = params.build().unwrap();
Ok(())
}
}
Loading

0 comments on commit 172838c

Please sign in to comment.