-
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b0d8817
commit 3a09ee7
Showing
4 changed files
with
357 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
- Unreleased | ||
- Implement BM25 TF-IDF Transformer | ||
- Added Lambda function Transformer | ||
- All objects implement Stringable interface |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
<?php | ||
|
||
namespace Rubix\ML\Benchmarks\Transformers; | ||
|
||
use Tensor\Matrix; | ||
use Rubix\ML\Datasets\Unlabeled; | ||
use Rubix\ML\Transformers\BM25Transformer; | ||
|
||
/** | ||
* @Groups({"Transformers"}) | ||
* @BeforeMethods({"setUp"}) | ||
*/ | ||
class BM25TransformerBench | ||
{ | ||
protected const NUM_SAMPLES = 10000; | ||
|
||
/** | ||
* @var array[] | ||
*/ | ||
protected $aSamples; | ||
|
||
/** | ||
* @var array[] | ||
*/ | ||
protected $bSamples; | ||
|
||
public function setUp() : void | ||
{ | ||
$mask = Matrix::rand(self::NUM_SAMPLES, 100) | ||
->greater(0.8); | ||
|
||
$samples = Matrix::gaussian(self::NUM_SAMPLES, 100) | ||
->multiply($mask) | ||
->asArray(); | ||
|
||
$this->dataset = Unlabeled::quick($samples); | ||
|
||
$this->transformer = new BM25Transformer(); | ||
} | ||
|
||
/** | ||
* @Subject | ||
* @Iterations(3) | ||
* @OutputTimeUnit("seconds", precision=3) | ||
*/ | ||
public function apply() : void | ||
{ | ||
$this->dataset->apply($this->transformer); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
<?php | ||
|
||
namespace Rubix\ML\Transformers; | ||
|
||
use Rubix\ML\DataType; | ||
use Rubix\ML\Datasets\Dataset; | ||
use Rubix\ML\Specifications\SamplesAreCompatibleWithTransformer; | ||
use RuntimeException; | ||
use Stringable; | ||
|
||
use function is_null; | ||
|
||
/** | ||
* TF-IDF Transformer | ||
* | ||
* Term Frequency - Inverse Document Frequency is a measure of how important | ||
* a word is to a document. The TF-IDF value increases proportionally with | ||
* the number of times a word appears in a document and is offset by the | ||
* frequency of the word in the corpus. | ||
* | ||
* > **Note**: This transformer assumes that its input is made up of word | ||
* frequency vectors such as those created by the Word Count Vectorizer. | ||
* | ||
* References: | ||
* [1] S. Robertson. (2003). Understanding Inverse Document Frequency: On | ||
* theoretical arguments for IDF. | ||
* | ||
* @category Machine Learning | ||
* @package Rubix/ML | ||
* @author Andrew DalPino | ||
*/ | ||
class BM25Transformer implements Transformer, Stateful, Elastic, Stringable | ||
{ | ||
/** | ||
* The rate at which the TF values decay. | ||
* | ||
* @var float | ||
*/ | ||
protected $termFrequencyDecay; | ||
|
||
/** | ||
* The document frequencies of each word i.e. the number of times a word | ||
* appeared in a document given the entire corpus. | ||
* | ||
* @var int[]|null | ||
*/ | ||
protected $dfs; | ||
|
||
/** | ||
* The number of tokens fitted so far. | ||
* | ||
* @var int|null | ||
*/ | ||
protected $tokenCount; | ||
|
||
/** | ||
* The inverse document frequency values for each feature column. | ||
* | ||
* @var float[]|null | ||
*/ | ||
protected $idfs; | ||
|
||
/** | ||
* The number of documents (samples) that have been fitted so far. | ||
* | ||
* @var int|null | ||
*/ | ||
protected $n; | ||
|
||
/** | ||
* The average token count per document. | ||
* | ||
* @var float|null | ||
*/ | ||
protected $averageDocumentLength; | ||
|
||
/** | ||
* @param int $termFrequencyDecay | ||
* @throws \InvalidArgumentException | ||
*/ | ||
public function __construct(float $termFrequencyDecay = 0.0) | ||
{ | ||
if ($termFrequencyDecay < 0.0) { | ||
throw new InvalidArgumentException('Term frequency decay' | ||
. " must be greater than 0, $termFrequencyDecay given."); | ||
} | ||
|
||
$this->termFrequencyDecay = $termFrequencyDecay; | ||
} | ||
|
||
/** | ||
* Return the data types that this transformer is compatible with. | ||
* | ||
* @return \Rubix\ML\DataType[] | ||
*/ | ||
public function compatibility() : array | ||
{ | ||
return [ | ||
DataType::continuous(), | ||
]; | ||
} | ||
|
||
/** | ||
* Is the transformer fitted? | ||
* | ||
* @return bool | ||
*/ | ||
public function fitted() : bool | ||
{ | ||
return isset($this->idfs); | ||
} | ||
|
||
/** | ||
* Return the document frequencies calculated during fitting. | ||
* | ||
* @return int[]|null | ||
*/ | ||
public function dfs() : ?array | ||
{ | ||
return $this->dfs; | ||
} | ||
|
||
/** | ||
* Fit the transformer to a dataset. | ||
* | ||
* @param \Rubix\ML\Datasets\Dataset $dataset | ||
*/ | ||
public function fit(Dataset $dataset) : void | ||
{ | ||
$this->dfs = array_fill(0, $dataset->numColumns(), 1); | ||
$this->tokenCount = 0; | ||
$this->n = 1; | ||
|
||
$this->update($dataset); | ||
} | ||
|
||
/** | ||
* Update the fitting of the transformer. | ||
* | ||
* @param \Rubix\ML\Datasets\Dataset $dataset | ||
* @throws \InvalidArgumentException | ||
*/ | ||
public function update(Dataset $dataset) : void | ||
{ | ||
SamplesAreCompatibleWithTransformer::check($dataset, $this); | ||
|
||
if (is_null($this->dfs) or is_null($this->n)) { | ||
$this->fit($dataset); | ||
|
||
return; | ||
} | ||
|
||
foreach ($dataset->samples() as $sample) { | ||
foreach ($sample as $column => $tf) { | ||
if ($tf > 0) { | ||
++$this->dfs[$column]; | ||
|
||
$this->tokenCount += $tf; | ||
} | ||
} | ||
} | ||
|
||
$this->n += $dataset->numRows(); | ||
|
||
$idfs = []; | ||
|
||
foreach ($this->dfs as $df) { | ||
$idfs[] = 1.0 + log($this->n / $df); | ||
} | ||
|
||
$this->idfs = $idfs; | ||
|
||
$this->averageDocumentLength = $this->tokenCount / $this->n; | ||
} | ||
|
||
/** | ||
* Transform the dataset in place. | ||
* | ||
* @param array[] $samples | ||
* @throws \RuntimeException | ||
*/ | ||
public function transform(array &$samples) : void | ||
{ | ||
if (is_null($this->idfs) or is_null($this->averageDocumentLength)) { | ||
throw new RuntimeException('Transformer has not been fitted.'); | ||
} | ||
|
||
foreach ($samples as &$sample) { | ||
if ($this->termFrequencyDecay > 0.0) { | ||
$delta = array_sum($sample) / $this->averageDocumentLength; | ||
|
||
$delta *= $this->termFrequencyDecay; | ||
} else { | ||
$delta = 0.0; | ||
} | ||
|
||
foreach ($sample as $column => &$tf) { | ||
if ($tf > 0) { | ||
$tf *= $tf / ($tf + $delta); | ||
$tf *= $this->idfs[$column]; | ||
} | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Return the string representation of the object. | ||
* | ||
* @return string | ||
*/ | ||
public function __toString() : string | ||
{ | ||
return 'BM25 TF-IDF Transformer'; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
<?php | ||
|
||
namespace Rubix\ML\Tests\Transformers; | ||
|
||
use Rubix\ML\Datasets\Unlabeled; | ||
use Rubix\ML\Transformers\Elastic; | ||
use Rubix\ML\Transformers\Stateful; | ||
use Rubix\ML\Transformers\Transformer; | ||
use Rubix\ML\Transformers\BM25Transformer; | ||
use PHPUnit\Framework\TestCase; | ||
use RuntimeException; | ||
|
||
/** | ||
* @group Transformers | ||
* @covers \Rubix\ML\Transformers\BM25Transformer | ||
*/ | ||
class BM25TransformerTest extends TestCase | ||
{ | ||
/** | ||
* @var \Rubix\ML\Datasets\Unlabeled | ||
*/ | ||
protected $dataset; | ||
|
||
/** | ||
* @var \Rubix\ML\Transformers\BM25Transformer | ||
*/ | ||
protected $transformer; | ||
|
||
/** | ||
* @before | ||
*/ | ||
protected function setUp() : void | ||
{ | ||
$this->dataset = new Unlabeled([ | ||
[1, 3, 0, 0, 1, 0, 0, 0, 1, 2, 0, 2, 0, 0, 0, 4, 1, 0, 1], | ||
[0, 1, 1, 0, 0, 2, 1, 0, 0, 0, 0, 3, 0, 1, 0, 0, 0, 0, 0], | ||
[0, 0, 0, 1, 2, 3, 0, 0, 4, 2, 0, 0, 1, 0, 2, 0, 1, 0, 0], | ||
]); | ||
|
||
$this->transformer = new BM25Transformer(0.0); | ||
} | ||
|
||
/** | ||
* @test | ||
*/ | ||
public function build() : void | ||
{ | ||
$this->assertInstanceOf(BM25Transformer::class, $this->transformer); | ||
$this->assertInstanceOf(Transformer::class, $this->transformer); | ||
$this->assertInstanceOf(Stateful::class, $this->transformer); | ||
$this->assertInstanceOf(Elastic::class, $this->transformer); | ||
} | ||
|
||
/** | ||
* @test | ||
*/ | ||
public function fitTransform() : void | ||
{ | ||
$this->transformer->fit($this->dataset); | ||
|
||
$this->assertTrue($this->transformer->fitted()); | ||
|
||
$dfs = $this->transformer->dfs(); | ||
|
||
$this->assertIsArray($dfs); | ||
$this->assertCount(19, $dfs); | ||
$this->assertContainsOnly('int', $dfs); | ||
|
||
$this->dataset->apply($this->transformer); | ||
|
||
$outcome = [ | ||
[1.6931471805599454, 3.8630462173553424, 0.0, 0.0, 1.2876820724517808, 0.0, 0.0, 0.0, 1.2876820724517808, 2.5753641449035616, 0.0, 2.5753641449035616, 0.0, 0.0, 0.0, 6.772588722239782, 1.2876820724517808, 0.0, 1.6931471805599454], | ||
[0.0, 1.2876820724517808, 1.6931471805599454, 0.0, 0.0, 2.5753641449035616, 1.6931471805599454, 0.0, 0.0, 0.0, 0.0, 3.8630462173553424, 0.0, 1.6931471805599454, 0.0, 0.0, 0.0, 0.0, 0.0], | ||
[0.0, 0.0, 0.0, 1.6931471805599454, 2.5753641449035616, 3.8630462173553424, 0.0, 0.0, 5.150728289807123, 2.5753641449035616, 0.0, 0.0, 1.6931471805599454, 0.0, 3.386294361119891, 0.0, 1.2876820724517808, 0.0, 0.0], | ||
]; | ||
|
||
$this->assertEquals($outcome, $this->dataset->samples()); | ||
} | ||
|
||
/** | ||
* @test | ||
*/ | ||
public function transformUnfitted() : void | ||
{ | ||
$this->expectException(RuntimeException::class); | ||
|
||
$samples = $this->dataset->samples(); | ||
|
||
$this->transformer->transform($samples); | ||
} | ||
} |