diff --git a/docs/source/api.rst b/docs/source/api.rst index f524ead..d61da50 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -147,6 +147,8 @@ Functions to fetch data about landmarks. pymaid.get_landmarks pymaid.get_landmark_groups + pymaid.LandmarkMatcher + pymaid.CrossProjectLandmarkMatcher Reconstruction samplers ----------------------- diff --git a/docs/source/whats_new.rst b/docs/source/whats_new.rst index 7f7ad76..9f5eace 100644 --- a/docs/source/whats_new.rst +++ b/docs/source/whats_new.rst @@ -10,6 +10,11 @@ What's new? * - Version - Date - + * - In progress + - n/a + - :class:`pymaid.LandmarkMatcher` and :class:`pymaid.CrossProjectLandmarkMatcher` + for matching paired landmarks for use in transformations within a project + (e.g. left-right or segmental) or between projects (e.g. different animals) * - 2.4.0 - 27/05/23 - - :func:`pymaid.get_annotation_graph` deprecated in favour of the new diff --git a/pymaid/fetch/__init__.py b/pymaid/fetch/__init__.py index e5b94e8..37b904d 100644 --- a/pymaid/fetch/__init__.py +++ b/pymaid/fetch/__init__.py @@ -54,7 +54,7 @@ from .. import core, utils, config, cache from navis import in_volume -from .landmarks import get_landmarks, get_landmark_groups +from .landmarks import get_landmarks, get_landmark_groups, LandmarkMatcher, CrossProjectLandmarkMatcher from .skeletons import get_skeleton_ids from .annotations import get_annotation_graph, get_entity_graph, get_annotation_id @@ -90,8 +90,8 @@ 'get_origin', 'get_skids_by_origin', 'get_sampler', 'get_sampler_domains', 'get_sampler_counts', 'get_skeleton_change', - 'get_landmarks', - 'get_landmark_groups', + 'get_landmarks', 'get_landmark_groups', + 'LandmarkMatcher', 'CrossProjectLandmarkMatcher', 'get_skeleton_ids', ] diff --git a/pymaid/fetch/landmarks.py b/pymaid/fetch/landmarks.py index 0df0b29..dd4d0e0 100644 --- a/pymaid/fetch/landmarks.py +++ b/pymaid/fetch/landmarks.py @@ -6,7 +6,7 @@ from pymaid.client import CatmaidInstance -from ..utils import _eval_remote_instance, DataFrameBuilder +from ..utils import _eval_remote_instance, DataFrameBuilder, clean_points def get_landmarks( @@ -180,7 +180,11 @@ def get_landmark_groups( class LandmarkMatcher: - """Class for finding matching pairs of landmark locations between two groups.""" + """Class for finding matching pairs of landmark locations between two groups. + + For example, find control points for transforming neurons left to right + or between segments. + """ def __init__( self, landmarks: pd.DataFrame, @@ -189,6 +193,26 @@ def __init__( group_locations: pd.DataFrame, group_members: dict[int, tp.Iterable[int]], ): + """Prefer constructing with ``.from_catmaid()`` where possible. + + Parameters + ---------- + landmarks : pd.DataFrame + Landmarks dataframe: + see first output of ``get_landmarks`` for details. + landmark_locations : pd.DataFrame + Landmark locations dataframe: + see second (optional) output of ``get_landmarks``. + groups : pd.DataFrame + Groups dataframe: + see first output of ``get_landmark_groups`` for details. + group_locations : pd.DataFrame + Group locations dataframe: + see second (optional) output of ``get_landmark_groups`` for details. + group_members : dict[int, tp.Iterable[int]] + Group members: + see third (optional) output of ``get_landmark_groups`` for details. + """ self.landmarks = landmarks self.landmark_locations = landmark_locations self.groups = groups @@ -199,6 +223,15 @@ def __init__( @classmethod def from_catmaid(cls, remote_instance=None): + """Instantiate from a CatmaidInstance. + + Possibly the global one. + + Parameters + ---------- + remote_instance : CatmaidInstance, optional + If None (default) use the global instance. + """ cm = _eval_remote_instance(remote_instance) landmarks, landmark_locations = get_landmarks(True, cm) groups, group_locations, members = get_landmark_groups(True, True, remote_instance=cm) @@ -213,7 +246,7 @@ def _locations_in_group(self, group_id: int): return self.group_locations["location_id"][idx] def _locations_in_landmark(self, landmark_id: int): - idx = self.landmark_locations["landmark_id"] = landmark_id + idx = self.landmark_locations["landmark_id"] == landmark_id return self.landmark_locations["location_id"][idx] def _unique_location(self, group_id: int, landmark_id: int) -> int: @@ -293,7 +326,7 @@ def match(self, group1: tp.Union[str, int], group2: tp.Union[str, int]) -> pd.Da if not len(g1_locs) == 1: continue - row = [lm_id, lm_id_to_name[lm_id]] + row = [lm_id_to_name[lm_id], lm_id] row.append(g1_locs.pop()) row.extend(locs[row[-1]]) row.append(g2_locs.pop()) @@ -305,13 +338,39 @@ def match(self, group1: tp.Union[str, int], group2: tp.Union[str, int]) -> pd.Da return df.astype(dtypes) -class CrossInstanceLandmarkMatcher: +class CrossProjectLandmarkMatcher: + """Class for finding matching pairs of landmark locations between two instances. + + For example, find control points for transforming neurons + from one instance's space to another. + """ def __init__(self, this_lms: LandmarkMatcher, other_lms: LandmarkMatcher): + """Constructing using ``.from_catmaid()`` may be more convenient. + + Parameters + ---------- + this_lms : LandmarkMatcher + ``LandmarkMatcher`` for landmarks in "this" space. + other_lms : LandmarkMatcher + ``LandmarkMatcher`` for landmarks in the "other" space. + """ self.this_m: LandmarkMatcher = this_lms self.other_m: LandmarkMatcher = other_lms @classmethod - def from_catmaid(cls, other_remote_instance: CatmaidInstance, this_remote_instance=None): + def from_catmaid( + cls, other_remote_instance: CatmaidInstance, this_remote_instance=None + ): + """Instantiate from a pair of CatmaidInstances. + + Parameters + ---------- + other_remote_instance : CatmaidInstance + Other catmaid instance + this_remote_instance : CatmaidInstance, optional + This CATMAID instance. + If None (default), use the global instance. + """ this_remote_instance = _eval_remote_instance(this_remote_instance) return cls( LandmarkMatcher.from_catmaid(this_remote_instance), @@ -342,7 +401,7 @@ def match( Group name (str) or ID (int) on this instance. other_group : tp.Optional[tp.Union[str, int]], optional Group name (str) or ID (int) on the other instance. - If None (default) and ``this_group`` is a str name, use that name. + If None (default) and ``this_group`` is a str name, use the same name. Returns ------- @@ -411,4 +470,20 @@ def match_all(self) -> pd.DataFrame: df = self.match(group_name) df.insert(0, "group_name", [group_name] * len(df)) dfs.append(df) - return pd.concat(dfs) + return pd.concat(dfs, ignore_index=True) + + +def to_control_points(df: pd.DataFrame) -> tuple[np.ndarray, np.ndarray]: + """Get control point arrays from a dataframe returned by a LandmarkMatcher. + + Parameters + ---------- + df : pd.DataFrame + DataFrame returned by a LandmarkMatcher or CrossProjectLandmarkMatcher. + + Returns + ------- + tuple[np.ndarray, np.ndarray] + The coordinates of locations. + """ + return clean_points(df, "{}1").to_numpy(), clean_points(df, "{}2").to_numpy() diff --git a/pymaid/utils.py b/pymaid/utils.py index 1bcb7c9..a0218d5 100644 --- a/pymaid/utils.py +++ b/pymaid/utils.py @@ -809,3 +809,35 @@ def build(self, index_col=None) -> pd.DataFrame: df.index = index return df + + +def clean_points( + df: pd.DataFrame, fmt: tp.Union[str, tp.Callable[[str], tp.Hashable]], dims="xyz" +) -> pd.DataFrame: + """Extract points from a dataframe. + + Parameters + ---------- + df : pd.DataFrame + Dataframe, some of whose columns represent points. + fmt : tp.Union[str, tp.Callable[[str], tp.Hashable]] + Either a format string (e.g. ``"point_{}_1"``), + or a callable which takes a string and returns a column name. + When a dimension name (like ``"x"``) is passed to the format string, + or the callable, the result should be the name of a column in ``df``. + dims : str, optional + Dimension name order, by default "xyz" + + Returns + ------- + pd.DataFrame + The column index will be the dimensions given in ``dims``. + Call ``.to_numpy()`` to convert into a numpy array. + """ + if isinstance(fmt, str): + fmt_c = lambda s: fmt.format(s) + else: + fmt_c = fmt + + cols = {fmt_c(d): d for d in dims} + return df[list(cols)].rename(columns=cols)