Skip to content

Commit

Permalink
Implement BM25 TF-IDF Transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdalpino committed Aug 29, 2020
1 parent b0d8817 commit 3a09ee7
Show file tree
Hide file tree
Showing 4 changed files with 357 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
- Unreleased
- Implement BM25 TF-IDF Transformer
- Added Lambda function Transformer
- All objects implement Stringable interface
50 changes: 50 additions & 0 deletions benchmarks/Transformers/BM25TransformerBench.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
<?php

namespace Rubix\ML\Benchmarks\Transformers;

use Tensor\Matrix;
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\Transformers\BM25Transformer;

/**
* @Groups({"Transformers"})
* @BeforeMethods({"setUp"})
*/
class BM25TransformerBench
{
protected const NUM_SAMPLES = 10000;

/**
* @var array[]
*/
protected $aSamples;

/**
* @var array[]
*/
protected $bSamples;

public function setUp() : void
{
$mask = Matrix::rand(self::NUM_SAMPLES, 100)
->greater(0.8);

$samples = Matrix::gaussian(self::NUM_SAMPLES, 100)
->multiply($mask)
->asArray();

$this->dataset = Unlabeled::quick($samples);

$this->transformer = new BM25Transformer();
}

/**
* @Subject
* @Iterations(3)
* @OutputTimeUnit("seconds", precision=3)
*/
public function apply() : void
{
$this->dataset->apply($this->transformer);
}
}
215 changes: 215 additions & 0 deletions src/Transformers/BM25Transformer.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
<?php

namespace Rubix\ML\Transformers;

use Rubix\ML\DataType;
use Rubix\ML\Datasets\Dataset;
use Rubix\ML\Specifications\SamplesAreCompatibleWithTransformer;
use RuntimeException;
use Stringable;

use function is_null;

/**
* TF-IDF Transformer
*
* Term Frequency - Inverse Document Frequency is a measure of how important
* a word is to a document. The TF-IDF value increases proportionally with
* the number of times a word appears in a document and is offset by the
* frequency of the word in the corpus.
*
* > **Note**: This transformer assumes that its input is made up of word
* frequency vectors such as those created by the Word Count Vectorizer.
*
* References:
* [1] S. Robertson. (2003). Understanding Inverse Document Frequency: On
* theoretical arguments for IDF.
*
* @category Machine Learning
* @package Rubix/ML
* @author Andrew DalPino
*/
class BM25Transformer implements Transformer, Stateful, Elastic, Stringable
{
/**
* The rate at which the TF values decay.
*
* @var float
*/
protected $termFrequencyDecay;

/**
* The document frequencies of each word i.e. the number of times a word
* appeared in a document given the entire corpus.
*
* @var int[]|null
*/
protected $dfs;

/**
* The number of tokens fitted so far.
*
* @var int|null
*/
protected $tokenCount;

/**
* The inverse document frequency values for each feature column.
*
* @var float[]|null
*/
protected $idfs;

/**
* The number of documents (samples) that have been fitted so far.
*
* @var int|null
*/
protected $n;

/**
* The average token count per document.
*
* @var float|null
*/
protected $averageDocumentLength;

/**
* @param int $termFrequencyDecay
* @throws \InvalidArgumentException
*/
public function __construct(float $termFrequencyDecay = 0.0)
{
if ($termFrequencyDecay < 0.0) {
throw new InvalidArgumentException('Term frequency decay'
. " must be greater than 0, $termFrequencyDecay given.");
}

$this->termFrequencyDecay = $termFrequencyDecay;
}

/**
* Return the data types that this transformer is compatible with.
*
* @return \Rubix\ML\DataType[]
*/
public function compatibility() : array
{
return [
DataType::continuous(),
];
}

/**
* Is the transformer fitted?
*
* @return bool
*/
public function fitted() : bool
{
return isset($this->idfs);
}

/**
* Return the document frequencies calculated during fitting.
*
* @return int[]|null
*/
public function dfs() : ?array
{
return $this->dfs;
}

/**
* Fit the transformer to a dataset.
*
* @param \Rubix\ML\Datasets\Dataset $dataset
*/
public function fit(Dataset $dataset) : void
{
$this->dfs = array_fill(0, $dataset->numColumns(), 1);
$this->tokenCount = 0;
$this->n = 1;

$this->update($dataset);
}

/**
* Update the fitting of the transformer.
*
* @param \Rubix\ML\Datasets\Dataset $dataset
* @throws \InvalidArgumentException
*/
public function update(Dataset $dataset) : void
{
SamplesAreCompatibleWithTransformer::check($dataset, $this);

if (is_null($this->dfs) or is_null($this->n)) {
$this->fit($dataset);

return;
}

foreach ($dataset->samples() as $sample) {
foreach ($sample as $column => $tf) {
if ($tf > 0) {
++$this->dfs[$column];

$this->tokenCount += $tf;
}
}
}

$this->n += $dataset->numRows();

$idfs = [];

foreach ($this->dfs as $df) {
$idfs[] = 1.0 + log($this->n / $df);
}

$this->idfs = $idfs;

$this->averageDocumentLength = $this->tokenCount / $this->n;
}

/**
* Transform the dataset in place.
*
* @param array[] $samples
* @throws \RuntimeException
*/
public function transform(array &$samples) : void
{
if (is_null($this->idfs) or is_null($this->averageDocumentLength)) {
throw new RuntimeException('Transformer has not been fitted.');
}

foreach ($samples as &$sample) {
if ($this->termFrequencyDecay > 0.0) {
$delta = array_sum($sample) / $this->averageDocumentLength;

$delta *= $this->termFrequencyDecay;
} else {
$delta = 0.0;
}

foreach ($sample as $column => &$tf) {
if ($tf > 0) {
$tf *= $tf / ($tf + $delta);
$tf *= $this->idfs[$column];
}
}
}
}

/**
* Return the string representation of the object.
*
* @return string
*/
public function __toString() : string
{
return 'BM25 TF-IDF Transformer';
}
}
91 changes: 91 additions & 0 deletions tests/Transformers/BM25TransformerTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
<?php

namespace Rubix\ML\Tests\Transformers;

use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\Transformers\Elastic;
use Rubix\ML\Transformers\Stateful;
use Rubix\ML\Transformers\Transformer;
use Rubix\ML\Transformers\BM25Transformer;
use PHPUnit\Framework\TestCase;
use RuntimeException;

/**
* @group Transformers
* @covers \Rubix\ML\Transformers\BM25Transformer
*/
class BM25TransformerTest extends TestCase
{
/**
* @var \Rubix\ML\Datasets\Unlabeled
*/
protected $dataset;

/**
* @var \Rubix\ML\Transformers\BM25Transformer
*/
protected $transformer;

/**
* @before
*/
protected function setUp() : void
{
$this->dataset = new Unlabeled([
[1, 3, 0, 0, 1, 0, 0, 0, 1, 2, 0, 2, 0, 0, 0, 4, 1, 0, 1],
[0, 1, 1, 0, 0, 2, 1, 0, 0, 0, 0, 3, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 2, 3, 0, 0, 4, 2, 0, 0, 1, 0, 2, 0, 1, 0, 0],
]);

$this->transformer = new BM25Transformer(0.0);
}

/**
* @test
*/
public function build() : void
{
$this->assertInstanceOf(BM25Transformer::class, $this->transformer);
$this->assertInstanceOf(Transformer::class, $this->transformer);
$this->assertInstanceOf(Stateful::class, $this->transformer);
$this->assertInstanceOf(Elastic::class, $this->transformer);
}

/**
* @test
*/
public function fitTransform() : void
{
$this->transformer->fit($this->dataset);

$this->assertTrue($this->transformer->fitted());

$dfs = $this->transformer->dfs();

$this->assertIsArray($dfs);
$this->assertCount(19, $dfs);
$this->assertContainsOnly('int', $dfs);

$this->dataset->apply($this->transformer);

$outcome = [
[1.6931471805599454, 3.8630462173553424, 0.0, 0.0, 1.2876820724517808, 0.0, 0.0, 0.0, 1.2876820724517808, 2.5753641449035616, 0.0, 2.5753641449035616, 0.0, 0.0, 0.0, 6.772588722239782, 1.2876820724517808, 0.0, 1.6931471805599454],
[0.0, 1.2876820724517808, 1.6931471805599454, 0.0, 0.0, 2.5753641449035616, 1.6931471805599454, 0.0, 0.0, 0.0, 0.0, 3.8630462173553424, 0.0, 1.6931471805599454, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.6931471805599454, 2.5753641449035616, 3.8630462173553424, 0.0, 0.0, 5.150728289807123, 2.5753641449035616, 0.0, 0.0, 1.6931471805599454, 0.0, 3.386294361119891, 0.0, 1.2876820724517808, 0.0, 0.0],
];

$this->assertEquals($outcome, $this->dataset->samples());
}

/**
* @test
*/
public function transformUnfitted() : void
{
$this->expectException(RuntimeException::class);

$samples = $this->dataset->samples();

$this->transformer->transform($samples);
}
}

0 comments on commit 3a09ee7

Please sign in to comment.