From 05f898792c0008aff3689bb87f834576fe8fa8a4 Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Sat, 29 Aug 2020 04:05:08 -0500 Subject: [PATCH] Add beta parameter --- docs/transformers/bm25-transformer.md | 37 +++++++++ src/Transformers/BM25Transformer.php | 96 +++++++++++++--------- tests/Transformers/BM25TransformerTest.php | 8 +- 3 files changed, 99 insertions(+), 42 deletions(-) create mode 100644 docs/transformers/bm25-transformer.md diff --git a/docs/transformers/bm25-transformer.md b/docs/transformers/bm25-transformer.md new file mode 100644 index 0000000..285f1ab --- /dev/null +++ b/docs/transformers/bm25-transformer.md @@ -0,0 +1,37 @@ +[source] + +# BM25 Transformer +BM25 is a term frequency weighting scheme that takes term frequency (TF) saturation and document length into account. + +> **Note:** This transformer assumes that its input is made up of word frequency vectors such as those produced by [Word Count Vectorizer](word-count-vectorizer.md). + +**Interfaces:** [Transformer](api.md#transformer), [Stateful](api.md#stateful), [Elastic](api.md#elastic) + +**Data Type Compatibility:** Continuous only + +## Parameters +| # | Param | Default | Type | Description | +|---|---|---|---|---| +| 1 | alpha | 1.2 | float | The term frequency (TF) normalization factor. | +| 2 | beta | 0.75 | float | The importance of document length in normalizing term frequency. | + +## Example +```php +use Rubix\ML\Transformers\BM25Transformer; + +$transformer = new BM25Transformer(1.2, 0.75); +``` + +## Additional Methods +Return the document frequencies calculated during fitting: +```php +public dfs() : ?array +``` + +Return the average number of tokens per document: +```php +public averageDocumentLength() : ?float +``` + +### References +>- S. Robertson et al. (2009). The Probabilistic Relevance Framework: BM25 and Beyond. \ No newline at end of file diff --git a/src/Transformers/BM25Transformer.php b/src/Transformers/BM25Transformer.php index 5ff7c13..3c4ede7 100644 --- a/src/Transformers/BM25Transformer.php +++ b/src/Transformers/BM25Transformer.php @@ -5,25 +5,23 @@ use Rubix\ML\DataType; use Rubix\ML\Datasets\Dataset; use Rubix\ML\Specifications\SamplesAreCompatibleWithTransformer; +use InvalidArgumentException; use RuntimeException; use Stringable; use function is_null; /** - * TF-IDF Transformer + * BM25 Transformer * - * Term Frequency - Inverse Document Frequency is a measure of how important - * a word is to a document. The TF-IDF value increases proportionally with - * the number of times a word appears in a document and is offset by the - * frequency of the word in the corpus. + * BM25 is a term frequency weighting scheme that takes term frequency (TF) saturation and + * document length into account. * - * > **Note**: This transformer assumes that its input is made up of word - * frequency vectors such as those created by the Word Count Vectorizer. + * > **Note**: This transformer assumes that its input is made up of term frequency vectors + * such as those created by the Word Count Vectorizer. * * References: - * [1] S. Robertson. (2003). Understanding Inverse Document Frequency: On - * theoretical arguments for IDF. + * [1] S. Robertson et al. (2009). The Probabilistic Relevance Framework: BM25 and Beyond. * * @category Machine Learning * @package Rubix/ML @@ -32,11 +30,18 @@ class BM25Transformer implements Transformer, Stateful, Elastic, Stringable { /** - * The rate at which the TF values decay. - * - * @var float + * The term frequency (TF) normalization factor. + * + * @var float + */ + protected $alpha; + + /** + * The importance of document length in normalizing term frequency. + * + * @var float */ - protected $termFrequencyDecay; + protected $beta; /** * The document frequencies of each word i.e. the number of times a word @@ -46,13 +51,6 @@ class BM25Transformer implements Transformer, Stateful, Elastic, Stringable */ protected $dfs; - /** - * The number of tokens fitted so far. - * - * @var int|null - */ - protected $tokenCount; - /** * The inverse document frequency values for each feature column. * @@ -60,6 +58,13 @@ class BM25Transformer implements Transformer, Stateful, Elastic, Stringable */ protected $idfs; + /** + * The number of tokens fitted so far. + * + * @var int|null + */ + protected $tokenCount; + /** * The number of documents (samples) that have been fitted so far. * @@ -69,23 +74,30 @@ class BM25Transformer implements Transformer, Stateful, Elastic, Stringable /** * The average token count per document. - * + * * @var float|null */ protected $averageDocumentLength; /** - * @param int $termFrequencyDecay + * @param float $alpha + * @param float $beta * @throws \InvalidArgumentException */ - public function __construct(float $termFrequencyDecay = 0.0) + public function __construct(float $alpha = 1.2, float $beta = 0.75) { - if ($termFrequencyDecay < 0.0) { + if ($alpha < 0.0) { throw new InvalidArgumentException('Term frequency decay' - . " must be greater than 0, $termFrequencyDecay given."); + . " must be greater than 0, $alpha given."); + } + + if ($beta < 0.0 or $beta > 1.0) { + throw new InvalidArgumentException('Beta must be between' + . " 0 and 1, $beta given."); } - $this->termFrequencyDecay = $termFrequencyDecay; + $this->alpha = $alpha; + $this->beta = $beta; } /** @@ -107,7 +119,7 @@ public function compatibility() : array */ public function fitted() : bool { - return isset($this->idfs); + return $this->idfs and $this->averageDocumentLength; } /** @@ -120,6 +132,16 @@ public function dfs() : ?array return $this->dfs; } + /** + * Return the average number of tokens per document. + * + * @return float|null + */ + public function averageDocumentLength() : ?float + { + return $this->averageDocumentLength; + } + /** * Fit the transformer to a dataset. * @@ -162,15 +184,15 @@ public function update(Dataset $dataset) : void $this->n += $dataset->numRows(); + $this->averageDocumentLength = $this->tokenCount / $this->n; + $idfs = []; foreach ($this->dfs as $df) { - $idfs[] = 1.0 + log($this->n / $df); + $idfs[] = log(1.0 + ($this->n - $df + 0.5) / ($df + 0.5)); } $this->idfs = $idfs; - - $this->averageDocumentLength = $this->tokenCount / $this->n; } /** @@ -186,17 +208,15 @@ public function transform(array &$samples) : void } foreach ($samples as &$sample) { - if ($this->termFrequencyDecay > 0.0) { - $delta = array_sum($sample) / $this->averageDocumentLength; + $delta = array_sum($sample) / $this->averageDocumentLength; - $delta *= $this->termFrequencyDecay; - } else { - $delta = 0.0; - } + $delta = 1.0 - $this->beta + $this->beta * $delta; + + $delta *= $this->alpha; foreach ($sample as $column => &$tf) { if ($tf > 0) { - $tf *= $tf / ($tf + $delta); + $tf /= $tf + $delta; $tf *= $this->idfs[$column]; } } @@ -210,6 +230,6 @@ public function transform(array &$samples) : void */ public function __toString() : string { - return 'BM25 TF-IDF Transformer'; + return "BM25 Transformer (alpha: {$this->alpha}, beta: {$this->beta})"; } } diff --git a/tests/Transformers/BM25TransformerTest.php b/tests/Transformers/BM25TransformerTest.php index 2e519b6..ccab0fd 100644 --- a/tests/Transformers/BM25TransformerTest.php +++ b/tests/Transformers/BM25TransformerTest.php @@ -37,7 +37,7 @@ protected function setUp() : void [0, 0, 0, 1, 2, 3, 0, 0, 4, 2, 0, 0, 1, 0, 2, 0, 1, 0, 0], ]); - $this->transformer = new BM25Transformer(0.0); + $this->transformer = new BM25Transformer(1.2, 0.75); } /** @@ -69,9 +69,9 @@ public function fitTransform() : void $this->dataset->apply($this->transformer); $outcome = [ - [1.6931471805599454, 3.8630462173553424, 0.0, 0.0, 1.2876820724517808, 0.0, 0.0, 0.0, 1.2876820724517808, 2.5753641449035616, 0.0, 2.5753641449035616, 0.0, 0.0, 0.0, 6.772588722239782, 1.2876820724517808, 0.0, 1.6931471805599454], - [0.0, 1.2876820724517808, 1.6931471805599454, 0.0, 0.0, 2.5753641449035616, 1.6931471805599454, 0.0, 0.0, 0.0, 0.0, 3.8630462173553424, 0.0, 1.6931471805599454, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.6931471805599454, 2.5753641449035616, 3.8630462173553424, 0.0, 0.0, 5.150728289807123, 2.5753641449035616, 0.0, 0.0, 1.6931471805599454, 0.0, 3.386294361119891, 0.0, 1.2876820724517808, 0.0, 0.0], + [0.2562582002070131, 0.22742881339794754, 0.0, 0.0, 0.13186359514416618, 0.0, 0.0, 0.0, 0.13186359514416618, 0.19254341937443092, 0.0, 0.19254341937443092, 0.0, 0.0, 0.0, 0.4860031535349766, 0.13186359514416618, 0.0, 0.2562582002070131], + [0.0, 0.17063795450977862, 0.3316106698128093, 0.0, 0.0, 0.23083934808978732, 0.3316106698128093, 0.0, 0.0, 0.0, 0.0, 0.26160416281731713, 0.0, 0.3316106698128093, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.2562582002070131, 0.19254341937443092, 0.22742881339794754, 0.0, 0.0, 0.25008418471976107, 0.19254341937443092, 0.0, 0.0, 0.2562582002070131, 0.0, 0.3741808347986538, 0.0, 0.13186359514416618, 0.0, 0.0], ]; $this->assertEquals($outcome, $this->dataset->samples());