Skip to content

Commit

Permalink
Add beta parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdalpino committed Aug 29, 2020
1 parent 3a09ee7 commit 05f8987
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 42 deletions.
37 changes: 37 additions & 0 deletions docs/transformers/bm25-transformer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
<span style="float:right;"><a href="https://github.com/RubixML/RubixML/blob/master/src/Transformers/TfIdfTransformer.php">[source]</a></span>

# BM25 Transformer
BM25 is a term frequency weighting scheme that takes term frequency (TF) saturation and document length into account.

> **Note:** This transformer assumes that its input is made up of word frequency vectors such as those produced by [Word Count Vectorizer](word-count-vectorizer.md).
**Interfaces:** [Transformer](api.md#transformer), [Stateful](api.md#stateful), [Elastic](api.md#elastic)

**Data Type Compatibility:** Continuous only

## Parameters
| # | Param | Default | Type | Description |
|---|---|---|---|---|
| 1 | alpha | 1.2 | float | The term frequency (TF) normalization factor. |
| 2 | beta | 0.75 | float | The importance of document length in normalizing term frequency. |

## Example
```php
use Rubix\ML\Transformers\BM25Transformer;

$transformer = new BM25Transformer(1.2, 0.75);
```

## Additional Methods
Return the document frequencies calculated during fitting:
```php
public dfs() : ?array
```

Return the average number of tokens per document:
```php
public averageDocumentLength() : ?float
```

### References
>- S. Robertson et al. (2009). The Probabilistic Relevance Framework: BM25 and Beyond.
96 changes: 58 additions & 38 deletions src/Transformers/BM25Transformer.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,23 @@
use Rubix\ML\DataType;
use Rubix\ML\Datasets\Dataset;
use Rubix\ML\Specifications\SamplesAreCompatibleWithTransformer;
use InvalidArgumentException;
use RuntimeException;
use Stringable;

use function is_null;

/**
* TF-IDF Transformer
* BM25 Transformer
*
* Term Frequency - Inverse Document Frequency is a measure of how important
* a word is to a document. The TF-IDF value increases proportionally with
* the number of times a word appears in a document and is offset by the
* frequency of the word in the corpus.
* BM25 is a term frequency weighting scheme that takes term frequency (TF) saturation and
* document length into account.
*
* > **Note**: This transformer assumes that its input is made up of word
* frequency vectors such as those created by the Word Count Vectorizer.
* > **Note**: This transformer assumes that its input is made up of term frequency vectors
* such as those created by the Word Count Vectorizer.
*
* References:
* [1] S. Robertson. (2003). Understanding Inverse Document Frequency: On
* theoretical arguments for IDF.
* [1] S. Robertson et al. (2009). The Probabilistic Relevance Framework: BM25 and Beyond.
*
* @category Machine Learning
* @package Rubix/ML
Expand All @@ -32,11 +30,18 @@
class BM25Transformer implements Transformer, Stateful, Elastic, Stringable
{
/**
* The rate at which the TF values decay.
*
* @var float
* The term frequency (TF) normalization factor.
*
* @var float
*/
protected $alpha;

/**
* The importance of document length in normalizing term frequency.
*
* @var float
*/
protected $termFrequencyDecay;
protected $beta;

/**
* The document frequencies of each word i.e. the number of times a word
Expand All @@ -46,20 +51,20 @@ class BM25Transformer implements Transformer, Stateful, Elastic, Stringable
*/
protected $dfs;

/**
* The number of tokens fitted so far.
*
* @var int|null
*/
protected $tokenCount;

/**
* The inverse document frequency values for each feature column.
*
* @var float[]|null
*/
protected $idfs;

/**
* The number of tokens fitted so far.
*
* @var int|null
*/
protected $tokenCount;

/**
* The number of documents (samples) that have been fitted so far.
*
Expand All @@ -69,23 +74,30 @@ class BM25Transformer implements Transformer, Stateful, Elastic, Stringable

/**
* The average token count per document.
*
*
* @var float|null
*/
protected $averageDocumentLength;

/**
* @param int $termFrequencyDecay
* @param float $alpha
* @param float $beta
* @throws \InvalidArgumentException
*/
public function __construct(float $termFrequencyDecay = 0.0)
public function __construct(float $alpha = 1.2, float $beta = 0.75)
{
if ($termFrequencyDecay < 0.0) {
if ($alpha < 0.0) {
throw new InvalidArgumentException('Term frequency decay'
. " must be greater than 0, $termFrequencyDecay given.");
. " must be greater than 0, $alpha given.");
}

if ($beta < 0.0 or $beta > 1.0) {
throw new InvalidArgumentException('Beta must be between'
. " 0 and 1, $beta given.");
}

$this->termFrequencyDecay = $termFrequencyDecay;
$this->alpha = $alpha;
$this->beta = $beta;
}

/**
Expand All @@ -107,7 +119,7 @@ public function compatibility() : array
*/
public function fitted() : bool
{
return isset($this->idfs);
return $this->idfs and $this->averageDocumentLength;
}

/**
Expand All @@ -120,6 +132,16 @@ public function dfs() : ?array
return $this->dfs;
}

/**
* Return the average number of tokens per document.
*
* @return float|null
*/
public function averageDocumentLength() : ?float
{
return $this->averageDocumentLength;
}

/**
* Fit the transformer to a dataset.
*
Expand Down Expand Up @@ -162,15 +184,15 @@ public function update(Dataset $dataset) : void

$this->n += $dataset->numRows();

$this->averageDocumentLength = $this->tokenCount / $this->n;

$idfs = [];

foreach ($this->dfs as $df) {
$idfs[] = 1.0 + log($this->n / $df);
$idfs[] = log(1.0 + ($this->n - $df + 0.5) / ($df + 0.5));
}

$this->idfs = $idfs;

$this->averageDocumentLength = $this->tokenCount / $this->n;
}

/**
Expand All @@ -186,17 +208,15 @@ public function transform(array &$samples) : void
}

foreach ($samples as &$sample) {
if ($this->termFrequencyDecay > 0.0) {
$delta = array_sum($sample) / $this->averageDocumentLength;
$delta = array_sum($sample) / $this->averageDocumentLength;

$delta *= $this->termFrequencyDecay;
} else {
$delta = 0.0;
}
$delta = 1.0 - $this->beta + $this->beta * $delta;

$delta *= $this->alpha;

foreach ($sample as $column => &$tf) {
if ($tf > 0) {
$tf *= $tf / ($tf + $delta);
$tf /= $tf + $delta;
$tf *= $this->idfs[$column];
}
}
Expand All @@ -210,6 +230,6 @@ public function transform(array &$samples) : void
*/
public function __toString() : string
{
return 'BM25 TF-IDF Transformer';
return "BM25 Transformer (alpha: {$this->alpha}, beta: {$this->beta})";
}
}
8 changes: 4 additions & 4 deletions tests/Transformers/BM25TransformerTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ protected function setUp() : void
[0, 0, 0, 1, 2, 3, 0, 0, 4, 2, 0, 0, 1, 0, 2, 0, 1, 0, 0],
]);

$this->transformer = new BM25Transformer(0.0);
$this->transformer = new BM25Transformer(1.2, 0.75);
}

/**
Expand Down Expand Up @@ -69,9 +69,9 @@ public function fitTransform() : void
$this->dataset->apply($this->transformer);

$outcome = [
[1.6931471805599454, 3.8630462173553424, 0.0, 0.0, 1.2876820724517808, 0.0, 0.0, 0.0, 1.2876820724517808, 2.5753641449035616, 0.0, 2.5753641449035616, 0.0, 0.0, 0.0, 6.772588722239782, 1.2876820724517808, 0.0, 1.6931471805599454],
[0.0, 1.2876820724517808, 1.6931471805599454, 0.0, 0.0, 2.5753641449035616, 1.6931471805599454, 0.0, 0.0, 0.0, 0.0, 3.8630462173553424, 0.0, 1.6931471805599454, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.6931471805599454, 2.5753641449035616, 3.8630462173553424, 0.0, 0.0, 5.150728289807123, 2.5753641449035616, 0.0, 0.0, 1.6931471805599454, 0.0, 3.386294361119891, 0.0, 1.2876820724517808, 0.0, 0.0],
[0.2562582002070131, 0.22742881339794754, 0.0, 0.0, 0.13186359514416618, 0.0, 0.0, 0.0, 0.13186359514416618, 0.19254341937443092, 0.0, 0.19254341937443092, 0.0, 0.0, 0.0, 0.4860031535349766, 0.13186359514416618, 0.0, 0.2562582002070131],
[0.0, 0.17063795450977862, 0.3316106698128093, 0.0, 0.0, 0.23083934808978732, 0.3316106698128093, 0.0, 0.0, 0.0, 0.0, 0.26160416281731713, 0.0, 0.3316106698128093, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.2562582002070131, 0.19254341937443092, 0.22742881339794754, 0.0, 0.0, 0.25008418471976107, 0.19254341937443092, 0.0, 0.0, 0.2562582002070131, 0.0, 0.3741808347986538, 0.0, 0.13186359514416618, 0.0, 0.0],
];

$this->assertEquals($outcome, $this->dataset->samples());
Expand Down

0 comments on commit 05f8987

Please sign in to comment.