-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
46 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
export { type DataConfig, DataLoader } from './data_loader' | ||
export type { DataConfig } from './data_loader' | ||
export { DataLoader } from './data_loader' | ||
|
||
export { ImageLoader } from './image_loader' | ||
export { TabularLoader } from './tabular_loader' | ||
export { TextLoader } from './text_loader' |
34 changes: 28 additions & 6 deletions
34
discojs/discojs-core/src/dataset/data_loader/text_loader.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,34 @@ | ||
import type { Dataset } from '..' | ||
import type { Task } from '../..' | ||
|
||
import { TabularLoader } from './tabular_loader' | ||
import type { DataSplit, Dataset } from '..' | ||
import { TextData } from '..' | ||
|
||
import { DataLoader } from '.' | ||
|
||
/** | ||
* Text data loader whose instantiable implementation is delegated by the platform-dependent Disco subprojects, namely, | ||
* @epfml/discojs-web and @epfml/discojs-node. Loads data from files whose entries are line-separated and each consist of | ||
* a sentence-like sample associated to an optional label. | ||
* @epfml/discojs-web and @epfml/discojs-node. | ||
*/ | ||
export abstract class TextLoader<Source> extends TabularLoader<Source> { | ||
abstract loadDatasetFrom (source: Source, config: Record<string, unknown>): Promise<Dataset> | ||
export abstract class TextLoader<S> extends DataLoader<S> { | ||
constructor ( | ||
private readonly task: Task | ||
) { | ||
super() | ||
} | ||
|
||
abstract loadDatasetFrom (source: S): Promise<Dataset> | ||
|
||
async load (source: S): Promise<Dataset> { | ||
return await this.loadDatasetFrom(source) | ||
} | ||
|
||
async loadAll (sources: S[]): Promise<DataSplit> { | ||
const concatenated = | ||
(await Promise.all(sources.map(async (src) => await this.load(src)))) | ||
.reduce((acc, dataset) => acc.concatenate(dataset)) | ||
|
||
return { | ||
train: await TextData.init(concatenated, this.task) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
export { ImageLoader as NodeImageLoader } from './image_loader' | ||
export { TabularLoader as NodeTabularLoader } from './tabular_loader' | ||
export { TextLoader as NodeTextLoader } from './text_loader' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,14 @@ | ||
// import fs from 'node:fs' | ||
import fs from 'node:fs/promises' | ||
import { data as tfData } from '@tensorflow/tfjs-node' | ||
|
||
// import split2 from 'split2' | ||
import { data } from '@epfml/discojs-core' | ||
|
||
// import { tf } from '../..' | ||
// import { TextLoader } from '../../core/dataset/data_loader/text_loader' | ||
// import { Dataset } from '../../core/dataset' | ||
// import { DataConfig } from '../../core/dataset/data_loader' | ||
export class TextLoader extends data.TextLoader<string> { | ||
async loadDatasetFrom (source: string): Promise<data.Dataset> { | ||
// TODO sure, good idea to load the whole dataset in memory #irony | ||
const content = await fs.readFile(source) | ||
const file = new tfData.FileDataSource(content) | ||
|
||
// export class NodeTextLoader extends TextLoader<string> { | ||
// async loadDatasetFrom (source: string, config?: DataConfig): Promise<Dataset> { | ||
// const prefix = 'file://' | ||
// if (source.slice(0, 7) !== prefix) { | ||
// source = prefix + source | ||
// } | ||
// // create stream being read by generator | ||
// const stream = fs.createReadStream(source, { encoding: 'utf-8' }) | ||
// // eslint-disable-next-line @typescript-eslint/no-this-alias | ||
// const self = this | ||
|
||
// async function * dataGenerator (): AsyncGenerator<tf.TensorContainer> { | ||
// // TODO @s314cy | ||
// const withLabels = config?.labels !== undefined | ||
// stream.pipe(split2()) | ||
// stream.on('data', (data) => yield self.tokenize(data)) | ||
// } | ||
|
||
// return tf.data.generator(dataGenerator) | ||
// } | ||
// } | ||
return new tfData.TextLineDataset(file) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters