Skip to content

Commit

Permalink
bug fix to close #441, only boxes which match should be included in c…
Browse files Browse the repository at this point in the history
…lass recall
  • Loading branch information
bw4sz committed Jul 11, 2023
1 parent f85644b commit fc1e20e
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ current_release.csv
*.pt
tests/__pycache__
*.wp*
lightning_logs/*
*.prof
4 changes: 3 additions & 1 deletion deepforest/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ def evaluate(predictions, ground_df, root_dir, iou_threshold=0.4, savedir=None):
box_precision = np.mean(box_precisions)
box_recall = np.mean(box_recalls)

class_recall = compute_class_recall(results)
# Only matching boxes are considered in class recall
matched_results = results[results.match==True]
class_recall = compute_class_recall(matched_results)

return {
"results": results,
Expand Down
22 changes: 21 additions & 1 deletion tests/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,24 @@ def test_evaluate_empty():
assert results["box_recall"] == 0

df = pd.read_csv(csv_file)
assert results["results"].shape[0] == df.shape[0]
assert results["results"].shape[0] == df.shape[0]

@pytest.fixture
def sample_results():
# Create a sample DataFrame for testing
data = {
'true_label': [1, 1, 2],
'predicted_label': [1, 2, 1]
}
return pd.DataFrame(data)

def test_compute_class_recall(sample_results):
# Test case with sample data
expected_recall = pd.DataFrame({
'label': [1, 2],
'recall': [0.5, 0],
'precision': [0.5, 0],
'size': [2, 1]
}).reset_index(drop=True)

assert evaluate.compute_class_recall(sample_results).equals(expected_recall)

0 comments on commit fc1e20e

Please sign in to comment.