Skip to content

Commit

Permalink
Document landmark matchers
Browse files Browse the repository at this point in the history
  • Loading branch information
clbarnes committed Jun 13, 2023
1 parent 59c0e3a commit 39c3c71
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 11 deletions.
2 changes: 2 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ Functions to fetch data about landmarks.

pymaid.get_landmarks
pymaid.get_landmark_groups
pymaid.LandmarkMatcher
pymaid.CrossProjectLandmarkMatcher

Reconstruction samplers
-----------------------
Expand Down
5 changes: 5 additions & 0 deletions docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ What's new?
* - Version
- Date
-
* - In progress
- n/a
- :class:`pymaid.LandmarkMatcher` and :class:`pymaid.CrossProjectLandmarkMatcher`
for matching paired landmarks for use in transformations within a project
(e.g. left-right or segmental) or between projects (e.g. different animals)
* - 2.4.0
- 27/05/23
- - :func:`pymaid.get_annotation_graph` deprecated in favour of the new
Expand Down
6 changes: 3 additions & 3 deletions pymaid/fetch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

from .. import core, utils, config, cache
from navis import in_volume
from .landmarks import get_landmarks, get_landmark_groups
from .landmarks import get_landmarks, get_landmark_groups, LandmarkMatcher, CrossProjectLandmarkMatcher
from .skeletons import get_skeleton_ids
from .annotations import get_annotation_graph, get_entity_graph, get_annotation_id

Expand Down Expand Up @@ -90,8 +90,8 @@
'get_origin', 'get_skids_by_origin',
'get_sampler', 'get_sampler_domains', 'get_sampler_counts',
'get_skeleton_change',
'get_landmarks',
'get_landmark_groups',
'get_landmarks', 'get_landmark_groups',
'LandmarkMatcher', 'CrossProjectLandmarkMatcher',
'get_skeleton_ids',
]

Expand Down
91 changes: 83 additions & 8 deletions pymaid/fetch/landmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pymaid.client import CatmaidInstance

from ..utils import _eval_remote_instance, DataFrameBuilder
from ..utils import _eval_remote_instance, DataFrameBuilder, clean_points


def get_landmarks(
Expand Down Expand Up @@ -180,7 +180,11 @@ def get_landmark_groups(


class LandmarkMatcher:
"""Class for finding matching pairs of landmark locations between two groups."""
"""Class for finding matching pairs of landmark locations between two groups.
For example, find control points for transforming neurons left to right
or between segments.
"""
def __init__(
self,
landmarks: pd.DataFrame,
Expand All @@ -189,6 +193,26 @@ def __init__(
group_locations: pd.DataFrame,
group_members: dict[int, tp.Iterable[int]],
):
"""Prefer constructing with ``.from_catmaid()`` where possible.
Parameters
----------
landmarks : pd.DataFrame
Landmarks dataframe:
see first output of ``get_landmarks`` for details.
landmark_locations : pd.DataFrame
Landmark locations dataframe:
see second (optional) output of ``get_landmarks``.
groups : pd.DataFrame
Groups dataframe:
see first output of ``get_landmark_groups`` for details.
group_locations : pd.DataFrame
Group locations dataframe:
see second (optional) output of ``get_landmark_groups`` for details.
group_members : dict[int, tp.Iterable[int]]
Group members:
see third (optional) output of ``get_landmark_groups`` for details.
"""
self.landmarks = landmarks
self.landmark_locations = landmark_locations
self.groups = groups
Expand All @@ -199,6 +223,15 @@ def __init__(

@classmethod
def from_catmaid(cls, remote_instance=None):
"""Instantiate from a CatmaidInstance.
Possibly the global one.
Parameters
----------
remote_instance : CatmaidInstance, optional
If None (default) use the global instance.
"""
cm = _eval_remote_instance(remote_instance)
landmarks, landmark_locations = get_landmarks(True, cm)
groups, group_locations, members = get_landmark_groups(True, True, remote_instance=cm)
Expand All @@ -213,7 +246,7 @@ def _locations_in_group(self, group_id: int):
return self.group_locations["location_id"][idx]

def _locations_in_landmark(self, landmark_id: int):
idx = self.landmark_locations["landmark_id"] = landmark_id
idx = self.landmark_locations["landmark_id"] == landmark_id
return self.landmark_locations["location_id"][idx]

def _unique_location(self, group_id: int, landmark_id: int) -> int:
Expand Down Expand Up @@ -293,7 +326,7 @@ def match(self, group1: tp.Union[str, int], group2: tp.Union[str, int]) -> pd.Da
if not len(g1_locs) == 1:
continue

row = [lm_id, lm_id_to_name[lm_id]]
row = [lm_id_to_name[lm_id], lm_id]
row.append(g1_locs.pop())
row.extend(locs[row[-1]])
row.append(g2_locs.pop())
Expand All @@ -305,13 +338,39 @@ def match(self, group1: tp.Union[str, int], group2: tp.Union[str, int]) -> pd.Da
return df.astype(dtypes)


class CrossInstanceLandmarkMatcher:
class CrossProjectLandmarkMatcher:
"""Class for finding matching pairs of landmark locations between two instances.
For example, find control points for transforming neurons
from one instance's space to another.
"""
def __init__(self, this_lms: LandmarkMatcher, other_lms: LandmarkMatcher):
"""Constructing using ``.from_catmaid()`` may be more convenient.
Parameters
----------
this_lms : LandmarkMatcher
``LandmarkMatcher`` for landmarks in "this" space.
other_lms : LandmarkMatcher
``LandmarkMatcher`` for landmarks in the "other" space.
"""
self.this_m: LandmarkMatcher = this_lms
self.other_m: LandmarkMatcher = other_lms

@classmethod
def from_catmaid(cls, other_remote_instance: CatmaidInstance, this_remote_instance=None):
def from_catmaid(
cls, other_remote_instance: CatmaidInstance, this_remote_instance=None
):
"""Instantiate from a pair of CatmaidInstances.
Parameters
----------
other_remote_instance : CatmaidInstance
Other catmaid instance
this_remote_instance : CatmaidInstance, optional
This CATMAID instance.
If None (default), use the global instance.
"""
this_remote_instance = _eval_remote_instance(this_remote_instance)
return cls(
LandmarkMatcher.from_catmaid(this_remote_instance),
Expand Down Expand Up @@ -342,7 +401,7 @@ def match(
Group name (str) or ID (int) on this instance.
other_group : tp.Optional[tp.Union[str, int]], optional
Group name (str) or ID (int) on the other instance.
If None (default) and ``this_group`` is a str name, use that name.
If None (default) and ``this_group`` is a str name, use the same name.
Returns
-------
Expand Down Expand Up @@ -411,4 +470,20 @@ def match_all(self) -> pd.DataFrame:
df = self.match(group_name)
df.insert(0, "group_name", [group_name] * len(df))
dfs.append(df)
return pd.concat(dfs)
return pd.concat(dfs, ignore_index=True)


def to_control_points(df: pd.DataFrame) -> tuple[np.ndarray, np.ndarray]:
"""Get control point arrays from a dataframe returned by a LandmarkMatcher.
Parameters
----------
df : pd.DataFrame
DataFrame returned by a LandmarkMatcher or CrossProjectLandmarkMatcher.
Returns
-------
tuple[np.ndarray, np.ndarray]
The coordinates of locations.
"""
return clean_points(df, "{}1").to_numpy(), clean_points(df, "{}2").to_numpy()
32 changes: 32 additions & 0 deletions pymaid/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,3 +809,35 @@ def build(self, index_col=None) -> pd.DataFrame:
df.index = index

return df


def clean_points(
df: pd.DataFrame, fmt: tp.Union[str, tp.Callable[[str], tp.Hashable]], dims="xyz"
) -> pd.DataFrame:
"""Extract points from a dataframe.
Parameters
----------
df : pd.DataFrame
Dataframe, some of whose columns represent points.
fmt : tp.Union[str, tp.Callable[[str], tp.Hashable]]
Either a format string (e.g. ``"point_{}_1"``),
or a callable which takes a string and returns a column name.
When a dimension name (like ``"x"``) is passed to the format string,
or the callable, the result should be the name of a column in ``df``.
dims : str, optional
Dimension name order, by default "xyz"
Returns
-------
pd.DataFrame
The column index will be the dimensions given in ``dims``.
Call ``.to_numpy()`` to convert into a numpy array.
"""
if isinstance(fmt, str):
fmt_c = lambda s: fmt.format(s)
else:
fmt_c = fmt

cols = {fmt_c(d): d for d in dims}
return df[list(cols)].rename(columns=cols)

0 comments on commit 39c3c71

Please sign in to comment.