Skip to content

Commit

Permalink
use pyo3::intern macro inside pretokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
mh-northlander committed Jul 2, 2024
1 parent 73c8cd9 commit 4345772
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
10 changes: 7 additions & 3 deletions python/src/dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,20 +461,24 @@ fn read_config(config_opt: &Bound<PyAny>) -> PyResult<ConfigBuilder> {
}

pub(crate) fn read_default_config(py: Python) -> PyResult<ConfigBuilder> {
let path = PyModule::import_bound(py, "sudachipy")?.getattr("_DEFAULT_SETTINGFILE")?;
let path = py
.import_bound("sudachipy")?
.getattr("_DEFAULT_SETTINGFILE")?;
let path = path.downcast::<PyString>()?.to_str()?;
let path = PathBuf::from(path);
wrap_ctx(ConfigBuilder::from_opt_file(Some(&path)), &path)
}

pub(crate) fn get_default_resource_dir(py: Python) -> PyResult<PathBuf> {
let path = PyModule::import_bound(py, "sudachipy")?.getattr("_DEFAULT_RESOURCEDIR")?;
let path = py
.import_bound("sudachipy")?
.getattr("_DEFAULT_RESOURCEDIR")?;
let path = path.downcast::<PyString>()?.to_str()?;
Ok(PathBuf::from(path))
}

fn find_dict_path(py: Python, dict_type: &str) -> PyResult<PathBuf> {
let pyfunc = PyModule::import_bound(py, "sudachipy")?.getattr("_find_dict_path")?;
let pyfunc = py.import_bound("sudachipy")?.getattr("_find_dict_path")?;
let path = pyfunc.call1((dict_type,))?;
let path = path.downcast::<PyString>()?.to_str()?;
Ok(PathBuf::from(path))
Expand Down
4 changes: 2 additions & 2 deletions python/src/pretokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ impl PyPretokenizer {
py: Python<'py>,
data: &Bound<'py, PyAny>,
) -> PyResult<Bound<'py, PyAny>> {
data.call_method1("split", PyTuple::new_bound(py, [self_]))
data.call_method1(intern!(py, "split"), PyTuple::new_bound(py, [self_]))
}
}

Expand All @@ -190,7 +190,7 @@ fn make_result_for_projection<'py>(
) -> PyResult<Bound<'py, PyList>> {
let result = PyList::empty_bound(py);
let nstring = {
static NORMALIZED_STRING: GILOnceCell<Py<PyType>> = pyo3::sync::GILOnceCell::new();
static NORMALIZED_STRING: GILOnceCell<Py<PyType>> = GILOnceCell::new();
NORMALIZED_STRING.get_or_try_init(py, || -> PyResult<Py<PyType>> {
let ns = py.import_bound("tokenizers")?.getattr("NormalizedString")?;
let tpe = ns.downcast::<PyType>()?;
Expand Down

0 comments on commit 4345772

Please sign in to comment.