-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add support for text data & tokenization
- Loading branch information
Showing
16 changed files
with
256 additions
and
105 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
41 changes: 41 additions & 0 deletions
41
discojs/discojs-core/src/dataset/data/preprocessing/image_preprocessing.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import { tf } from '../../..' | ||
|
||
import { List } from 'immutable' | ||
|
||
export enum ImagePreprocessing { | ||
Resize, | ||
Normalize | ||
} | ||
|
||
interface ImageEntry extends tf.TensorContainerObject { | ||
xs: tf.Tensor3D | tf.Tensor4D | ||
ys: tf.Tensor1D | number | undefined | ||
} | ||
|
||
const resize = { | ||
type: ImagePreprocessing.Resize, | ||
apply: (entry: tf.TensorContainer): tf.TensorContainer => { | ||
const { xs, ys } = entry as ImageEntry | ||
return { | ||
xs, | ||
ys | ||
} | ||
} | ||
} | ||
|
||
const normalize = { | ||
type: ImagePreprocessing.Normalize, | ||
apply: (entry: tf.TensorContainer): tf.TensorContainer => { | ||
const { xs, ys } = entry as ImageEntry | ||
return { | ||
xs, | ||
ys | ||
} | ||
} | ||
} | ||
|
||
// Add your preprocessing here | ||
export const AVAILABLE_PREPROCESSING = List.of( | ||
resize, | ||
normalize | ||
).sortBy((e) => e.type) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
export { AVAILABLE_PREPROCESSING as IMAGE_PREPROCESSING, ImagePreprocessing } from './image_preprocessing' |
1 change: 1 addition & 0 deletions
1
discojs/discojs-core/src/dataset/data/preprocessing/text_preprocessing.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
export enum TextPreprocessing {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import { Data } from './data' | ||
|
||
export class TextData extends Data { | ||
batch (): TextData { | ||
return new TextData(this.batchedDataset, this.task, this.size) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
41 changes: 41 additions & 0 deletions
41
discojs/discojs-core/src/dataset/data_loader/text_loader.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import { List } from 'immutable' | ||
|
||
import { Task } from '../..' | ||
import { DataLoader, DataConfig } from './data_loader' | ||
import { Dataset } from '../dataset' | ||
import { DataSplit } from '../data/data_split' | ||
import { TextData } from '../data/text_data' | ||
|
||
const BUFFER_SIZE = 50 | ||
|
||
export abstract class TextLoader<Source> extends DataLoader<Source> { | ||
constructor ( | ||
task: Task, | ||
public readonly delimiter = ',' | ||
) { | ||
super(task) | ||
} | ||
|
||
abstract loadTextDatasetFrom (source: Source): Promise<Dataset> | ||
|
||
async load (source: Source, config?: DataConfig): Promise<Dataset> { | ||
const dataset = await this.loadTextDatasetFrom(source) | ||
return config?.shuffle ? dataset.shuffle(BUFFER_SIZE) : dataset | ||
} | ||
|
||
async loadAll (sources: Source[], config: DataConfig): Promise<DataSplit> { | ||
const datasets = await Promise.all(sources.map(async (source) => | ||
await this.load(source, { ...config, shuffle: false }))) | ||
let dataset = List(datasets).reduce((acc: Dataset, dataset) => | ||
acc.concatenate(dataset)) | ||
dataset = config?.shuffle ? dataset.shuffle(BUFFER_SIZE) : dataset | ||
const data = await TextData.init( | ||
dataset, | ||
this.task, | ||
undefined | ||
) | ||
return { | ||
train: data | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
30 changes: 30 additions & 0 deletions
30
discojs/discojs-node/src/dataset/data_loader/text_loader.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import fs from 'node:fs' | ||
|
||
import split2 from 'split2' | ||
|
||
import { tf } from '../..' | ||
import { TextLoader } from 'core/dataset/data_loader/text_loader' | ||
import { Dataset } from 'core/dataset' | ||
import { DataConfig } from 'core/dataset/data_loader' | ||
|
||
export class NodeTextLoader extends TextLoader<string> { | ||
async loadTextDatasetFrom (source: string, config?: DataConfig): Promise<Dataset> { | ||
const prefix = 'file://' | ||
if (source.slice(0, 7) !== prefix) { | ||
source = prefix + source | ||
} | ||
// create stream being read by generator | ||
const stream = fs.createReadStream(source, { encoding: 'utf-8' }) | ||
// eslint-disable-next-line @typescript-eslint/no-this-alias | ||
const self = this | ||
|
||
async function * dataGenerator (): AsyncGenerator<tf.TensorContainer> { | ||
// TODO @s314cy | ||
const withLabels = config?.labels !== undefined | ||
stream.pipe(split2()) | ||
stream.on('data', (data) => yield self.tokenize(data)) | ||
} | ||
|
||
return tf.data.generator(dataGenerator) | ||
} | ||
} |
28 changes: 28 additions & 0 deletions
28
discojs/discojs-web/src/dataset/data_loader/text_loader.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import { DataConfig } from '@/core/dataset/data_loader' | ||
import { tf } from '../..' | ||
import { Dataset } from 'core/dataset' | ||
import { TextLoader } from 'core/dataset/data_loader/text_loader' | ||
|
||
export class WebTextLoader extends TextLoader<File> { | ||
async loadTextDatasetFrom (source: File, config?: DataConfig): Promise<Dataset> { | ||
const labels = config?.labels | ||
if (labels !== undefined) { | ||
return new tf.data.CSVDataset(new tf.data.FileDataSource(source), { | ||
delimiter: this.delimiter, | ||
columnNames: ['sample', 'label'], | ||
columnConfigs: { | ||
sample: { | ||
required: true, | ||
isLabel: false | ||
}, | ||
label: { | ||
required: true, | ||
isLabel: true | ||
} | ||
} | ||
}) | ||
} else { | ||
return new tf.data.TextLineDataset(new tf.data.FileDataSource(source)) | ||
} | ||
} | ||
} |
Oops, something went wrong.