Skip to content

Commit

Permalink
discojs-core/data_loader: add text
Browse files Browse the repository at this point in the history
  • Loading branch information
s314cy authored and tharvik committed Mar 18, 2024
1 parent 5b80f12 commit 72f1629
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 41 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { Dataset, DataSplit } from '..'
import type { DataSplit, Dataset } from '..'

export interface DataConfig { features?: string[], labels?: string[], shuffle?: boolean, validationSplit?: number, inference?: boolean }

Expand Down
4 changes: 3 additions & 1 deletion discojs/discojs-core/src/dataset/data_loader/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
export { type DataConfig, DataLoader } from './data_loader'
export type { DataConfig } from './data_loader'
export { DataLoader } from './data_loader'

export { ImageLoader } from './image_loader'
export { TabularLoader } from './tabular_loader'
export { TextLoader } from './text_loader'
34 changes: 28 additions & 6 deletions discojs/discojs-core/src/dataset/data_loader/text_loader.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,34 @@
import type { Dataset } from '..'
import type { Task } from '../..'

import { TabularLoader } from './tabular_loader'
import type { DataSplit, Dataset } from '..'
import { TextData } from '..'

import { DataLoader } from '.'

/**
* Text data loader whose instantiable implementation is delegated by the platform-dependent Disco subprojects, namely,
* @epfml/discojs-web and @epfml/discojs-node. Loads data from files whose entries are line-separated and each consist of
* a sentence-like sample associated to an optional label.
* @epfml/discojs-web and @epfml/discojs-node.
*/
export abstract class TextLoader<Source> extends TabularLoader<Source> {
abstract loadDatasetFrom (source: Source, config: Record<string, unknown>): Promise<Dataset>
export abstract class TextLoader<S> extends DataLoader<S> {
constructor (
private readonly task: Task
) {
super()
}

abstract loadDatasetFrom (source: S): Promise<Dataset>

async load (source: S): Promise<Dataset> {
return await this.loadDatasetFrom(source)
}

async loadAll (sources: S[]): Promise<DataSplit> {
const concatenated =
(await Promise.all(sources.map(async (src) => await this.load(src))))
.reduce((acc, dataset) => acc.concatenate(dataset))

return {
train: await TextData.init(concatenated, this.task)
}
}
}
1 change: 1 addition & 0 deletions discojs/discojs-node/src/data/index.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
export { ImageLoader as NodeImageLoader } from './image_loader'
export { TabularLoader as NodeTabularLoader } from './tabular_loader'
export { TextLoader as NodeTextLoader } from './text_loader'
38 changes: 11 additions & 27 deletions discojs/discojs-node/src/data/text_loader.ts
Original file line number Diff line number Diff line change
@@ -1,30 +1,14 @@
// import fs from 'node:fs'
import fs from 'node:fs/promises'
import { data as tfData } from '@tensorflow/tfjs-node'

// import split2 from 'split2'
import { data } from '@epfml/discojs-core'

// import { tf } from '../..'
// import { TextLoader } from '../../core/dataset/data_loader/text_loader'
// import { Dataset } from '../../core/dataset'
// import { DataConfig } from '../../core/dataset/data_loader'
export class TextLoader extends data.TextLoader<string> {
async loadDatasetFrom (source: string): Promise<data.Dataset> {
// TODO sure, good idea to load the whole dataset in memory #irony
const content = await fs.readFile(source)
const file = new tfData.FileDataSource(content)

// export class NodeTextLoader extends TextLoader<string> {
// async loadDatasetFrom (source: string, config?: DataConfig): Promise<Dataset> {
// const prefix = 'file://'
// if (source.slice(0, 7) !== prefix) {
// source = prefix + source
// }
// // create stream being read by generator
// const stream = fs.createReadStream(source, { encoding: 'utf-8' })
// // eslint-disable-next-line @typescript-eslint/no-this-alias
// const self = this

// async function * dataGenerator (): AsyncGenerator<tf.TensorContainer> {
// // TODO @s314cy
// const withLabels = config?.labels !== undefined
// stream.pipe(split2())
// stream.on('data', (data) => yield self.tokenize(data))
// }

// return tf.data.generator(dataGenerator)
// }
// }
return new tfData.TextLineDataset(file)
}
}
8 changes: 2 additions & 6 deletions discojs/discojs-web/src/data/text_loader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,8 @@ import tf from '@tensorflow/tfjs'
import { data } from '@epfml/discojs-core'

export class TextLoader extends data.TextLoader<File> {
async loadDatasetFrom (source: File, config?: Record<string, unknown>): Promise<data.Dataset> {
async loadDatasetFrom (source: File): Promise<data.Dataset> {
const file = new tf.data.FileDataSource(source)
if (config !== undefined) {
return new tf.data.CSVDataset(file, config)
} else {
return new tf.data.TextLineDataset(file)
}
return new tf.data.TextLineDataset(file)
}
}

0 comments on commit 72f1629

Please sign in to comment.