Skip to content

Commit

Permalink
Improve RotatE implementation
Browse files Browse the repository at this point in the history
Adds sp* and *po scoring, improves readability, sets default initialization
  • Loading branch information
rgemulla committed Mar 3, 2020
1 parent ac240c4 commit 76a0077
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 13 deletions.
96 changes: 83 additions & 13 deletions kge/model/rotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from kge.model.kge_model import RelationalScorer, KgeModel
from torch.nn import functional as F

# EXPERIMENTAL. This is a first cut, implementation still needs work.

# TODO sp* and *po scoring with RotatE leads to *large* intermediate results. It's
# unclear whether this can be fixed.
class RotatEScorer(RelationalScorer):
r"""Implementation of the RotatE KGE scorer."""

Expand All @@ -16,27 +18,44 @@ def score_emb(self, s_emb, p_emb, o_emb, combine: str):

# determine real and imaginary part
s_emb_re, s_emb_im = torch.chunk(s_emb, 2, dim=1)
o_emb_re, o_emb_im = torch.chunk(s_emb, 2, dim=1)
o_emb_re, o_emb_im = torch.chunk(o_emb, 2, dim=1)

# TODO Original RotatE code normalizes relation embeddings here
# convert from radians to points on complex unix ball
p_emb_re, p_emb_im = torch.cos(p_emb), torch.sin(p_emb)

if combine == "spo":
# compute the difference vector (s*p-t), treating real and complex parts
# separately
sp_emb_re = s_emb_re * p_emb_re - s_emb_im * p_emb_im
sp_emb_im = s_emb_re * p_emb_im + s_emb_im * p_emb_re
diff_re = sp_emb_re - o_emb_re
diff_im = sp_emb_im - o_emb_im
# compute the difference vector (s*p-t)
sp_emb_re, sp_emb_im = hadamard_complex(
s_emb_re, s_emb_im, p_emb_re, p_emb_im
)
diff_re, diff_im = diff_complex(sp_emb_re, sp_emb_im, o_emb_re, o_emb_im)

# compute the absolute values for each (complex) element of the difference
# vector
diff = torch.stack((diff_re, diff_im,), dim=0) # dim0: real, imaginary
diff_abs = torch.norm(diff, dim=0) # sqrt(real^2+imaginary^2)
diff_abs = norm_complex(diff_re, diff_im)

# now take the norm of the absolute values
# now take the norm of the absolute values of the difference vector
out = torch.norm(diff_abs, dim=1, p=self._norm)
# TODO combine = "sp*" and combine = "*po"
elif combine == "sp*":
# as above, but pair each sp-pair with each object
sp_emb_re, sp_emb_im = hadamard_complex(
s_emb_re, s_emb_im, p_emb_re, p_emb_im
) # sp x dim
diff_re, diff_im = pairwise_diff_complex(
sp_emb_re, sp_emb_im, o_emb_re, o_emb_im
) # sp x o x dim
diff_abs = norm_complex(diff_re, diff_im) # sp x o x dim
out = torch.norm(diff_abs, dim=2, p=self._norm)
elif combine == "*po":
# as above, but pair each subject with each po-pair
sp_emb_re, sp_emb_im = pairwise_hadamard_complex(
s_emb_re, s_emb_im, p_emb_re, p_emb_im
) # s x p x dim
diff_re, diff_im = diff_complex(
sp_emb_re, sp_emb_im, o_emb_re, o_emb_im
) # s x po x dim
diff_abs = norm_complex(diff_re, diff_im) # s x po x dim
out = torch.norm(diff_abs, dim=2, p=self._norm).t()
else:
return super().score_emb(s_emb, p_emb, o_emb, combine)
return out.view(n, -1)
Expand All @@ -61,3 +80,54 @@ def __init__(self, config: Config, dataset: Dataset, configuration_key=None):
super().__init__(
config, dataset, RotatEScorer, configuration_key=self.configuration_key
)


def pairwise_sum(X, Y):
"""Compute pairwise sum of rows of X and Y.
Returns tensor of shape len(X) x len(Y) x dim."""
return X.unsqueeze(1) + Y


def pairwise_diff(X, Y):
"""Compute pairwise difference of rows of X and Y.
Returns tensor of shape len(X) x len(Y) x dim."""
return X.unsqueeze(1) - Y


def pairwise_hadamard(X, Y):
"""Compute pairwise Hadamard product of rows of X and Y.
Returns tensor of shape len(X) x len(Y) x dim."""
return X.unsqueeze(1) * Y


def hadamard_complex(x_re, x_im, y_re, y_im):
"Hadamard product for complex vectors"
result_re = x_re * y_re - x_im * y_im
result_im = x_re * y_im + x_im * y_re
return result_re, result_im


def pairwise_hadamard_complex(x_re, x_im, y_re, y_im):
"Pairwise Hadamard product for complex vectors"
result_re = pairwise_hadamard(x_re, y_re) - pairwise_hadamard(x_im, y_im)
result_im = pairwise_hadamard(x_re, y_im) + pairwise_hadamard(x_im, y_re)
return result_re, result_im


def diff_complex(x_re, x_im, y_re, y_im):
"Difference of complex vectors"
return x_re - y_re, x_im - y_im


def pairwise_diff_complex(x_re, x_im, y_re, y_im):
"Pairwise difference of complex vectors"
return pairwise_diff(x_re, y_re), pairwise_diff(x_im, y_im)


def norm_complex(x_re, x_im):
"Compute magnitude of given complex numbers"
x_re_im = torch.stack((x_re, x_im), dim=0) # dim0: real, imaginary
return torch.norm(x_re_im, dim=0) # sqrt(real^2+imaginary^2)
14 changes: 14 additions & 0 deletions kge/model/rotate.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,28 @@ import: [lookup_embedder]

rotate:
class_name: RotatE

entity_embedder:
type: lookup_embedder
# Note: dimensionality (key "dim") refers to the combined size of the
# head and tail embedding. Must be even.
+++: +++

relation_embedder:
type: lookup_embedder
dim: -1 # -1 means: pick as half the entity_embedder.dim

# The components of the relation embeddings in RotatE represent radians and
# are converted to values on the complex unit ball when being used. This
# initialization is used in the original RotatE implementation: uniform on
# the complex unit ball.
initialize: uniform_
initialize_args:
uniform_:
a: -3.14159265359
b: 3.14159265359

+++: +++

l_norm: 1.

0 comments on commit 76a0077

Please sign in to comment.