Skip to content

Commit

Permalink
Merge pull request #94 from bobmyhill/visual_center
Browse files Browse the repository at this point in the history
Added option to plot at visual center of fields
  • Loading branch information
morganjwilliams committed Oct 16, 2023
2 parents 21bdc12 + 319bdc0 commit c3e822b
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 28 deletions.
23 changes: 21 additions & 2 deletions pyrolite/util/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
update_docstring_references,
)
from .plot.axes import init_axes
from .plot.helpers import get_centroid
from .plot.helpers import get_centroid, get_visual_center
from .plot.style import patchkwargs
from .plot.transform import tlr_to_xy

Expand Down Expand Up @@ -395,6 +395,7 @@ def add_to_axes(
axes_scale=100.0,
add_labels=False,
which_labels="ID",
label_at_centroid=True,
**kwargs
):
"""
Expand All @@ -413,6 +414,9 @@ def add_to_axes(
which_labels : :class:`str`
Which labels to add to the polygons (e.g. for TAS, 'volcanic', 'intrusive'
or the field 'ID').
label_at_centroid : :class:`bool`
Whether to label the fields at the centroid (True) or at the visual
center of the field (False).
Returns
--------
Expand All @@ -424,6 +428,18 @@ def add_to_axes(
ax = self._add_polygons_to_axes(
ax=ax, fill=fill, axes_scale=axes_scale, add_labels=False, **kwargs
)

if not label_at_centroid:
# Calculate the effective vertical exaggeration that
# produces nice positioning of labels. The true vertical
# exaggeration is increased by a scale_factor because
# the text labels are typically wider than they are long,
# so we want to promote the labels
# being placed at the widest part of the field.
scale_factor = 1.5
p = ax.transData.transform([[0., 0.], [1., 1.]])
yx_scaling = (p[1][1] - p[0][1])/(p[1][0] - p[0][0])*scale_factor

rescale_by = 1.0
if axes_scale is not None: # rescale polygons to fit ax
if not np.isclose(self.default_scale, axes_scale):
Expand All @@ -449,7 +465,10 @@ def add_to_axes(
)
verts = np.array(_read_poly(cfg["poly"])) * rescale_by
_poly = matplotlib.patches.Polygon(verts)
x, y = get_centroid(_poly)
if label_at_centroid:
x, y = get_centroid(_poly)
else:
x, y = get_visual_center(_poly, yx_scaling)
ax.annotate(
"\n".join(label.split()),
xy=(x, y),
Expand Down
175 changes: 175 additions & 0 deletions pyrolite/util/plot/center.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""
Functions for calculating the visual center of a polygon.
Taken from https://github.com/Twista/python-polylabel,
Originally released under an MIT licence.
"""
from math import sqrt, inf
import time
from queue import PriorityQueue


def _point_to_polygon_distance(x, y, polygon):
inside = False
min_dist_sq = inf

for ring in polygon:
b = ring[-1]
for a in ring:

if ((a[1] > y) != (b[1] > y) and
(x < (b[0] - a[0]) * (y - a[1]) / (b[1] - a[1]) + a[0])):
inside = not inside

min_dist_sq = min(min_dist_sq, _get_seg_dist_sq(x, y, a, b))
b = a

result = sqrt(min_dist_sq)
if not inside:
return -result
return result


def _get_seg_dist_sq(px, py, a, b):
x = a[0]
y = a[1]
dx = b[0] - x
dy = b[1] - y

if dx != 0 or dy != 0:
t = ((px - x) * dx + (py - y) * dy) / (dx * dx + dy * dy)

if t > 1:
x = b[0]
y = b[1]

elif t > 0:
x += dx * t
y += dy * t

dx = px - x
dy = py - y

return dx * dx + dy * dy


class Cell(object):
def __init__(self, x, y, h, polygon):
self.h = h
self.y = y
self.x = x
self.d = _point_to_polygon_distance(x, y, polygon)
self.max = self.d + self.h * sqrt(2)

def __lt__(self, other):
return self.max < other.max

def __lte__(self, other):
return self.max <= other.max

def __gt__(self, other):
return self.max > other.max

def __gte__(self, other):
return self.max >= other.max

def __eq__(self, other):
return self.max == other.max


def _get_centroid_cell(polygon):
area = 0
x = 0
y = 0
points = polygon[0]
b = points[-1] # prev
for a in points:
f = a[0] * b[1] - b[0] * a[1]
x += (a[0] + b[0]) * f
y += (a[1] + b[1]) * f
area += f * 3
b = a
if area == 0:
return Cell(points[0][0], points[0][1], 0, polygon)
return Cell(x / area, y / area, 0, polygon)

pass


def visual_center(polygon, precision=1.0, debug=False, with_distance=False):
# find bounding box
first_item = polygon[0][0]
min_x = first_item[0]
min_y = first_item[1]
max_x = first_item[0]
max_y = first_item[1]
for p in polygon[0]:
if p[0] < min_x:
min_x = p[0]
if p[1] < min_y:
min_y = p[1]
if p[0] > max_x:
max_x = p[0]
if p[1] > max_y:
max_y = p[1]

width = max_x - min_x
height = max_y - min_y
cell_size = min(width, height)
h = cell_size / 2.0

cell_queue = PriorityQueue()

if cell_size == 0:
if with_distance:
return [min_x, min_y], None
else:
return [min_x, min_y]

# cover polygon with initial cells
x = min_x
while x < max_x:
y = min_y
while y < max_y:
c = Cell(x + h, y + h, h, polygon)
y += cell_size
cell_queue.put((-c.max, time.time(), c))
x += cell_size

best_cell = _get_centroid_cell(polygon)

bbox_cell = Cell(min_x + width / 2, min_y + height / 2, 0, polygon)
if bbox_cell.d > best_cell.d:
best_cell = bbox_cell

num_of_probes = cell_queue.qsize()
while not cell_queue.empty():
_, __, cell = cell_queue.get()

if cell.d > best_cell.d:
best_cell = cell

if debug:
print('found best {} after {} probes'.format(
round(1e4 * cell.d) / 1e4, num_of_probes))

if cell.max - best_cell.d <= precision:
continue

h = cell.h / 2
c = Cell(cell.x - h, cell.y - h, h, polygon)
cell_queue.put((-c.max, time.time(), c))
c = Cell(cell.x + h, cell.y - h, h, polygon)
cell_queue.put((-c.max, time.time(), c))
c = Cell(cell.x - h, cell.y + h, h, polygon)
cell_queue.put((-c.max, time.time(), c))
c = Cell(cell.x + h, cell.y + h, h, polygon)
cell_queue.put((-c.max, time.time(), c))
num_of_probes += 4

if debug:
print('num probes: {}'.format(num_of_probes))
print('best distance: {}'.format(best_cell.d))
if with_distance:
return [best_cell.x, best_cell.y], best_cell.d
else:
return [best_cell.x, best_cell.y]
25 changes: 25 additions & 0 deletions pyrolite/util/plot/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import matplotlib.pyplot as plt
import numpy as np
import scipy.spatial
from .center import visual_center

from ..log import Handle
from ..math import eigsorted, nancov
Expand Down Expand Up @@ -78,6 +79,30 @@ def get_centroid(poly):
return cx, cy


def get_visual_center(poly, vertical_exaggeration=1):
"""
Visual center of a closed polygon.
Parameters
----------
poly : :class:`matplotlib.patches.Polygon`
Polygon for which to obtain the visual center.
vertical_exaggeration : :class:`float`
Apparent vertical exaggeration of the plot
(pixels per unit in y direction divided by pixels
per unit in the x direction).
Returns
-------
cx, cy : :class:`tuple`
Centroid coordinates.
"""
poly_scaled = np.array([poly.get_xy() * [1., vertical_exaggeration]])
x, y = visual_center(poly_scaled)
return tuple([x, y/vertical_exaggeration])


def rect_from_centre(x, y, dx=0, dy=0, **kwargs):
"""
Takes an xy point, and creates a rectangular patch centred about it.
Expand Down
57 changes: 31 additions & 26 deletions test/util/util_classification.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import unittest

import matplotlib.pyplot as plt
import numpy as np

from pyrolite.comp.codata import renormalise
from pyrolite.util.classification import *
import pandas as pd

from pyrolite.util.classification import (TAS,
USDASoilTexture,
QAP,
FeldsparTernary,
JensenPlot,
PeralkalinityClassifier)
from pyrolite.util.synthetic import normal_frame, random_cov_matrix


Expand All @@ -20,29 +23,31 @@ def setUp(self):
self.df.loc[:, "Na2O + K2O"] = self.df.Na2O + self.df.K2O

def test_classifer_build(self):
cm = TAS()
_ = TAS()

def test_classifer_add_to_axes(self):
cm = TAS()
fig, ax = plt.subplots(1)
_, ax = plt.subplots(1)
for a in [None, ax]:
with self.subTest(a=a):
cm.add_to_axes(
ax=a,
alpha=0.4,
color="k",
axes_scale=100,
linewidth=0.5,
which_labels="ID",
add_labels=True,
)
for label_at_centroid in [True, False]:
with self.subTest(a=a):
cm.add_to_axes(
ax=a,
alpha=0.4,
color="k",
axes_scale=100,
linewidth=0.5,
which_labels="ID",
add_labels=True,
label_at_centroid=label_at_centroid,
)

def test_classifer_predict(self):
df = self.df
cm = TAS()
classes = cm.predict(df, data_scale=1.0)
# precitions will be ID's
rocknames = classes.apply(lambda x: cm.fields.get(x, {"name": None})["name"])
# rocknames will be IDs
_ = classes.apply(lambda x: cm.fields.get(x, {"name": None})["name"])
self.assertFalse(pd.isnull(classes).all())


Expand All @@ -55,7 +60,7 @@ def setUp(self):
)

def test_classifer_build(self):
cm = USDASoilTexture()
_ = USDASoilTexture()

def test_classifer_add_to_axes(self):
cm = USDASoilTexture()
Expand All @@ -82,7 +87,7 @@ def setUp(self):
)

def test_classifer_build(self):
cm = QAP()
_ = QAP()

def test_classifer_add_to_axes(self):
cm = QAP()
Expand All @@ -109,11 +114,11 @@ def setUp(self):
)

def test_classifer_build(self):
cm = FeldsparTernary()
_ = FeldsparTernary()

def test_classifer_add_to_axes(self):
cm = FeldsparTernary()
fig, ax = plt.subplots(1)
_, ax = plt.subplots(1)
for a in [None, ax]:
with self.subTest(a=a):
cm.add_to_axes(
Expand All @@ -137,11 +142,11 @@ def setUp(self):
)

def test_classifer_build(self):
cm = JensenPlot()
_ = JensenPlot()

def test_classifer_add_to_axes(self):
cm = JensenPlot()
fig, ax = plt.subplots(1)
_, ax = plt.subplots(1)
for a in [None, ax]:
with self.subTest(a=a):
cm.add_to_axes(
Expand All @@ -159,7 +164,7 @@ class TestPeralkalinity(unittest.TestCase):
"""Test the peralkalinity classifier."""

def setUp(self):
self.df = df = normal_frame(
self.df = normal_frame(
columns=["SiO2", "Na2O", "K2O", "Al2O3", "CaO"],
mean=[0.5, 0.04, 0.05, 0.2, 0.3],
size=100,
Expand Down

0 comments on commit c3e822b

Please sign in to comment.