-
Notifications
You must be signed in to change notification settings - Fork 232
/
euclidean.py
61 lines (49 loc) · 1.63 KB
/
euclidean.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python3
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from .manifold import Manifold
import torch as th
import numpy as np
class EuclideanManifold(Manifold):
__slots__ = ["max_norm"]
def __init__(self, max_norm=None, K=None, **kwargs):
self.max_norm = max_norm
self.K = K
if K is not None:
self.inner_radius = 2 * self.K / (1 + np.sqrt(1 + 4 * self.K * self.K))
def normalize(self, u):
d = u.size(-1)
if self.max_norm:
u.view(-1, d).renorm_(2, 0, self.max_norm)
return u
def distance(self, u, v):
return (u - v).pow(2).sum(dim=-1)
def rgrad(self, p, d_p):
return d_p
def half_aperture(self, u):
sqnu = u.pow(2).sum(dim=-1)
return th.asin(self.inner_radius / sqnu.sqrt())
def angle_at_u(self, u, v):
norm_u = self.norm(u)
norm_v = self.norm(v)
dist = self.distance(v, u)
num = norm_u.pow(2) - norm_v.pow(2) - dist.pow(2)
denom = 2 * norm_v * dist
return (num / denom).acos()
def expm(self, p, d_p, normalize=False, lr=None, out=None):
if lr is not None:
d_p.mul_(-lr)
if out is None:
out = p
out.add_(d_p)
if normalize:
self.normalize(out)
return out
def logm(self, p, d_p, out=None):
return p - d_p
def ptransp(self, p, x, y, v):
ix, v_ = v._indices().squeeze(), v._values()
return p.index_copy_(0, ix, v_)