Skip to content

Commit

Permalink
feat: add support for text data & tokenization
Browse files Browse the repository at this point in the history
  • Loading branch information
s314cy committed May 4, 2023
1 parent 80d20a7 commit e8c307f
Show file tree
Hide file tree
Showing 16 changed files with 256 additions and 105 deletions.
21 changes: 19 additions & 2 deletions discojs/discojs-core/src/dataset/data/data.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Task } from '../..'
import { tf, Task } from '../..'
import { Dataset } from '../dataset'

export abstract class Data {
Expand All @@ -17,5 +17,22 @@ export abstract class Data {

abstract batch (): Data

abstract preprocess (): Promise<Data>
get batchedDataset (): Dataset {
const batchSize = this.task.trainingInformation.batchSize
return batchSize === undefined
? this.dataset
: this.dataset.batch(batchSize)
}

async preprocess (): Promise<Data> {
return this
}

get preprocessing (): (entry: tf.TensorContainer) => tf.TensorContainer {
return (x: tf.TensorContainer) => x
}

get preprocessedDataset (): Dataset {
return this.dataset.map(this.preprocessing)
}
}
26 changes: 18 additions & 8 deletions discojs/discojs-core/src/dataset/data/image_data.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { tf, Task } from '../..'
import { getPreprocessImage, ImagePreprocessing } from './preprocessing'
import { Dataset } from '../dataset'
import { Data } from './data'
import { ImagePreprocessing, IMAGE_PREPROCESSING } from './preprocessing'

export class ImageData extends Data {
static async init (
Expand Down Expand Up @@ -43,16 +43,26 @@ export class ImageData extends Data {
}

batch (): Data {
const batchSize = this.task.trainingInformation.batchSize
const newDataset = batchSize === undefined ? this.dataset : this.dataset.batch(batchSize)
return new ImageData(this.batchedDataset, this.task, this.size)
}

get preprocessing (): (entry: tf.TensorContainer) => tf.TensorContainer {
const params = this.task.trainingInformation
const taskFunctions = params.preprocessingFunctions

if (taskFunctions === undefined) {
return (x) => x
}

const applyPreprocessing = IMAGE_PREPROCESSING
.filter((e) => e.type in taskFunctions)
.map((e) => e.apply)

return new ImageData(newDataset, this.task, this.size)
return applyPreprocessing.reduce((acc: (x: tf.TensorContainer) => tf.TensorContainer, fn) =>
(x: tf.TensorContainer) => fn(acc(x)))
}

async preprocess (): Promise<Data> {
let newDataset = this.dataset
const preprocessImage = getPreprocessImage(this.task)
newDataset = newDataset.map((x: tf.TensorContainer) => preprocessImage(x))
return new ImageData(newDataset, this.task, this.size)
return new ImageData(this.preprocessedDataset, this.task, this.size)
}
}
77 changes: 0 additions & 77 deletions discojs/discojs-core/src/dataset/data/preprocessing.ts

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import { tf } from '../../..'

import { List } from 'immutable'

export enum ImagePreprocessing {
Resize,
Normalize
}

interface ImageEntry extends tf.TensorContainerObject {
xs: tf.Tensor3D | tf.Tensor4D
ys: tf.Tensor1D | number | undefined
}

const resize = {
type: ImagePreprocessing.Resize,
apply: (entry: tf.TensorContainer): tf.TensorContainer => {
const { xs, ys } = entry as ImageEntry
return {
xs,
ys
}
}
}

const normalize = {
type: ImagePreprocessing.Normalize,
apply: (entry: tf.TensorContainer): tf.TensorContainer => {
const { xs, ys } = entry as ImageEntry
return {
xs,
ys
}
}
}

// Add your preprocessing here
export const AVAILABLE_PREPROCESSING = List.of(
resize,
normalize
).sortBy((e) => e.type)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export { AVAILABLE_PREPROCESSING as IMAGE_PREPROCESSING, ImagePreprocessing } from './image_preprocessing'
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export enum TextPreprocessing {}
13 changes: 1 addition & 12 deletions discojs/discojs-core/src/dataset/data/tabular_data.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { Task } from '../..'
import { getPreprocessTabular } from './preprocessing'
import { Dataset } from '../dataset'
import { Data } from './data'

Expand All @@ -23,16 +22,6 @@ export class TabularData extends Data {
}

batch (): Data {
const batchSize = this.task.trainingInformation.batchSize
const newDataset = batchSize === undefined ? this.dataset : this.dataset.batch(batchSize)

return new TabularData(newDataset, this.task, this.size)
}

async preprocess (): Promise<Data> {
let newDataset = this.dataset
const preprocessTabular = getPreprocessTabular(this.task)
newDataset = await preprocessTabular(newDataset)
return new TabularData(newDataset, this.task, this.size)
return new TabularData(this.batchedDataset, this.task, this.size)
}
}
7 changes: 7 additions & 0 deletions discojs/discojs-core/src/dataset/data/text_data.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import { Data } from './data'

export class TextData extends Data {
batch (): TextData {
return new TextData(this.batchedDataset, this.task, this.size)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@ import { DataLoader, DataConfig } from '../data_loader'
const BUFFER_SIZE = 1000

export abstract class TabularLoader<Source> extends DataLoader<Source> {
private readonly delimiter: string

constructor (task: Task, delimiter: string) {
constructor (task: Task, public readonly delimiter = ',') {
super(task)
this.delimiter = delimiter
}

/**
Expand Down
41 changes: 41 additions & 0 deletions discojs/discojs-core/src/dataset/data_loader/text_loader.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import { List } from 'immutable'

import { Task } from '../..'
import { DataLoader, DataConfig } from './data_loader'
import { Dataset } from '../dataset'
import { DataSplit } from '../data/data_split'
import { TextData } from '../data/text_data'

const BUFFER_SIZE = 50

export abstract class TextLoader<Source> extends DataLoader<Source> {
constructor (
task: Task,
public readonly delimiter = ','
) {
super(task)
}

abstract loadTextDatasetFrom (source: Source): Promise<Dataset>

async load (source: Source, config?: DataConfig): Promise<Dataset> {
const dataset = await this.loadTextDatasetFrom(source)
return config?.shuffle ? dataset.shuffle(BUFFER_SIZE) : dataset
}

async loadAll (sources: Source[], config: DataConfig): Promise<DataSplit> {
const datasets = await Promise.all(sources.map(async (source) =>
await this.load(source, { ...config, shuffle: false })))
let dataset = List(datasets).reduce((acc: Dataset, dataset) =>
acc.concatenate(dataset))
dataset = config?.shuffle ? dataset.shuffle(BUFFER_SIZE) : dataset
const data = await TextData.init(
dataset,
this.task,
undefined
)
return {
train: data
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import fs from 'fs'

import { tf, data } from '../..'
import { tf, data } from 'core'

export class NodeImageLoader extends data.ImageLoader<string> {
async readImageFrom (source: string): Promise<tf.Tensor3D> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { tf, data } from '../..'
import { tf, data } from 'core'

export class NodeTabularLoader extends data.TabularLoader<string> {
loadTabularDatasetFrom (source: string, csvConfig: Record<string, unknown>): tf.data.CSVDataset {
Expand Down
30 changes: 30 additions & 0 deletions discojs/discojs-node/src/dataset/data_loader/text_loader.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import fs from 'node:fs'

import split2 from 'split2'

import { tf } from '../..'
import { TextLoader } from 'core/dataset/data_loader/text_loader'
import { Dataset } from 'core/dataset'
import { DataConfig } from 'core/dataset/data_loader'

export class NodeTextLoader extends TextLoader<string> {
async loadTextDatasetFrom (source: string, config?: DataConfig): Promise<Dataset> {
const prefix = 'file://'
if (source.slice(0, 7) !== prefix) {
source = prefix + source
}
// create stream being read by generator
const stream = fs.createReadStream(source, { encoding: 'utf-8' })
// eslint-disable-next-line @typescript-eslint/no-this-alias
const self = this

async function * dataGenerator (): AsyncGenerator<tf.TensorContainer> {
// TODO @s314cy
const withLabels = config?.labels !== undefined
stream.pipe(split2())
stream.on('data', (data) => yield self.tokenize(data))
}

return tf.data.generator(dataGenerator)
}
}
28 changes: 28 additions & 0 deletions discojs/discojs-web/src/dataset/data_loader/text_loader.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import { DataConfig } from '@/core/dataset/data_loader'
import { tf } from '../..'
import { Dataset } from 'core/dataset'
import { TextLoader } from 'core/dataset/data_loader/text_loader'

export class WebTextLoader extends TextLoader<File> {
async loadTextDatasetFrom (source: File, config?: DataConfig): Promise<Dataset> {
const labels = config?.labels
if (labels !== undefined) {
return new tf.data.CSVDataset(new tf.data.FileDataSource(source), {
delimiter: this.delimiter,
columnNames: ['sample', 'label'],
columnConfigs: {
sample: {
required: true,
isLabel: false
},
label: {
required: true,
isLabel: true
}
}
})
} else {
return new tf.data.TextLineDataset(new tf.data.FileDataSource(source))
}
}
}
Loading

0 comments on commit e8c307f

Please sign in to comment.