Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(WIP) Tie-breaking extensibility #110

Merged
merged 6 commits into from
May 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion kge/config-default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,22 @@ eval:
# Type of evaluation (entity_ranking only at the moment)
type: entity_ranking

# How to handle cases ties between the correct answer and other answers, e.g.,
# Query: (s, p, ?).
# Answers and score: a:10, b:10, c:10, d:11, e:9
# Correct: 'a'.
#
# Possible options are:
# - worst_rank: Use the highest rank of all answers that have the same
# score as the correct answer. In example: 4.
# - best_rank: Use the lowest rank of all answers that have the same
# score as the correct answer (competition scoring). In
# example: 2. DO NOT USE THIS OPTION, it leads to
# misleading evaluation results.
# - rounded_mean_rank: Average between worst and best rank, rounded up
# (rounded fractional ranking). In example: 3.
tie_handling: rounded_mean_rank
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!


# Compute Hits@K for these choices of K
hits_at_k_s: [1, 3, 10, 50, 100, 200, 300, 400, 500, 1000]

Expand Down Expand Up @@ -359,7 +375,6 @@ eval:
# Other options
pin_memory: False


# Configuration options for model validation/selection during training. Applied
# in addition to the options set under "eval" above.
valid:
Expand Down
20 changes: 15 additions & 5 deletions kge/job/entity_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ class EntityRankingJob(EvaluationJob):

def __init__(self, config: Config, dataset: Dataset, parent_job, model):
super().__init__(config, dataset, parent_job, model)
self.config.check(
"eval.tie_handling",
["rounded_mean_rank", "best_rank", "worst_rank"],
)
self.tie_handling = self.config.get("eval.tie_handling")
self.is_prepared = False

if self.__class__ == EntityRankingJob:
Expand Down Expand Up @@ -43,7 +48,6 @@ def _prepare(self):
num_workers=self.config.get("eval.num_workers"),
pin_memory=self.config.get("eval.pin_memory"),
)

# let the model add some hooks, if it wants to do so
self.model.prepare_job(self)
self.is_prepared = True
Expand Down Expand Up @@ -527,8 +531,7 @@ def _get_ranks_and_num_ties(
num_ties = torch.sum(scores == true_scores.view(-1, 1), dim=1, dtype=torch.long)
return rank, num_ties

@staticmethod
def _get_ranks(rank: torch.Tensor, num_ties: torch.Tensor) -> torch.Tensor:
def _get_ranks(self, rank: torch.Tensor, num_ties: torch.Tensor) -> torch.Tensor:
"""Calculates the final rank from (minimum) rank and number of ties.

:param rank: batch_size x 1 tensor with number of scores greater than the one of
Expand All @@ -540,8 +543,15 @@ def _get_ranks(rank: torch.Tensor, num_ties: torch.Tensor) -> torch.Tensor:
:return: batch_size x 1 tensor of ranks

"""
ranks = rank + num_ties // 2
return ranks

if self.tie_handling == "rounded_mean_rank":
return rank + num_ties // 2
elif self.tie_handling == "best_rank":
return rank
elif self.tie_handling == "worst_rank":
return rank + num_ties - 1
else:
raise NotImplementedError

def _compute_metrics(self, rank_hist, suffix=""):
"""Computes desired matrix from rank histogram"""
Expand Down