From 277e37bbc6f431820f574e7f44a898b971c13984 Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Thu, 15 Feb 2024 15:39:15 -0600 Subject: [PATCH] Plus plus check (#317) * Initial commit * Allow deltas in units tests --- CHANGELOG.md | 3 +++ src/Clusterers/Seeders/PlusPlus.php | 7 +++++++ tests/Transformers/MaxAbsoluteScalerTest.php | 6 +++--- tests/Transformers/RobustStandardizerTest.php | 2 +- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c16213c7..707a5742c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,9 @@ - 2.5.0 - Added Vantage Point Spatial tree - Blob Generator can now `simulate()` a Dataset object + - Added Wrapper interface + - Added Swoole Backend + - Plus Plus added check for min number of sample seeds - 2.4.1 - Sentence Tokenizer fix Arabic and Farsi language support diff --git a/src/Clusterers/Seeders/PlusPlus.php b/src/Clusterers/Seeders/PlusPlus.php index 4a59d98b4..aeb2eb812 100644 --- a/src/Clusterers/Seeders/PlusPlus.php +++ b/src/Clusterers/Seeders/PlusPlus.php @@ -6,6 +6,7 @@ use Rubix\ML\Kernels\Distance\Distance; use Rubix\ML\Kernels\Distance\Euclidean; use Rubix\ML\Specifications\DatasetIsNotEmpty; +use Rubix\ML\Exceptions\RuntimeException; use function count; @@ -49,12 +50,18 @@ public function __construct(?Distance $kernel = null) * * @param Dataset $dataset * @param int $k + * @throws RuntimeException * @return list> */ public function seed(Dataset $dataset, int $k) : array { DatasetIsNotEmpty::with($dataset)->check(); + if ($k > $dataset->numSamples()) { + throw new RuntimeException("Cannot seed $k clusters with only " + . $dataset->numSamples() . ' samples.'); + } + $centroids = $dataset->randomSubset(1)->samples(); while (count($centroids) < $k) { diff --git a/tests/Transformers/MaxAbsoluteScalerTest.php b/tests/Transformers/MaxAbsoluteScalerTest.php index 7be4b1c73..a9923ad53 100644 --- a/tests/Transformers/MaxAbsoluteScalerTest.php +++ b/tests/Transformers/MaxAbsoluteScalerTest.php @@ -77,9 +77,9 @@ public function fitUpdateTransformReverse() : void $this->assertCount(3, $sample); - $this->assertEqualsWithDelta(0, $sample[0], 1); - $this->assertEqualsWithDelta(0, $sample[1], 1); - $this->assertEqualsWithDelta(0, $sample[2], 1); + $this->assertEqualsWithDelta(0, $sample[0], 1 + 1e-8); + $this->assertEqualsWithDelta(0, $sample[1], 1 + 1e-8); + $this->assertEqualsWithDelta(0, $sample[2], 1 + 1e-8); $dataset->reverseApply($this->transformer); diff --git a/tests/Transformers/RobustStandardizerTest.php b/tests/Transformers/RobustStandardizerTest.php index f1759c691..c706d9bfc 100644 --- a/tests/Transformers/RobustStandardizerTest.php +++ b/tests/Transformers/RobustStandardizerTest.php @@ -86,7 +86,7 @@ public function fitUpdateTransformReverse() : void $dataset->reverseApply($this->transformer); - $this->assertEquals($original, $dataset->sample(0)); + $this->assertEqualsWithDelta($original, $dataset->sample(0), 1e-8); } /**