From 189586575a22557dc95f4f3c5337879685b9ec6b Mon Sep 17 00:00:00 2001 From: Chris Simpson Date: Thu, 1 Oct 2020 13:33:23 +0100 Subject: [PATCH 01/20] [exceptions] update exception references --- src/Kernels/Distance/Gower.php | 4 ++-- src/ModelOrchestra.php | 12 ++++++------ src/NeuralNet/ActivationFunctions/ISRLU.php | 4 ++-- src/NeuralNet/ActivationFunctions/ISRU.php | 4 ++-- src/NeuralNet/Layers/AlphaDropout.php | 10 +++++----- src/Transformers/BM25Transformer.php | 10 +++++----- src/Transformers/DeltaTfIdfTransformer.php | 8 ++++---- src/Transformers/TokenHashingVectorizer.php | 4 ++-- tests/ModelOrchestraTest.php | 4 ++-- tests/Transformers/BM25TransformerTest.php | 2 +- tests/Transformers/DeltaTfIdfTransformerTest.php | 2 +- 11 files changed, 32 insertions(+), 32 deletions(-) diff --git a/src/Kernels/Distance/Gower.php b/src/Kernels/Distance/Gower.php index 89bdb2a..e6ea6cb 100644 --- a/src/Kernels/Distance/Gower.php +++ b/src/Kernels/Distance/Gower.php @@ -3,7 +3,7 @@ namespace Rubix\ML\Kernels\Distance; use Rubix\ML\DataType; -use InvalidArgumentException; +use Rubix\ML\Exceptions\InvalidArgumentException; use function count; @@ -37,7 +37,7 @@ class Gower implements Distance, NaNSafe /** * @param float $range - * @throws \InvalidArgumentException + * @throws \Rubix\ML\Exceptions\InvalidArgumentException */ public function __construct(float $range = 1.0) { diff --git a/src/ModelOrchestra.php b/src/ModelOrchestra.php index f69ce80..8cf1313 100644 --- a/src/ModelOrchestra.php +++ b/src/ModelOrchestra.php @@ -18,8 +18,8 @@ use Rubix\ML\Classifiers\SoftmaxClassifier; use Rubix\ML\Specifications\DatasetIsNotEmpty; use Rubix\ML\Specifications\SamplesAreCompatibleWithEstimator; -use InvalidArgumentException; -use RuntimeException; +use Rubix\ML\Exceptions\InvalidArgumentException; +use Rubix\ML\Exceptions\RuntimeException; use function count; use function in_array; @@ -71,7 +71,7 @@ class ModelOrchestra implements Learner, Parallel, Persistable, Verbose * @param \Rubix\ML\Learner[] $members * @param \Rubix\ML\Learner|null $conductor * @param float $ratio - * @throws \InvalidArgumentException + * @throws \Rubix\ML\Exceptions\InvalidArgumentException */ public function __construct(array $members, ?Learner $conductor = null, float $ratio = 0.8) { @@ -234,7 +234,7 @@ public function conductor() : Learner * training set. * * @param \Rubix\ML\Datasets\Dataset $dataset - * @throws \InvalidArgumentException + * @throws \Rubix\ML\Exceptions\InvalidArgumentException */ public function train(Dataset $dataset) : void { @@ -284,7 +284,7 @@ public function train(Dataset $dataset) : void * Make predictions from a dataset. * * @param \Rubix\ML\Datasets\Dataset $dataset - * @throws \RuntimeException + * @throws \Rubix\ML\Exceptions\RuntimeException * @return mixed[] */ public function predict(Dataset $dataset) : array @@ -302,7 +302,7 @@ public function predict(Dataset $dataset) : array * The callback that executes after the training task. * * @param \Rubix\ML\Learner $estimator - * @throws \RuntimeException + * @throws \Rubix\ML\Exceptions\RuntimeException */ public function afterTrain(Learner $estimator) : void { diff --git a/src/NeuralNet/ActivationFunctions/ISRLU.php b/src/NeuralNet/ActivationFunctions/ISRLU.php index 348ab42..1567349 100644 --- a/src/NeuralNet/ActivationFunctions/ISRLU.php +++ b/src/NeuralNet/ActivationFunctions/ISRLU.php @@ -3,7 +3,7 @@ namespace Rubix\ML\NeuralNet\ActivationFunctions; use Tensor\Matrix; -use InvalidArgumentException; +use Rubix\ML\Exceptions\InvalidArgumentException; /** * ISRLU @@ -30,7 +30,7 @@ class ISRLU implements ActivationFunction /** * @param float $alpha - * @throws \InvalidArgumentException + * @throws \Rubix\ML\Exceptions\InvalidArgumentException */ public function __construct(float $alpha = 1.0) { diff --git a/src/NeuralNet/ActivationFunctions/ISRU.php b/src/NeuralNet/ActivationFunctions/ISRU.php index bd5c7e0..ed8a6f7 100644 --- a/src/NeuralNet/ActivationFunctions/ISRU.php +++ b/src/NeuralNet/ActivationFunctions/ISRU.php @@ -3,7 +3,7 @@ namespace Rubix\ML\NeuralNet\ActivationFunctions; use Tensor\Matrix; -use InvalidArgumentException; +use Rubix\ML\Exceptions\InvalidArgumentException; /** * ISRU @@ -30,7 +30,7 @@ class ISRU implements ActivationFunction /** * @param float $alpha - * @throws \InvalidArgumentException + * @throws \Rubix\ML\Exceptions\InvalidArgumentException */ public function __construct(float $alpha = 1.0) { diff --git a/src/NeuralNet/Layers/AlphaDropout.php b/src/NeuralNet/Layers/AlphaDropout.php index 77d9bc9..cc450df 100644 --- a/src/NeuralNet/Layers/AlphaDropout.php +++ b/src/NeuralNet/Layers/AlphaDropout.php @@ -6,8 +6,8 @@ use Rubix\ML\Deferred; use Rubix\ML\NeuralNet\Optimizers\Optimizer; use Rubix\ML\NeuralNet\ActivationFunctions\SELU; -use InvalidArgumentException; -use RuntimeException; +use Rubix\ML\Exceptions\InvalidArgumentException; +use Rubix\ML\Exceptions\RuntimeException; /** * Alpha Dropout @@ -71,7 +71,7 @@ class AlphaDropout implements Hidden /** * @param float $ratio - * @throws \InvalidArgumentException + * @throws \Rubix\ML\Exceptions\InvalidArgumentException */ public function __construct(float $ratio = 0.1) { @@ -88,7 +88,7 @@ public function __construct(float $ratio = 0.1) /** * Return the width of the layer. * - * @throws \RuntimeException + * @throws \Rubix\ML\Exceptions\RuntimeException * @return int */ public function width() : int @@ -153,7 +153,7 @@ public function infer(Matrix $input) : Matrix * * @param \Rubix\ML\Deferred $prevGradient * @param \Rubix\ML\NeuralNet\Optimizers\Optimizer $optimizer - * @throws \RuntimeException + * @throws \Rubix\ML\Exceptions\RuntimeException * @return \Rubix\ML\Deferred */ public function back(Deferred $prevGradient, Optimizer $optimizer) : Deferred diff --git a/src/Transformers/BM25Transformer.php b/src/Transformers/BM25Transformer.php index b1c2f9f..cc8d036 100644 --- a/src/Transformers/BM25Transformer.php +++ b/src/Transformers/BM25Transformer.php @@ -5,8 +5,8 @@ use Rubix\ML\DataType; use Rubix\ML\Datasets\Dataset; use Rubix\ML\Specifications\SamplesAreCompatibleWithTransformer; -use InvalidArgumentException; -use RuntimeException; +use Rubix\ML\Exceptions\InvalidArgumentException; +use Rubix\ML\Exceptions\RuntimeException; use Stringable; use function is_null; @@ -82,7 +82,7 @@ class BM25Transformer implements Transformer, Stateful, Elastic, Stringable /** * @param float $alpha * @param float $beta - * @throws \InvalidArgumentException + * @throws \Rubix\ML\Exceptions\InvalidArgumentException */ public function __construct(float $alpha = 1.2, float $beta = 0.75) { @@ -160,7 +160,7 @@ public function fit(Dataset $dataset) : void * Update the fitting of the transformer. * * @param \Rubix\ML\Datasets\Dataset $dataset - * @throws \InvalidArgumentException + * @throws \Rubix\ML\Exceptions\InvalidArgumentException */ public function update(Dataset $dataset) : void { @@ -199,7 +199,7 @@ public function update(Dataset $dataset) : void * Transform the dataset in place. * * @param array[] $samples - * @throws \RuntimeException + * @throws \Rubix\ML\Exceptions\RuntimeException */ public function transform(array &$samples) : void { diff --git a/src/Transformers/DeltaTfIdfTransformer.php b/src/Transformers/DeltaTfIdfTransformer.php index 77353e0..e2c0792 100644 --- a/src/Transformers/DeltaTfIdfTransformer.php +++ b/src/Transformers/DeltaTfIdfTransformer.php @@ -6,8 +6,8 @@ use Rubix\ML\Datasets\Dataset; use Rubix\ML\Datasets\Labeled; use Rubix\ML\Specifications\SamplesAreCompatibleWithTransformer; -use InvalidArgumentException; -use RuntimeException; +use Rubix\ML\Exceptions\InvalidArgumentException; +use Rubix\ML\Exceptions\RuntimeException; use function is_null; @@ -137,7 +137,7 @@ public function dfs() : ?array * Fit the transformer to the dataset. * * @param \Rubix\ML\Datasets\Dataset $dataset - * @throws \InvalidArgumentException + * @throws \Rubix\ML\Exceptions\InvalidArgumentException */ public function fit(Dataset $dataset) : void { @@ -223,7 +223,7 @@ public function update(Dataset $dataset) : void * Transform the dataset in place. * * @param array[] $samples - * @throws \RuntimeException + * @throws \Rubix\ML\Exceptions\RuntimeException */ public function transform(array &$samples) : void { diff --git a/src/Transformers/TokenHashingVectorizer.php b/src/Transformers/TokenHashingVectorizer.php index a9d62e8..0b53f9f 100644 --- a/src/Transformers/TokenHashingVectorizer.php +++ b/src/Transformers/TokenHashingVectorizer.php @@ -6,7 +6,7 @@ use Rubix\ML\Datasets\Dataset; use Rubix\ML\Other\Tokenizers\Word; use Rubix\ML\Other\Tokenizers\Tokenizer; -use InvalidArgumentException; +use Rubix\ML\Exceptions\InvalidArgumentException; use Stringable; use function count; @@ -52,7 +52,7 @@ class TokenHashingVectorizer implements Transformer, Stringable /** * @param int $dimensions * @param \Rubix\ML\Other\Tokenizers\Tokenizer|null $tokenizer - * @throws \InvalidArgumentException + * @throws \Rubix\ML\Exceptions\InvalidArgumentException */ public function __construct(int $dimensions, ?Tokenizer $tokenizer = null) { diff --git a/tests/ModelOrchestraTest.php b/tests/ModelOrchestraTest.php index c9c252f..867455a 100644 --- a/tests/ModelOrchestraTest.php +++ b/tests/ModelOrchestraTest.php @@ -21,8 +21,8 @@ use Rubix\ML\Datasets\Generators\Agglomerate; use Rubix\ML\CrossValidation\Metrics\Accuracy; use PHPUnit\Framework\TestCase; -use InvalidArgumentException; -use RuntimeException; +use Rubix\ML\Exceptions\InvalidArgumentException; +use Rubix\ML\Exceptions\RuntimeException; class ModelOrchestraTest extends TestCase { diff --git a/tests/Transformers/BM25TransformerTest.php b/tests/Transformers/BM25TransformerTest.php index ccab0fd..3205edb 100644 --- a/tests/Transformers/BM25TransformerTest.php +++ b/tests/Transformers/BM25TransformerTest.php @@ -8,7 +8,7 @@ use Rubix\ML\Transformers\Transformer; use Rubix\ML\Transformers\BM25Transformer; use PHPUnit\Framework\TestCase; -use RuntimeException; +use Rubix\ML\Exceptions\RuntimeException; /** * @group Transformers diff --git a/tests/Transformers/DeltaTfIdfTransformerTest.php b/tests/Transformers/DeltaTfIdfTransformerTest.php index d75d028..4b06480 100644 --- a/tests/Transformers/DeltaTfIdfTransformerTest.php +++ b/tests/Transformers/DeltaTfIdfTransformerTest.php @@ -8,7 +8,7 @@ use Rubix\ML\Transformers\Transformer; use Rubix\ML\Transformers\DeltaTfIdfTransformer; use PHPUnit\Framework\TestCase; -use RuntimeException; +use Rubix\ML\Exceptions\RuntimeException; class DeltaTfIdfTransformerTest extends TestCase { From 5445eeb2c975916d497b190fa9dbd9d94457e90f Mon Sep 17 00:00:00 2001 From: Chris Simpson Date: Thu, 1 Oct 2020 13:36:18 +0100 Subject: [PATCH 02/20] [exceptions] Update composer.json --- composer.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer.json b/composer.json index 55af9f1..7c33f4c 100644 --- a/composer.json +++ b/composer.json @@ -17,7 +17,7 @@ ], "require": { "php": ">=7.2", - "rubix/ml": "^0.2.0", + "rubix/ml": "^0.3.0", "rubix/tensor": "^2.0.4", "wamania/php-stemmer": "^2.0", "league/flysystem": "2.0.0-beta.3" From 125b475d007e69e242ac54e3299ed7b314c8ea4a Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Fri, 16 Oct 2020 01:56:12 -0500 Subject: [PATCH 03/20] Use new specification chain instead of helper --- composer.json | 3 ++- phpstan.neon | 4 ---- src/ModelOrchestra.php | 10 +++++----- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/composer.json b/composer.json index 55af9f1..b3ec90d 100644 --- a/composer.json +++ b/composer.json @@ -17,7 +17,7 @@ ], "require": { "php": ">=7.2", - "rubix/ml": "^0.2.0", + "rubix/ml": "^0.3.0", "rubix/tensor": "^2.0.4", "wamania/php-stemmer": "^2.0", "league/flysystem": "2.0.0-beta.3" @@ -27,6 +27,7 @@ "league/flysystem-memory": "2.0.0-beta.3", "phpbench/phpbench": "0.17.*", "phpstan/phpstan": "0.12.*", + "phpstan/extension-installer": "^1.0", "phpstan/phpstan-phpunit": "0.12.*", "phpunit/phpunit": "8.5.*" }, diff --git a/phpstan.neon b/phpstan.neon index 3d6f41b..62b99f8 100644 --- a/phpstan.neon +++ b/phpstan.neon @@ -1,7 +1,3 @@ -includes: - - vendor/phpstan/phpstan-phpunit/extension.neon - - vendor/phpstan/phpstan-phpunit/rules.neon - parameters: level: 8 paths: diff --git a/src/ModelOrchestra.php b/src/ModelOrchestra.php index f69ce80..9ffb7ea 100644 --- a/src/ModelOrchestra.php +++ b/src/ModelOrchestra.php @@ -9,7 +9,6 @@ use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Other\Helpers\Params; use Rubix\ML\Backends\Tasks\Proba; -use Rubix\ML\Other\Helpers\Verifier; use Rubix\ML\Backends\Tasks\Predict; use Rubix\ML\Other\Traits\LoggerAware; use Rubix\ML\Other\Traits\PredictsSingle; @@ -17,6 +16,7 @@ use Rubix\ML\Other\Traits\Multiprocessing; use Rubix\ML\Classifiers\SoftmaxClassifier; use Rubix\ML\Specifications\DatasetIsNotEmpty; +use Rubix\ML\Specifications\SpecificationChain; use Rubix\ML\Specifications\SamplesAreCompatibleWithEstimator; use InvalidArgumentException; use RuntimeException; @@ -243,10 +243,10 @@ public function train(Dataset $dataset) : void . ' Labeled training set.'); } - Verifier::check([ - DatasetIsNotEmpty::with($dataset), - SamplesAreCompatibleWithEstimator::with($dataset, $this), - ]); + SpecificationChain::with([ + new DatasetIsNotEmpty($dataset), + new SamplesAreCompatibleWithEstimator($dataset, $this), + ])->check(); if ($this->logger) { $this->logger->info("$this initialized"); From f0f53e76d72b2fd514b792f2329124be865893ca Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Wed, 21 Oct 2020 01:44:37 -0500 Subject: [PATCH 04/20] Implicit Stringable through 0.3.0 Transformer interface --- composer.json | 3 +-- src/Transformers/BM25Transformer.php | 3 +-- src/Transformers/TokenHashingVectorizer.php | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/composer.json b/composer.json index cfd5653..631cf4b 100644 --- a/composer.json +++ b/composer.json @@ -17,7 +17,7 @@ ], "require": { "php": ">=7.2", - "rubix/ml": "^0.3.0", + "rubix/ml": "0.3.0.x-dev", "rubix/tensor": "^2.0.4", "wamania/php-stemmer": "^2.0", "league/flysystem": "2.0.0-beta.3" @@ -28,7 +28,6 @@ "phpbench/phpbench": "0.17.*", "phpstan/extension-installer": "^1.0", "phpstan/phpstan": "0.12.*", - "phpstan/extension-installer": "^1.0", "phpstan/phpstan-phpunit": "0.12.*", "phpunit/phpunit": "8.5.*" }, diff --git a/src/Transformers/BM25Transformer.php b/src/Transformers/BM25Transformer.php index cc8d036..2994a7a 100644 --- a/src/Transformers/BM25Transformer.php +++ b/src/Transformers/BM25Transformer.php @@ -7,7 +7,6 @@ use Rubix\ML\Specifications\SamplesAreCompatibleWithTransformer; use Rubix\ML\Exceptions\InvalidArgumentException; use Rubix\ML\Exceptions\RuntimeException; -use Stringable; use function is_null; @@ -27,7 +26,7 @@ * @package Rubix/ML * @author Andrew DalPino */ -class BM25Transformer implements Transformer, Stateful, Elastic, Stringable +class BM25Transformer implements Transformer, Stateful, Elastic { /** * The term frequency (TF) saturation factor. Lower values will cause TF to saturate quicker. diff --git a/src/Transformers/TokenHashingVectorizer.php b/src/Transformers/TokenHashingVectorizer.php index 0b53f9f..43c118b 100644 --- a/src/Transformers/TokenHashingVectorizer.php +++ b/src/Transformers/TokenHashingVectorizer.php @@ -7,7 +7,6 @@ use Rubix\ML\Other\Tokenizers\Word; use Rubix\ML\Other\Tokenizers\Tokenizer; use Rubix\ML\Exceptions\InvalidArgumentException; -use Stringable; use function count; use function is_string; @@ -26,7 +25,7 @@ * @package Rubix/ML * @author Andrew DalPino */ -class TokenHashingVectorizer implements Transformer, Stringable +class TokenHashingVectorizer implements Transformer { /** * The maximum number of dimensions supported. From b70e24669b84baff92248faf0a8c20a6c26344ed Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Wed, 21 Oct 2020 01:51:39 -0500 Subject: [PATCH 05/20] Fix CI pipeline --- src/Persisters/Flysystem.php | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/Persisters/Flysystem.php b/src/Persisters/Flysystem.php index df37f8f..5f725e2 100644 --- a/src/Persisters/Flysystem.php +++ b/src/Persisters/Flysystem.php @@ -3,15 +3,14 @@ namespace Rubix\ML\Persisters; use Rubix\ML\Encoding; -use Rubix\ML\Other\Helpers\Reflection; use Rubix\ML\Persistable; use Rubix\ML\Other\Helpers\Params; +use Rubix\ML\Other\Helpers\Reflection; use Rubix\ML\Persisters\Serializers\Native; use Rubix\ML\Persisters\Serializers\Serializer; use League\Flysystem\FilesystemOperator; use League\Flysystem\FilesystemException; use RuntimeException; -use Stringable; /** * Flysystem @@ -28,7 +27,7 @@ * @author Chris Simpson * @author Andrew DalPino */ -class Flysystem implements Persister, Stringable +class Flysystem implements Persister { /** * The extension to give files created as part of a persistable's save history. From c2c5e34c301e37b652bb4c13c267223a2e09c3fd Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Wed, 21 Oct 2020 02:01:21 -0500 Subject: [PATCH 06/20] Move K Best selector to main repo --- CHANGELOG.md | 3 + docs/transformers/k-best-selector.md | 30 ----- src/Transformers/KBestSelector.php | 168 --------------------------- 3 files changed, 3 insertions(+), 198 deletions(-) delete mode 100644 docs/transformers/k-best-selector.md delete mode 100644 src/Transformers/KBestSelector.php diff --git a/CHANGELOG.md b/CHANGELOG.md index 32bd551..f73e045 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,6 @@ +- 0.3.0-beta + - Move K Best Selector to main repository + - 0.2.1-beta - Implemented K Best feature selector diff --git a/docs/transformers/k-best-selector.md b/docs/transformers/k-best-selector.md deleted file mode 100644 index 9529dad..0000000 --- a/docs/transformers/k-best-selector.md +++ /dev/null @@ -1,30 +0,0 @@ -[source] - -# K Best Selector -A supervised feature selector that picks the top K ranked features returned by a learner implementing the [RanksFeatures](../ranks-features.md) interface. - -> **Note:** The default feature ranking base learner is a fully-grown decision tree. - -**Interfaces:** [Transformer](api.md#transformer), [Stateful](api.md#stateful) - -**Data Type Compatibility:** Depends on the base learner - -## Parameters -| # | Param | Default | Type | Description | -|---|---|---|---|---| -| 1 | k | | int | The maximum number of features to select from the dataset. | -| 2 | estimator | Auto | RanksFeatures | The base feature ranking learner instance. | - -## Additional Methods -Return the final importances of the selected feature columns: -``` php -public importances() : ?array -``` - -## Example -```php -use Rubix\ML\Transformers\KBestSelector; -use Rubix\ML\Classifiers\GradientBoost; - -$transformer = new KBestSelector(10, new GradientBoost()); -``` diff --git a/src/Transformers/KBestSelector.php b/src/Transformers/KBestSelector.php deleted file mode 100644 index a748e65..0000000 --- a/src/Transformers/KBestSelector.php +++ /dev/null @@ -1,168 +0,0 @@ -k = $k; - $this->estimator = $estimator; - $this->fitBase = is_null($estimator); - } - - /** - * Return the data types that this transformer is compatible with. - * - * @return list<\Rubix\ML\DataType> - */ - public function compatibility() : array - { - return DataType::all(); - } - - /** - * Is the transformer fitted? - * - * @return bool - */ - public function fitted() : bool - { - return isset($this->importances); - } - - /** - * Return the final importances of the selected feature columns. - * - * @return float[]|null - */ - public function importances() : ?array - { - return $this->importances; - } - - /** - * Fit the transformer to the dataset. - * - * @param \Rubix\ML\Datasets\Dataset $dataset - * @throws \InvalidArgumentException - */ - public function fit(Dataset $dataset) : void - { - if (!$dataset instanceof Labeled) { - throw new InvalidArgumentException('Transformer requires a' - . ' Labeled training set.'); - } - - if ($this->fitBase or is_null($this->estimator)) { - switch ($dataset->labelType()) { - case DataType::categorical(): - $this->estimator = new ClassificationTree(); - - break 1; - - case DataType::continuous(): - $this->estimator = new RegressionTree(); - - break 1; - - default: - throw new InvalidArgumentException('No compatible base' - . " learner for {$dataset->labelType()} label type."); - } - } - - $this->estimator->train($dataset); - - $importances = $this->estimator->featureImportances(); - - asort($importances); - - $this->importances = array_slice($importances, 0, $this->k, true); - } - - /** - * Transform the dataset in place. - * - * @param array[] $samples - * @throws \RuntimeException - */ - public function transform(array &$samples) : void - { - if (is_null($this->importances)) { - throw new RuntimeException('Transformer has not been fitted.'); - } - - foreach ($samples as &$sample) { - $sample = array_values(array_intersect_key($sample, $this->importances)); - } - } - - /** - * Return the string representation of the object. - * - * @return string - */ - public function __toString() : string - { - return "K Best Selector (k: {$this->k}, estimator: {$this->estimator})"; - } -} From bcacf5b69d0ea7f2a57fe499f04e531d1b1a635c Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Fri, 23 Oct 2020 03:55:46 -0500 Subject: [PATCH 07/20] Added Gzip serializer --- CHANGELOG.md | 1 + docs/persisters/Serializers/gzip.md | 18 +++++ src/Persisters/Serializers/Gzip.php | 96 +++++++++++++++++++++++ tests/Persisters/Serializers/GzipTest.php | 61 ++++++++++++++ 4 files changed, 176 insertions(+) create mode 100644 docs/persisters/Serializers/gzip.md create mode 100644 src/Persisters/Serializers/Gzip.php create mode 100644 tests/Persisters/Serializers/GzipTest.php diff --git a/CHANGELOG.md b/CHANGELOG.md index f73e045..bdf886c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ - 0.3.0-beta + - Added Gzip serializer - Move K Best Selector to main repository - 0.2.1-beta diff --git a/docs/persisters/Serializers/gzip.md b/docs/persisters/Serializers/gzip.md new file mode 100644 index 0000000..c1ac875 --- /dev/null +++ b/docs/persisters/Serializers/gzip.md @@ -0,0 +1,18 @@ +[source] + +# Native +A compression format based on the DEFLATE algorithm with a header and CRC32 checksum. + +## Parameters +| # | Param | Default | Type | Description | +|---|---|---|---|---| +| 1 | level | 1 | int | The compression level between 0 and 9, 0 meaning no compression. | +| 2 | serializer | Native | Serializer | The base serializer | + +## Example +```php +use Rubix\ML\Persisters\Serializers\Gzip; +use Rubix\ML\Persisters\Serializers\Native; + +$serializer = new Gzip(1, new Native()); +``` \ No newline at end of file diff --git a/src/Persisters/Serializers/Gzip.php b/src/Persisters/Serializers/Gzip.php new file mode 100644 index 0000000..5d62b0d --- /dev/null +++ b/src/Persisters/Serializers/Gzip.php @@ -0,0 +1,96 @@ + 9) { + throw new InvalidArgumentException('Level must be' + . " between 0 and 9, $level given."); + } + + $this->level = $level; + $this->serializer = $serializer ?? new Native(); + } + + /** + * Serialize a persistable object and return the data. + * + * @param \Rubix\ML\Persistable $persistable + * @return \Rubix\ML\Encoding + */ + public function serialize(Persistable $persistable) : Encoding + { + $encoding = $this->serializer->serialize($persistable); + + $data = gzencode((string) $encoding, $this->level); + + if ($data === false) { + throw new RuntimeException('Failed to compress data.'); + } + + return new Encoding($data); + } + + /** + * Unserialize a persistable object and return it. + * + * @param \Rubix\ML\Encoding $encoding + * @throws \Rubix\ML\Exceptions\RuntimeException + * @return \Rubix\ML\Persistable + */ + public function unserialize(Encoding $encoding) : Persistable + { + $data = gzdecode((string) $encoding); + + if ($data === false) { + throw new RuntimeException('Failed to decompress data.'); + } + + return $this->serializer->unserialize(new Encoding($data)); + } + + /** + * Return the string representation of the object. + * + * @return string + */ + public function __toString() : string + { + return "Gzip (level: {$this->level}, serializer: {$this->serializer})"; + } +} diff --git a/tests/Persisters/Serializers/GzipTest.php b/tests/Persisters/Serializers/GzipTest.php new file mode 100644 index 0000000..647832d --- /dev/null +++ b/tests/Persisters/Serializers/GzipTest.php @@ -0,0 +1,61 @@ +serializer = new Gzip(1); + + $this->persistable = new DummyClassifier(); + } + + /** + * @test + */ + public function build() : void + { + $this->assertInstanceOf(Gzip::class, $this->serializer); + $this->assertInstanceOf(Serializer::class, $this->serializer); + } + + /** + * @test + */ + public function serializeUnserialize() : void + { + $data = $this->serializer->serialize($this->persistable); + + $this->assertInstanceOf(Encoding::class, $data); + + $persistable = $this->serializer->unserialize($data); + + $this->assertInstanceOf(DummyClassifier::class, $persistable); + $this->assertInstanceOf(Persistable::class, $persistable); + } +} From f494b5224634338661fabd7b2c9657ffe3e4728a Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Fri, 23 Oct 2020 04:09:27 -0500 Subject: [PATCH 08/20] Import runtime exception class --- src/Persisters/Serializers/Gzip.php | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Persisters/Serializers/Gzip.php b/src/Persisters/Serializers/Gzip.php index 5d62b0d..0679c43 100644 --- a/src/Persisters/Serializers/Gzip.php +++ b/src/Persisters/Serializers/Gzip.php @@ -4,6 +4,7 @@ use Rubix\ML\Encoding; use Rubix\ML\Persistable; +use Rubix\ML\Exceptions\RuntimeException; use Rubix\ML\Exceptions\InvalidArgumentException; /** From 3bb6f6706b5b6911933fea77b0d101d0a3fe2488 Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Fri, 23 Oct 2020 04:10:45 -0500 Subject: [PATCH 09/20] Fix typo --- docs/persisters/Serializers/gzip.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/persisters/Serializers/gzip.md b/docs/persisters/Serializers/gzip.md index c1ac875..ca38f65 100644 --- a/docs/persisters/Serializers/gzip.md +++ b/docs/persisters/Serializers/gzip.md @@ -1,6 +1,6 @@ [source] -# Native +# Gzip A compression format based on the DEFLATE algorithm with a header and CRC32 checksum. ## Parameters From bc2d359e52bc82ee38d2d857130a1a48e3e21c59 Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Fri, 23 Oct 2020 19:57:27 -0500 Subject: [PATCH 10/20] Implemented Bzip2 serializer --- .github/workflows/ci.yml | 2 + CHANGELOG.md | 2 +- docs/persisters/Serializers/bzip2.md | 24 ++++ docs/persisters/Serializers/gzip.md | 5 +- src/Persisters/Serializers/Bzip2.php | 123 +++++++++++++++++++++ src/Persisters/Serializers/Gzip.php | 9 ++ tests/Persisters/Serializers/Bzip2Test.php | 61 ++++++++++ 7 files changed, 224 insertions(+), 2 deletions(-) create mode 100644 docs/persisters/Serializers/bzip2.md create mode 100644 src/Persisters/Serializers/Bzip2.php create mode 100644 tests/Persisters/Serializers/Bzip2Test.php diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ba13201..3684f23 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,6 +19,8 @@ jobs: uses: shivammathur/setup-php@v2 with: php-version: ${{ matrix.php-versions }} + tools: pecl + extensions: bz2 ini-values: memory_limit=-1 - name: Validate composer.json diff --git a/CHANGELOG.md b/CHANGELOG.md index bdf886c..db31a63 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,5 @@ - 0.3.0-beta - - Added Gzip serializer + - Added Gzip and Bzip2 serializers - Move K Best Selector to main repository - 0.2.1-beta diff --git a/docs/persisters/Serializers/bzip2.md b/docs/persisters/Serializers/bzip2.md new file mode 100644 index 0000000..47e9a96 --- /dev/null +++ b/docs/persisters/Serializers/bzip2.md @@ -0,0 +1,24 @@ +[source] + +# Bzip2 +A compression format based on the Burrows–Wheeler transform. Bzip2 is slightly smaller than Gzip format but is slower and requires more memory. + +> **Note:** This serializer requires the Bzip2 PHP extension. + +## Parameters +| # | Param | Default | Type | Description | +|---|---|---|---|---| +| 1 | block size | 4 | int | The size of each block between 1 and 9 where 9 gives the best compression. | +| 2 | work factor | 0 | int | Controls how the compression phase behaves when the input is highly repetitive. | +| 3 | serializer | Native | Serializer | The base serializer | + +## Example +```php +use Rubix\ML\Persisters\Serializers\Bzip2; +use Rubix\ML\Persisters\Serializers\Native; + +$serializer = new Bzip2(4, 125, new Native()); +``` + +### References +>- J. Tsai. (2006). Bzip2: Format Specification. \ No newline at end of file diff --git a/docs/persisters/Serializers/gzip.md b/docs/persisters/Serializers/gzip.md index ca38f65..5e25e72 100644 --- a/docs/persisters/Serializers/gzip.md +++ b/docs/persisters/Serializers/gzip.md @@ -15,4 +15,7 @@ use Rubix\ML\Persisters\Serializers\Gzip; use Rubix\ML\Persisters\Serializers\Native; $serializer = new Gzip(1, new Native()); -``` \ No newline at end of file +``` + +### References +>- P. Deutsch. (1996). RFC 1951 - DEFLATE Compressed Data Format Specification version. \ No newline at end of file diff --git a/src/Persisters/Serializers/Bzip2.php b/src/Persisters/Serializers/Bzip2.php new file mode 100644 index 0000000..c63f622 --- /dev/null +++ b/src/Persisters/Serializers/Bzip2.php @@ -0,0 +1,123 @@ + 9) { + throw new InvalidArgumentException('Block size must' + . " be between 0 and 9, $blockSize given."); + } + + if ($serializer instanceof self) { + throw new InvalidArgumentException('Base serializer' + . ' must not be an instance of itself.'); + } + + $this->blockSize = $blockSize; + $this->workFactor = $workFactor; + $this->serializer = $serializer ?? new Native(); + } + + /** + * Serialize a persistable object and return the data. + * + * @param \Rubix\ML\Persistable $persistable + * @return \Rubix\ML\Encoding + */ + public function serialize(Persistable $persistable) : Encoding + { + $encoding = $this->serializer->serialize($persistable); + + $data = bzcompress((string) $encoding, $this->blockSize, $this->workFactor); + + if (!is_string($data)) { + throw new RuntimeException('Failed to compress data.'); + } + + return new Encoding($data); + } + + /** + * Unserialize a persistable object and return it. + * + * @param \Rubix\ML\Encoding $encoding + * @throws \Rubix\ML\Exceptions\RuntimeException + * @return \Rubix\ML\Persistable + */ + public function unserialize(Encoding $encoding) : Persistable + { + $data = bzdecompress((string) $encoding); + + if (!is_string($data)) { + throw new RuntimeException('Failed to decompress data.'); + } + + return $this->serializer->unserialize(new Encoding($data)); + } + + /** + * Return the string representation of the object. + * + * @return string + */ + public function __toString() : string + { + return "Bzip2 (block size: {$this->blockSize}, work factor: {$this->workFactor}," + . " serializer: {$this->serializer})"; + } +} diff --git a/src/Persisters/Serializers/Gzip.php b/src/Persisters/Serializers/Gzip.php index 0679c43..adb56b1 100644 --- a/src/Persisters/Serializers/Gzip.php +++ b/src/Persisters/Serializers/Gzip.php @@ -12,6 +12,10 @@ * * A compression format based on the DEFLATE algorithm with a header and checksum. * + * References: + * [1] P. Deutsch. (1996). RFC 1951 - DEFLATE Compressed Data Format Specification + * version. + * * @category Machine Learning * @package Rubix/ML * @author Andrew DalPino @@ -44,6 +48,11 @@ public function __construct(int $level = 1, ?Serializer $serializer = null) . " between 0 and 9, $level given."); } + if ($serializer instanceof self) { + throw new InvalidArgumentException('Base serializer' + . ' must not be an instance of itself.'); + } + $this->level = $level; $this->serializer = $serializer ?? new Native(); } diff --git a/tests/Persisters/Serializers/Bzip2Test.php b/tests/Persisters/Serializers/Bzip2Test.php new file mode 100644 index 0000000..5d7214c --- /dev/null +++ b/tests/Persisters/Serializers/Bzip2Test.php @@ -0,0 +1,61 @@ +serializer = new Bzip2(4, 0); + + $this->persistable = new DummyClassifier(); + } + + /** + * @test + */ + public function build() : void + { + $this->assertInstanceOf(Bzip2::class, $this->serializer); + $this->assertInstanceOf(Serializer::class, $this->serializer); + } + + /** + * @test + */ + public function serializeUnserialize() : void + { + $data = $this->serializer->serialize($this->persistable); + + $this->assertInstanceOf(Encoding::class, $data); + + $persistable = $this->serializer->unserialize($data); + + $this->assertInstanceOf(DummyClassifier::class, $persistable); + $this->assertInstanceOf(Persistable::class, $persistable); + } +} From 47f52cb9a3f9e8c455c2d49a0b603713c551f648 Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Fri, 23 Oct 2020 20:04:22 -0500 Subject: [PATCH 11/20] Add Bz2 extension to optional --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 457024e..706baa7 100644 --- a/README.md +++ b/README.md @@ -10,5 +10,8 @@ $ composer require rubix/extras ### Requirements - [PHP](https://php.net/manual/en/install.php) 7.2 or above +##### Optional +- [Bzip2 extension](https://www.php.net/manual/en/book.bzip2.php) for Bzip2 compression + ## License [MIT](https://github.com/RubixML/Extras/blob/master/LICENSE.md) From 925895936f1d6d6dd591d94b99eb34449902f133 Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Tue, 27 Oct 2020 17:20:58 -0500 Subject: [PATCH 12/20] Add bz2 extension requirement to test --- tests/Persisters/Serializers/Bzip2Test.php | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/Persisters/Serializers/Bzip2Test.php b/tests/Persisters/Serializers/Bzip2Test.php index 5d7214c..d8519da 100644 --- a/tests/Persisters/Serializers/Bzip2Test.php +++ b/tests/Persisters/Serializers/Bzip2Test.php @@ -11,6 +11,7 @@ /** * @group Serializers + * @requires extension bz2 * @covers \Rubix\ML\Persisters\Serializers\Bzip2 */ class Bzip2Test extends TestCase From 4c9862617340185c195981969035f3470db4a7e1 Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Sun, 27 Dec 2020 01:45:28 -0600 Subject: [PATCH 13/20] Move Gzip serializer over to main repo --- CHANGELOG.md | 3 +- docs/persisters/Serializers/gzip.md | 21 ----- src/Kernels/Distance/Gower.php | 6 +- src/ModelOrchestra.php | 4 +- src/Persisters/Serializers/Gzip.php | 106 ---------------------- tests/Persisters/Serializers/GzipTest.php | 61 ------------- 6 files changed, 7 insertions(+), 194 deletions(-) delete mode 100644 docs/persisters/Serializers/gzip.md delete mode 100644 src/Persisters/Serializers/Gzip.php delete mode 100644 tests/Persisters/Serializers/GzipTest.php diff --git a/CHANGELOG.md b/CHANGELOG.md index db31a63..815119c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ - 0.3.0-beta - - Added Gzip and Bzip2 serializers + - Added Bzip2 serializers - Move K Best Selector to main repository + - Added custom exceptions from the main repo - 0.2.1-beta - Implemented K Best feature selector diff --git a/docs/persisters/Serializers/gzip.md b/docs/persisters/Serializers/gzip.md deleted file mode 100644 index 5e25e72..0000000 --- a/docs/persisters/Serializers/gzip.md +++ /dev/null @@ -1,21 +0,0 @@ -[source] - -# Gzip -A compression format based on the DEFLATE algorithm with a header and CRC32 checksum. - -## Parameters -| # | Param | Default | Type | Description | -|---|---|---|---|---| -| 1 | level | 1 | int | The compression level between 0 and 9, 0 meaning no compression. | -| 2 | serializer | Native | Serializer | The base serializer | - -## Example -```php -use Rubix\ML\Persisters\Serializers\Gzip; -use Rubix\ML\Persisters\Serializers\Native; - -$serializer = new Gzip(1, new Native()); -``` - -### References ->- P. Deutsch. (1996). RFC 1951 - DEFLATE Compressed Data Format Specification version. \ No newline at end of file diff --git a/src/Kernels/Distance/Gower.php b/src/Kernels/Distance/Gower.php index e6ea6cb..1b7c995 100644 --- a/src/Kernels/Distance/Gower.php +++ b/src/Kernels/Distance/Gower.php @@ -81,18 +81,18 @@ public function compute(array $a, array $b) : float case is_float($valueA) and is_nan($valueA): ++$nn; - break 1; + break; case is_float($valueB) and is_nan($valueB): ++$nn; - break 1; + break; case !is_string($valueA) and !is_string($valueB): $distance += abs($valueA - $valueB) / $this->range; - break 1; + break; default: if ($valueA !== $valueB) { diff --git a/src/ModelOrchestra.php b/src/ModelOrchestra.php index d8abe93..714d3c5 100644 --- a/src/ModelOrchestra.php +++ b/src/ModelOrchestra.php @@ -135,12 +135,12 @@ public function __construct(array $members, ?Learner $conductor = null, float $r case EstimatorType::classifier(): $conductor = new SoftmaxClassifier(); - break 1; + break; case EstimatorType::regressor(): $conductor = new Ridge(); - break 1; + break; default: $conductor = new Ridge(); diff --git a/src/Persisters/Serializers/Gzip.php b/src/Persisters/Serializers/Gzip.php deleted file mode 100644 index adb56b1..0000000 --- a/src/Persisters/Serializers/Gzip.php +++ /dev/null @@ -1,106 +0,0 @@ - 9) { - throw new InvalidArgumentException('Level must be' - . " between 0 and 9, $level given."); - } - - if ($serializer instanceof self) { - throw new InvalidArgumentException('Base serializer' - . ' must not be an instance of itself.'); - } - - $this->level = $level; - $this->serializer = $serializer ?? new Native(); - } - - /** - * Serialize a persistable object and return the data. - * - * @param \Rubix\ML\Persistable $persistable - * @return \Rubix\ML\Encoding - */ - public function serialize(Persistable $persistable) : Encoding - { - $encoding = $this->serializer->serialize($persistable); - - $data = gzencode((string) $encoding, $this->level); - - if ($data === false) { - throw new RuntimeException('Failed to compress data.'); - } - - return new Encoding($data); - } - - /** - * Unserialize a persistable object and return it. - * - * @param \Rubix\ML\Encoding $encoding - * @throws \Rubix\ML\Exceptions\RuntimeException - * @return \Rubix\ML\Persistable - */ - public function unserialize(Encoding $encoding) : Persistable - { - $data = gzdecode((string) $encoding); - - if ($data === false) { - throw new RuntimeException('Failed to decompress data.'); - } - - return $this->serializer->unserialize(new Encoding($data)); - } - - /** - * Return the string representation of the object. - * - * @return string - */ - public function __toString() : string - { - return "Gzip (level: {$this->level}, serializer: {$this->serializer})"; - } -} diff --git a/tests/Persisters/Serializers/GzipTest.php b/tests/Persisters/Serializers/GzipTest.php deleted file mode 100644 index 647832d..0000000 --- a/tests/Persisters/Serializers/GzipTest.php +++ /dev/null @@ -1,61 +0,0 @@ -serializer = new Gzip(1); - - $this->persistable = new DummyClassifier(); - } - - /** - * @test - */ - public function build() : void - { - $this->assertInstanceOf(Gzip::class, $this->serializer); - $this->assertInstanceOf(Serializer::class, $this->serializer); - } - - /** - * @test - */ - public function serializeUnserialize() : void - { - $data = $this->serializer->serialize($this->persistable); - - $this->assertInstanceOf(Encoding::class, $data); - - $persistable = $this->serializer->unserialize($data); - - $this->assertInstanceOf(DummyClassifier::class, $persistable); - $this->assertInstanceOf(Persistable::class, $persistable); - } -} From 5402695b76acb202812a94166c32f170a73ddd31 Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Sun, 27 Dec 2020 03:17:28 -0600 Subject: [PATCH 14/20] Added Levenshtein distance kernel --- CHANGELOG.md | 1 + LICENSE | 2 +- docs/kernels/distance/levenshtein.md | 16 +++++ src/Kernels/Distance/Gower.php | 4 +- src/Kernels/Distance/Levenshtein.php | 52 +++++++++++++++ tests/Kernels/Distance/GowerTest.php | 10 +-- tests/Kernels/Distance/LevenshteinTest.php | 67 ++++++++++++++++++++ tests/Transformers/KBestSelectorTest.php | 73 ---------------------- 8 files changed, 144 insertions(+), 81 deletions(-) create mode 100644 docs/kernels/distance/levenshtein.md create mode 100644 src/Kernels/Distance/Levenshtein.php create mode 100644 tests/Kernels/Distance/LevenshteinTest.php delete mode 100644 tests/Transformers/KBestSelectorTest.php diff --git a/CHANGELOG.md b/CHANGELOG.md index 815119c..eb3f635 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ - 0.3.0-beta - Added Bzip2 serializers + - Added Levenshtein distance kernel - Move K Best Selector to main repository - Added custom exceptions from the main repo diff --git a/LICENSE b/LICENSE index 38420b1..60c1e93 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2020 The Rubix ML Community +Copyright (c) 2020 Rubix ML Copyright (c) 2020 Andrew DalPino Permission is hereby granted, free of charge, to any person obtaining a copy diff --git a/docs/kernels/distance/levenshtein.md b/docs/kernels/distance/levenshtein.md new file mode 100644 index 0000000..1324a6a --- /dev/null +++ b/docs/kernels/distance/levenshtein.md @@ -0,0 +1,16 @@ +[source] + +# Levenshtein +Levenshtein distance is defined as the number of single-character edits (such as insert, delete, or replace) needed to change one word to another. + +**Data Type Compatibility:** Categorical + +## Parameters +This kernel does not have any parameters. + +## Example +```php +use Rubix\ML\Kernels\Distance\Levenshtein; + +$kernel = new Levenshtein(); +``` diff --git a/src/Kernels/Distance/Gower.php b/src/Kernels/Distance/Gower.php index 1b7c995..30debd7 100644 --- a/src/Kernels/Distance/Gower.php +++ b/src/Kernels/Distance/Gower.php @@ -65,8 +65,8 @@ public function compatibility() : array /** * Compute the distance between two vectors. * - * @param (string|int|float)[] $a - * @param (string|int|float)[] $b + * @param list $a + * @param list $b * @return float */ public function compute(array $a, array $b) : float diff --git a/src/Kernels/Distance/Levenshtein.php b/src/Kernels/Distance/Levenshtein.php new file mode 100644 index 0000000..9b76779 --- /dev/null +++ b/src/Kernels/Distance/Levenshtein.php @@ -0,0 +1,52 @@ + $a + * @param list $b + * @return float + */ + public function compute(array $a, array $b) : float + { + return (float) array_sum(array_map('levenshtein', $a, $b)); + } + + /** + * Return the string representation of the object. + * + * @return string + */ + public function __toString() : string + { + return 'Levenshtein'; + } +} diff --git a/tests/Kernels/Distance/GowerTest.php b/tests/Kernels/Distance/GowerTest.php index 35922d7..3f8bf18 100644 --- a/tests/Kernels/Distance/GowerTest.php +++ b/tests/Kernels/Distance/GowerTest.php @@ -41,15 +41,15 @@ public function build() : void * @test * @dataProvider computeProvider * - * @param (string|int|float)[] $a - * @param (string|int|float)[] $b + * @param list $a + * @param list $b * @param float $expected */ public function compute(array $a, array $b, $expected) : void { $distance = $this->kernel->compute($a, $b); - $this->assertGreaterThanOrEqual(0., $distance); + $this->assertGreaterThanOrEqual(0.0, $distance); $this->assertEquals($expected, $distance); } @@ -58,9 +58,9 @@ public function compute(array $a, array $b, $expected) : void */ public function computeProvider() : Generator { - yield [['toast', 1., 0.5, NAN], ['pretzels', 1., 0.2, 0.1], 0.43333333333333335]; + yield [['toast', 1.0, 0.5, NAN], ['pretzels', 1.0, 0.2, 0.1], 0.43333333333333335]; - yield [[0., 1., 0.5, 'ham'], [0.1, 0.9, 0.4, 'ham'], 0.07499999999999998]; + yield [[0.0, 1.0, 0.5, 'ham'], [0.1, 0.9, 0.4, 'ham'], 0.07499999999999998]; yield [[1, NAN, 1], [1, NAN, 1], 0.0]; } diff --git a/tests/Kernels/Distance/LevenshteinTest.php b/tests/Kernels/Distance/LevenshteinTest.php new file mode 100644 index 0000000..29a7793 --- /dev/null +++ b/tests/Kernels/Distance/LevenshteinTest.php @@ -0,0 +1,67 @@ +kernel = new Levenshtein(); + } + + /** + * @test + */ + public function build() : void + { + $this->assertInstanceOf(Levenshtein::class, $this->kernel); + $this->assertInstanceOf(Distance::class, $this->kernel); + } + + /** + * @test + * @dataProvider computeProvider + * + * @param list $a + * @param list $b + * @param float $expected + */ + public function compute(array $a, array $b, $expected) : void + { + $distance = $this->kernel->compute($a, $b); + + $this->assertGreaterThanOrEqual(0.0, $distance); + $this->assertEquals($expected, $distance); + } + + /** + * @return \Generator + */ + public function computeProvider() : Generator + { + yield [['aaa'], ['aaaaaa'], 3.0]; + + yield [['toast', 'naan'], ['pretzels', 'pizza'], 12.0]; + + yield [['Beef'], ['feeB'], 2.0]; + + yield [['Levenshtein'], ['Levanshtein'], 1.0]; + } +} diff --git a/tests/Transformers/KBestSelectorTest.php b/tests/Transformers/KBestSelectorTest.php deleted file mode 100644 index 69a24af..0000000 --- a/tests/Transformers/KBestSelectorTest.php +++ /dev/null @@ -1,73 +0,0 @@ -generator = new Agglomerate([ - 'male' => new Blob([69.2, 195.7, 40.0], [1.0, 3.0, 0.3]), - 'female' => new Blob([63.7, 168.5, 38.1], [0.8, 2.5, 0.4]), - ], [0.45, 0.55]); - - $this->transformer = new KBestSelector(1); - } - - /** - * @test - */ - public function build() : void - { - $this->assertInstanceOf(KBestSelector::class, $this->transformer); - $this->assertInstanceOf(Transformer::class, $this->transformer); - $this->assertInstanceOf(Stateful::class, $this->transformer); - } - - /** - * @test - */ - public function fitTransform() : void - { - $dataset = $this->generator->generate(100); - - $this->assertEquals(3, $dataset->numColumns()); - - $dataset->apply($this->transformer); - - $this->assertEquals(1, $dataset->numColumns()); - } - - /** - * @test - */ - public function transformUnfitted() : void - { - $this->expectException(RuntimeException::class); - - $samples = $this->generator->generate(1)->samples(); - - $this->transformer->transform($samples); - } -} From 3ddab5d8eeac8055907bca76aaf0f7f84d8b2cc3 Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Mon, 28 Dec 2020 21:54:38 -0600 Subject: [PATCH 15/20] Added Vantage Point Tree --- .github/FUNDING.yml | 1 + CHANGELOG.md | 1 + .../Trees/VPTreeBench.php} | 24 +- composer.json | 5 + docs/graph/trees/vp-tree.md | 28 ++ src/Graph/Nodes/VantagePoint.php | 70 ++++ src/Graph/Trees/VPTree.php | 346 ++++++++++++++++++ tests/Graph/Trees/VPTreeTest.php | 108 ++++++ 8 files changed, 571 insertions(+), 12 deletions(-) create mode 100644 .github/FUNDING.yml rename benchmarks/{Transformers/RecursiveFeatureEliminatorBench.php => Graph/Trees/VPTreeBench.php} (60%) create mode 100644 docs/graph/trees/vp-tree.md create mode 100644 src/Graph/Nodes/VantagePoint.php create mode 100644 src/Graph/Trees/VPTree.php create mode 100644 tests/Graph/Trees/VPTreeTest.php diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 0000000..4ea1485 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1 @@ +github: [RubixML, andrewdalpino] diff --git a/CHANGELOG.md b/CHANGELOG.md index eb3f635..9330dbb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ - 0.3.0-beta + - Added Vantage Point Tree for spatial queries - Added Bzip2 serializers - Added Levenshtein distance kernel - Move K Best Selector to main repository diff --git a/benchmarks/Transformers/RecursiveFeatureEliminatorBench.php b/benchmarks/Graph/Trees/VPTreeBench.php similarity index 60% rename from benchmarks/Transformers/RecursiveFeatureEliminatorBench.php rename to benchmarks/Graph/Trees/VPTreeBench.php index 53f04af..8369078 100644 --- a/benchmarks/Transformers/RecursiveFeatureEliminatorBench.php +++ b/benchmarks/Graph/Trees/VPTreeBench.php @@ -1,28 +1,28 @@ dataset = $generator->generate(self::DATASET_SIZE); - $this->transformer = new KBestSelector(2); + $this->tree = new VPTree(30); } /** * @Subject * @Iterations(3) - * @OutputTimeUnit("milliseconds", precision=3) + * @OutputTimeUnit("seconds", precision=3) */ - public function apply() : void + public function grow() : void { - $this->dataset->apply($this->transformer); + $this->tree->grow($this->dataset); } } diff --git a/composer.json b/composer.json index 631cf4b..89d966b 100644 --- a/composer.json +++ b/composer.json @@ -10,6 +10,11 @@ "ai", "rubixml", "rubix ml" ], "authors": [ + { + "name": "Andrew DalPino", + "homepage": "https://github.com/andrewdalpino", + "role": "Lead Engineer" + }, { "name": "Contributors", "homepage": "https://github.com/RubixML/Extras/graphs/contributors" diff --git a/docs/graph/trees/vp-tree.md b/docs/graph/trees/vp-tree.md new file mode 100644 index 0000000..498d52a --- /dev/null +++ b/docs/graph/trees/vp-tree.md @@ -0,0 +1,28 @@ +[source] + +# VP Tree +A Vantage Point Tree is a binary spatial tree that divides samples by their distance from the center of a cluster called the *vantage point*. Samples that are closer to the vantage point will be put into one branch of the tree while samples that are farther away will be put into the other branch. + +**Interfaces:** Binary Tree, Spatial + +**Data Type Compatibility:** Depends on distance kernel + +## Parameters +| # | Param | Default | Type | Description | +|---|---|---|---|---| +| 1 | max leaf size | 30 | int | The maximum number of samples that each leaf node can contain. | +| 2 | kernel | Euclidean | Distance | The distance kernel used to compute the distance between sample points. | + +## Example +```php +use Rubix\ML\Graph\Trees\VPTree; +use Rubix\ML\Kernels\Distance\Euclidean; + +$tree = new VPTree(30, new Euclidean()); +``` + +## Additional Methods +This tree does not have any additional methods. + +### References +>- P. N. Yianilos. (1993). Data Structures and Algorithms for Nearest Neighbor Search in General Metric Spaces. \ No newline at end of file diff --git a/src/Graph/Nodes/VantagePoint.php b/src/Graph/Nodes/VantagePoint.php new file mode 100644 index 0000000..625797f --- /dev/null +++ b/src/Graph/Nodes/VantagePoint.php @@ -0,0 +1,70 @@ +columns() as $column => $values) { + if ($dataset->columnType($column)->isContinuous()) { + $center[] = Stats::mean($values); + } else { + $center[] = argmax(array_count_values($values)); + } + } + + $distances = []; + + foreach ($dataset->samples() as $sample) { + $distances[] = $kernel->compute($sample, $center); + } + + $threshold = Stats::median($distances); + + $samples = $dataset->samples(); + $labels = $dataset->labels(); + + $leftSamples = $leftLabels = $rightSamples = $rightLabels = []; + + foreach ($distances as $i => $distance) { + if ($distance <= $threshold) { + $leftSamples[] = $samples[$i]; + $leftLabels[] = $labels[$i]; + } else { + $rightSamples[] = $samples[$i]; + $rightLabels[] = $labels[$i]; + } + } + + $radius = max($distances) ?: 0.0; + + return new self($center, $radius, [ + Labeled::quick($leftSamples, $leftLabels), + Labeled::quick($rightSamples, $rightLabels), + ]); + } +} diff --git a/src/Graph/Trees/VPTree.php b/src/Graph/Trees/VPTree.php new file mode 100644 index 0000000..79dc12d --- /dev/null +++ b/src/Graph/Trees/VPTree.php @@ -0,0 +1,346 @@ +maxLeafSize = $maxLeafSize; + $this->kernel = $kernel ?? new Euclidean(); + } + + /** + * Return the height of the tree i.e. the number of levels. + * + * @return int + */ + public function height() : int + { + return $this->root ? $this->root->height() : 0; + } + + /** + * Return the balance factor of the tree. A balanced tree will have + * a factor of 0 whereas an imbalanced tree will either be positive + * or negative indicating the direction and degree of the imbalance. + * + * @return int + */ + public function balance() : int + { + return $this->root ? $this->root->balance() : 0; + } + + /** + * Is the tree bare? + * + * @return bool + */ + public function bare() : bool + { + return !$this->root; + } + + /** + * Return the distance kernel used to compute distances. + * + * @return \Rubix\ML\Kernels\Distance\Distance + */ + public function kernel() : Distance + { + return $this->kernel; + } + + /** + * Insert a root node and recursively split the dataset until a terminating + * condition is met. + * + * @internal + * + * @param \Rubix\ML\Datasets\Labeled $dataset + * @throws \Rubix\ML\Exceptions\InvalidArgumentException + */ + public function grow(Labeled $dataset) : void + { + if (!$dataset instanceof Labeled) { + throw new InvalidArgumentException('Tree requires a labeled dataset.'); + } + + $this->root = VantagePoint::split($dataset, $this->kernel); + + $stack = [$this->root]; + + while ($current = array_pop($stack)) { + [$left, $right] = $current->groups(); + + $current->cleanup(); + + if ($left->numRows() > $this->maxLeafSize) { + $node = VantagePoint::split($left, $this->kernel); + + if ($node->isPoint()) { + $current->attachLeft(Clique::terminate($left, $this->kernel)); + } else { + $current->attachLeft($node); + + $stack[] = $node; + } + } elseif (!$left->empty()) { + $current->attachLeft(Clique::terminate($left, $this->kernel)); + } + + if ($right->numRows() > $this->maxLeafSize) { + $node = VantagePoint::split($right, $this->kernel); + + $current->attachRight($node); + + $stack[] = $node; + } elseif (!$right->empty()) { + $current->attachRight(Clique::terminate($right, $this->kernel)); + } + } + } + + /** + * Run a k nearest neighbors search and return the samples, labels, and + * distances in a 3-tuple. + * + * @param (string|int|float)[] $sample + * @param int $k + * @throws \InvalidArgumentException + * @return array[] + */ + public function nearest(array $sample, int $k = 1) : array + { + if ($k < 1) { + throw new InvalidArgumentException('K must be' + . " greater than 0, $k given."); + } + + $visited = new SplObjectStorage(); + + $stack = $this->path($sample); + + $samples = $labels = $distances = []; + + while ($current = array_pop($stack)) { + if ($current instanceof VantagePoint) { + $radius = $distances[$k - 1] ?? INF; + + foreach ($current->children() as $child) { + if (!$visited->contains($child)) { + if ($child instanceof Hypersphere) { + $distance = $this->kernel->compute($sample, $child->center()); + + if ($distance - $child->radius() < $radius) { + $stack[] = $child; + + continue; + } + } + + $visited->attach($child); + } + } + + $visited->attach($current); + + continue; + } + + if ($current instanceof Clique) { + $dataset = $current->dataset(); + + foreach ($dataset->samples() as $neighbor) { + $distances[] = $this->kernel->compute($sample, $neighbor); + } + + $samples = array_merge($samples, $dataset->samples()); + $labels = array_merge($labels, $dataset->labels()); + + array_multisort($distances, $samples, $labels); + + if (count($samples) > $k) { + $samples = array_slice($samples, 0, $k); + $labels = array_slice($labels, 0, $k); + $distances = array_slice($distances, 0, $k); + } + + $visited->attach($current); + } + } + + return [$samples, $labels, $distances]; + } + + /** + * Return all samples, labels, and distances within a given radius of a + * sample. + * + * @param (string|int|float)[] $sample + * @param float $radius + * @throws \InvalidArgumentException + * @throws \RuntimeException + * @return array[] + */ + public function range(array $sample, float $radius) : array + { + if ($radius <= 0.0) { + throw new InvalidArgumentException('Radius must be' + . " greater than 0, $radius given."); + } + + $samples = $labels = $distances = []; + + $stack = [$this->root]; + + while ($current = array_pop($stack)) { + if ($current instanceof VantagePoint) { + foreach ($current->children() as $child) { + if ($child instanceof Hypersphere) { + $distance = $this->kernel->compute($sample, $child->center()); + + if ($distance - $child->radius() < $radius) { + $stack[] = $child; + } + } + } + + continue; + } + + if ($current instanceof Clique) { + $dataset = $current->dataset(); + + foreach ($dataset->samples() as $i => $neighbor) { + $distance = $this->kernel->compute($sample, $neighbor); + + if ($distance <= $radius) { + $samples[] = $neighbor; + $labels[] = $dataset->label($i); + $distances[] = $distance; + } + } + } + } + + return [$samples, $labels, $distances]; + } + + /** + * Destroy the tree. + */ + public function destroy() : void + { + unset($this->root); + } + + /** + * Return the path of a sample taken from the root node to a leaf node + * in an array. + * + * @param (string|int|float)[] $sample + * @return mixed[] + */ + protected function path(array $sample) : array + { + $current = $this->root; + + $path = []; + + while ($current) { + $path[] = $current; + + if ($current instanceof VantagePoint) { + $left = $current->left(); + $right = $current->right(); + + if ($left instanceof Hypersphere) { + $distance = $this->kernel->compute($sample, $left->center()); + + if ($distance <= $left->radius()) { + $current = $left; + } else { + $current = $right; + } + } + + continue; + } + + break; + } + + return $path; + } + + /** + * Return the string representation of the object. + * + * @return string + */ + public function __toString() : string + { + return "VP Tree (max_leaf_size: {$this->maxLeafSize}, kernel: {$this->kernel})"; + } +} diff --git a/tests/Graph/Trees/VPTreeTest.php b/tests/Graph/Trees/VPTreeTest.php new file mode 100644 index 0000000..49c96f7 --- /dev/null +++ b/tests/Graph/Trees/VPTreeTest.php @@ -0,0 +1,108 @@ +generator = new Agglomerate([ + 'east' => new Blob([5, -2, -2]), + 'west' => new Blob([0, 5, -3]), + ], [0.5, 0.5]); + + $this->tree = new VPTree(20, new Euclidean()); + + srand(self::RANDOM_SEED); + } + + /** + * @test + */ + public function build() : void + { + $this->assertInstanceOf(VPTree::class, $this->tree); + $this->assertInstanceOf(Spatial::class, $this->tree); + $this->assertInstanceOf(BinaryTree::class, $this->tree); + $this->assertInstanceOf(Tree::class, $this->tree); + } + + /** + * @test + */ + public function growNeighborsRange() : void + { + $this->tree->grow($this->generator->generate(self::DATASET_SIZE)); + + $this->assertGreaterThan(2, $this->tree->height()); + + $sample = $this->generator->generate(1)->sample(0); + + [$samples, $labels, $distances] = $this->tree->nearest($sample, 5); + + $this->assertCount(5, $samples); + $this->assertCount(5, $labels); + $this->assertCount(5, $distances); + + $this->assertCount(1, array_unique($labels)); + + [$samples, $labels, $distances] = $this->tree->range($sample, 4.3); + + $this->assertCount(50, $samples); + $this->assertCount(50, $labels); + $this->assertCount(50, $distances); + + $this->assertCount(1, array_unique($labels)); + } + + /** + * @test + */ + public function growWithSameSamples() : void + { + $generator = new Agglomerate([ + 'east' => new Blob([5, -2, 10], 0.0), + ]); + + $dataset = $generator->generate(self::DATASET_SIZE); + + $this->tree->grow($dataset); + + $this->assertEquals(2, $this->tree->height()); + } + + protected function assertPreConditions() : void + { + $this->assertEquals(0, $this->tree->height()); + } +} From 4980a8eceab6ec18e3700a9b27d13afa0c10d24a Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Mon, 28 Dec 2020 22:08:15 -0600 Subject: [PATCH 16/20] Add Vantage Point node test --- phpunit.xml | 3 + src/Graph/Nodes/VantagePoint.php | 89 +++++++++++++++++++++- tests/Graph/Nodes/VantagePointTest.php | 100 +++++++++++++++++++++++++ 3 files changed, 189 insertions(+), 3 deletions(-) create mode 100644 tests/Graph/Nodes/VantagePointTest.php diff --git a/phpunit.xml b/phpunit.xml index 74a40ec..65cfb4e 100644 --- a/phpunit.xml +++ b/phpunit.xml @@ -13,6 +13,9 @@ tests + + tests/Graph + tests/NeuralNet diff --git a/src/Graph/Nodes/VantagePoint.php b/src/Graph/Nodes/VantagePoint.php index 625797f..f4dd813 100644 --- a/src/Graph/Nodes/VantagePoint.php +++ b/src/Graph/Nodes/VantagePoint.php @@ -5,6 +5,7 @@ use Rubix\ML\Datasets\Labeled; use Rubix\ML\Other\Helpers\Stats; use Rubix\ML\Kernels\Distance\Distance; +use Rubix\ML\Graph\Nodes\Traits\HasBinaryChildren; use function Rubix\ML\argmax; @@ -15,11 +16,33 @@ * @package Rubix/ML * @author Andrew DalPino */ -class VantagePoint extends Ball +class VantagePoint implements BinaryNode, Hypersphere { + use HasBinaryChildren; + + /** + * The center or multivariate mean of the centroid. + * + * @var list + */ + protected $center; + + /** + * The radius of the centroid. + * + * @var float + */ + protected $radius; + + /** + * The left and right splits of the training data. + * + * @var array{Labeled,Labeled} + */ + protected $groups; + /** - * Factory method to build a hypersphere by splitting the dataset into - * left and right clusters. + * Factory method to build a hypersphere by splitting the dataset into left and right clusters. * * @param \Rubix\ML\Datasets\Labeled $dataset * @param \Rubix\ML\Kernels\Distance\Distance $kernel @@ -67,4 +90,64 @@ public static function split(Labeled $dataset, Distance $kernel) : self Labeled::quick($rightSamples, $rightLabels), ]); } + + /** + * @param list $center + * @param float $radius + * @param array{Labeled,Labeled} $groups + */ + public function __construct(array $center, float $radius, array $groups) + { + $this->center = $center; + $this->radius = $radius; + $this->groups = $groups; + } + + /** + * Return the center vector. + * + * @return list + */ + public function center() : array + { + return $this->center; + } + + /** + * Return the radius of the centroid. + * + * @return float + */ + public function radius() : float + { + return $this->radius; + } + + /** + * Return the left and right splits of the training data. + * + * @return array{Labeled,Labeled} + */ + public function groups() : array + { + return $this->groups; + } + + /** + * Does the hypersphere reduce to a single point? + * + * @return bool + */ + public function isPoint() : bool + { + return $this->radius === 0.0; + } + + /** + * Remove the left and right splits of the training data. + */ + public function cleanup() : void + { + unset($this->groups); + } } diff --git a/tests/Graph/Nodes/VantagePointTest.php b/tests/Graph/Nodes/VantagePointTest.php new file mode 100644 index 0000000..baf503f --- /dev/null +++ b/tests/Graph/Nodes/VantagePointTest.php @@ -0,0 +1,100 @@ +node = new VantagePoint(self::CENTER, self::RADIUS, $groups); + } + + /** + * @test + */ + public function build() : void + { + $this->assertInstanceOf(VantagePoint::class, $this->node); + $this->assertInstanceOf(Hypersphere::class, $this->node); + $this->assertInstanceOf(BinaryNode::class, $this->node); + $this->assertInstanceOf(Node::class, $this->node); + } + + /** + * @test + */ + public function split() : void + { + $dataset = Labeled::quick(self::SAMPLES, self::LABELS); + + $node = VantagePoint::split($dataset, new Euclidean()); + + $this->assertEquals(self::CENTER, $node->center()); + $this->assertEquals(self::RADIUS, $node->radius()); + } + + /** + * @test + */ + public function center() : void + { + $this->assertSame(self::CENTER, $this->node->center()); + } + + /** + * @test + */ + public function radius() : void + { + $this->assertSame(self::RADIUS, $this->node->radius()); + } + + /** + * @test + */ + public function groups() : void + { + $expected = [ + Labeled::quick([self::SAMPLES[0]], [self::LABELS[0]]), + Labeled::quick([self::SAMPLES[1]], [self::LABELS[1]]), + ]; + + $this->assertEquals($expected, $this->node->groups()); + } +} From e43b5043c7dc5a4378a194eb11cbb4c4d23575b4 Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Thu, 31 Dec 2020 17:26:54 -0600 Subject: [PATCH 17/20] Move Flysystem Persister to main repo --- CHANGELOG.md | 1 + README.md | 2 +- docs/persisters/flysystem.md | 29 ----- src/Other/Helpers/Reflection.php | 43 ------- src/Persisters/Flysystem.php | 163 ------------------------- tests/Persisters/FlysystemTest.php | 187 ----------------------------- 6 files changed, 2 insertions(+), 423 deletions(-) delete mode 100644 docs/persisters/flysystem.md delete mode 100644 src/Other/Helpers/Reflection.php delete mode 100644 src/Persisters/Flysystem.php delete mode 100644 tests/Persisters/FlysystemTest.php diff --git a/CHANGELOG.md b/CHANGELOG.md index 9330dbb..fd32178 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ - Added Levenshtein distance kernel - Move K Best Selector to main repository - Added custom exceptions from the main repo + - Moved Flysystem Persister over to main repo - 0.2.1-beta - Implemented K Best feature selector diff --git a/README.md b/README.md index 706baa7..c4add68 100644 --- a/README.md +++ b/README.md @@ -14,4 +14,4 @@ $ composer require rubix/extras - [Bzip2 extension](https://www.php.net/manual/en/book.bzip2.php) for Bzip2 compression ## License -[MIT](https://github.com/RubixML/Extras/blob/master/LICENSE.md) +The code is licensed [MIT](LICENSE) and the documentation is licensed [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/). diff --git a/docs/persisters/flysystem.md b/docs/persisters/flysystem.md deleted file mode 100644 index 3f6c59b..0000000 --- a/docs/persisters/flysystem.md +++ /dev/null @@ -1,29 +0,0 @@ -[source] - -# Flysystem -[Flysystem](https://flysystem.thephpleague.com) is a filesystem library providing a unified storage interface and abstraction layer. It enables access to many different storage backends such as Local, Amazon S3, FTP, and more. - -> **Note:** The Flysystem persister is designed to work with Flysystem version 2.0. - -## Parameters -| # | Param | Default | Type | Description | -|---|---|---|---|---| -| 1 | path | | string | The path to the persistable object file on the filesystem. | -| 2 | filesystem | | FilesystemOperator | The Flysystem filesystem operator responsible for read and write operations. | -| 3 | history | false | bool | Should we keep a history of past saves? | -| 4 | serializer | Native | Serializer | The serializer used to convert to and from storage format. | - -## Example -```php -use League\Flysystem\Filesystem; -use League\Flysystem\Local\LocalFilesystemAdapter; -use Rubix\ML\Persisters\Flysystem; -use Rubix\ML\Persisters\Serializers\Native; - -$filesystem = new Filesystem(new LocalFilesystemAdapter('/path/to/')); - -$persister = new Flysystem('example.model', $filesystem, true, new Native()); -``` - -## Additional Methods -This persister does not have any additional methods. diff --git a/src/Other/Helpers/Reflection.php b/src/Other/Helpers/Reflection.php deleted file mode 100644 index c474c4c..0000000 --- a/src/Other/Helpers/Reflection.php +++ /dev/null @@ -1,43 +0,0 @@ -getProperty($name); - $property->setAccessible(true); - - return $property->getValue($subject); - } catch (ReflectionException $e) { - return $fallback; - } - } -} diff --git a/src/Persisters/Flysystem.php b/src/Persisters/Flysystem.php deleted file mode 100644 index 5f725e2..0000000 --- a/src/Persisters/Flysystem.php +++ /dev/null @@ -1,163 +0,0 @@ - **Note:** The Flysystem persister is designed to work with Flysystem version 2.0. - * - * @see https://flysystem.thephpleague.com - * - * @category Machine Learning - * @package Rubix/ML - * @author Chris Simpson - * @author Andrew DalPino - */ -class Flysystem implements Persister -{ - /** - * The extension to give files created as part of a persistable's save history. - * - * @var string - */ - public const HISTORY_EXT = 'old'; - - /** - * The path to the model file on the filesystem. - * - * @var string - */ - protected $path; - - /** - * The filesystem implementation providing access to your backend storage. - * - * @var \League\Flysystem\FilesystemOperator - */ - protected $filesystem; - - /** - * Should we keep a history of past saves? - * - * @var bool - */ - protected $history; - - /** - * The serializer used to convert to and from serial format. - * - * @var \Rubix\ML\Persisters\Serializers\Serializer - */ - protected $serializer; - - /** - * @param string $path - * @param \League\Flysystem\FilesystemOperator $filesystem - * @param bool $history - * @param \Rubix\ML\Persisters\Serializers\Serializer|null $serializer - */ - public function __construct( - string $path, - FilesystemOperator $filesystem, - bool $history = false, - ?Serializer $serializer = null - ) { - $this->path = $path; - $this->filesystem = $filesystem; - $this->history = $history; - $this->serializer = $serializer ?? new Native(); - } - - /** - * Save the persistable object. - * - * @param \Rubix\ML\Persistable $persistable - * @throws \RuntimeException - */ - public function save(Persistable $persistable) : void - { - try { - if ($this->history and $this->filesystem->fileExists($this->path)) { - $timestamp = (string) time(); - - $filename = "{$this->path}-$timestamp." . self::HISTORY_EXT; - - $num = 0; - - while ($this->filesystem->fileExists($filename)) { - $filename = "{$this->path}-$timestamp-" . ++$num . '.' . self::HISTORY_EXT; - } - - $this->filesystem->move($this->path, $filename); - } - } catch (FilesystemException $e) { - throw new RuntimeException('Failed to create history file.'); - } - - $encoding = $this->serializer->serialize($persistable); - - if ($encoding->bytes() === 0) { - throw new RuntimeException("Cannot save empty encoding to {$this->path}"); - } - - try { - $this->filesystem->write($this->path, $encoding->data()); - } catch (FilesystemException $e) { - throw new RuntimeException('Could not write to filesystem.'); - } - } - - /** - * Load the last model that was saved. - * - * @throws \RuntimeException - * @return \Rubix\ML\Persistable - */ - public function load() : Persistable - { - try { - $data = $this->filesystem->read($this->path); - } catch (FilesystemException $e) { - throw new RuntimeException("Error reading data from {$this->path}."); - } - - $encoding = new Encoding($data); - - if ($encoding->bytes() === 0) { - throw new RuntimeException("File at {$this->path} does not contain any data."); - } - - return $this->serializer->unserialize($encoding); - } - - /** - * Return the string representation of the object. - * - * @return string - */ - public function __toString() : string - { - $params = [ - 'path' => $this->path, - 'adapter' => Reflection::property($this->filesystem, 'adapter'), - 'history' => $this->history, - 'serializer' => $this->serializer, - ]; - - return 'Flysystem (' . Params::stringify($params) . ')'; - } -} diff --git a/tests/Persisters/FlysystemTest.php b/tests/Persisters/FlysystemTest.php deleted file mode 100644 index 948c43d..0000000 --- a/tests/Persisters/FlysystemTest.php +++ /dev/null @@ -1,187 +0,0 @@ -filesystem = new Filesystem(new InMemoryFilesystemAdapter()); - - $this->persistable = new DummyClassifier(); - - $this->persister = new Flysystem(self::PATH, $this->filesystem); - } - - /** - * @test - */ - public function build() : void - { - $this->assertInstanceOf(Flysystem::class, $this->persister); - $this->assertInstanceOf(Persister::class, $this->persister); - } - - /** - * @test - */ - public function saveLoad() : void - { - $this->persister->save($this->persistable); - - $this->assertTrue($this->filesystem->fileExists(self::PATH)); - - $persistable = $this->persister->load(); - - $this->assertInstanceOf(DummyClassifier::class, $persistable); - $this->assertInstanceOf(Persistable::class, $persistable); - } - - /** - * @test - */ - public function saveMethodWhenFilesystemWriteFails() : void - { - $filesystem = $this->createMock(FilesystemOperator::class); - - $filesystem->method('write') - ->with(self::PATH) - ->willThrowException(new UnableToWriteFile()); - - $this->persister = new Flysystem(self::PATH, $filesystem); - - $this->expectException(RuntimeException::class); - - $this->persister->save($this->persistable); - } - - /** - * @test - */ - public function saveMethodWithHistoryDisabled() : void - { - $directory = dirname(self::PATH); - - $this->persister = new Flysystem(self::PATH, $this->filesystem, false); - - $this->persister->save($this->persistable); - - $this->assertCount(1, $this->filesystem->listContents($directory)); - $this->assertTrue($this->filesystem->fileExists(self::PATH)); - - $this->persister->save($this->persistable); - - $this->assertCount(1, $this->filesystem->listContents($directory)); - $this->assertTrue($this->filesystem->fileExists(self::PATH)); - } - - /** - * @test - */ - public function saveMethodWithHistoryEnabled() : void - { - $directory = dirname(self::PATH); - - $this->persister = new Flysystem(self::PATH, $this->filesystem, true); - - $this->persister->save($this->persistable); - - $this->assertTrue($this->filesystem->fileExists(self::PATH)); - - $this->persister->save($this->persistable); - - $files = $this->filesystem->listContents($directory); - - $this->assertCount(2, $files); - } - - /** - * @test - */ - public function saveMethodWhenHistoryCreationFails() : void - { - $mock = $this->createMock(FilesystemOperator::class); - - $mock->expects($this->any()) - ->method('fileExists') - ->will($this->onConsecutiveCalls(true, true, false)); - - $mock->expects($this->any()) - ->method('move') - ->willThrowException(new UnableToMoveFile()); - - $this->persister = new Flysystem(self::PATH, $mock, true); - - $this->expectException(RuntimeException::class); - - $this->persister->save($this->persistable); - } - - /** - * @test - */ - public function loadMethodWhenTargetNotExists() : void - { - $this->expectException(RuntimeException::class); - - $this->persister->load(); - } - - /** - * @test - */ - public function loadMethodWhenTargetIsEmpty() : void - { - $this->filesystem->write(self::PATH, ''); - - $this->expectException(RuntimeException::class); - - $this->persister->load(); - } - - protected function assertPreConditions() : void - { - $this->assertFalse($this->filesystem->fileExists(self::PATH)); - } -} From 39f89be891b4efc424cfb92f0e9b58fece65d38c Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Thu, 31 Dec 2020 17:33:20 -0600 Subject: [PATCH 18/20] Remove Flysystem project dependencies --- composer.json | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/composer.json b/composer.json index 89d966b..c5b5267 100644 --- a/composer.json +++ b/composer.json @@ -24,12 +24,10 @@ "php": ">=7.2", "rubix/ml": "0.3.0.x-dev", "rubix/tensor": "^2.0.4", - "wamania/php-stemmer": "^2.0", - "league/flysystem": "2.0.0-beta.3" + "wamania/php-stemmer": "^2.0" }, "require-dev": { "friendsofphp/php-cs-fixer": "2.16.*", - "league/flysystem-memory": "2.0.0-beta.3", "phpbench/phpbench": "0.17.*", "phpstan/extension-installer": "^1.0", "phpstan/phpstan": "0.12.*", From 039532c23feb2b2d122f139dc8b0d1b47d7f7e97 Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Thu, 31 Dec 2020 17:36:25 -0600 Subject: [PATCH 19/20] Add Persisters to PHP unit config --- phpunit.xml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/phpunit.xml b/phpunit.xml index 65cfb4e..82227ea 100644 --- a/phpunit.xml +++ b/phpunit.xml @@ -22,6 +22,9 @@ tests/Other + + tests/Persisters + tests/Transformers From 72f3c73b304babc2190a2ca90850bc073b59f904 Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Thu, 31 Dec 2020 20:34:02 -0600 Subject: [PATCH 20/20] Version bump --- composer.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer.json b/composer.json index c5b5267..2a10ca8 100644 --- a/composer.json +++ b/composer.json @@ -22,7 +22,7 @@ ], "require": { "php": ">=7.2", - "rubix/ml": "0.3.0.x-dev", + "rubix/ml": "0.3.0", "rubix/tensor": "^2.0.4", "wamania/php-stemmer": "^2.0" },