Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an example and type enhancement for TextStreamer #1066

Merged
merged 8 commits into from
Dec 3, 2024
64 changes: 64 additions & 0 deletions docs/source/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,70 @@ Cheddar is my go-to for any occasion or mood;
It adds depth and richness without being overpowering its taste buds alone
```

### Streaming

Some pipelines such as `text-generation` or `automatic-speech-recognition` support streaming output. This is achieved using the `TextStreamer` class. For example, when using a chat model like `Qwen2.5-Coder-0.5B-Instruct`, you can specify a callback function that will be called with each generated token text (if unset, new tokens will be printed to the console).

```js
import { pipeline, TextStreamer } from "@huggingface/transformers";

// Create a text generation pipeline
const generator = await pipeline(
"text-generation",
"onnx-community/Qwen2.5-Coder-0.5B-Instruct",
{ dtype: "q4" },
);

// Define the list of messages
const messages = [
{ role: "system", content: "You are a helpful assistant." },
{ role: "user", content: "Write a quick sort algorithm." },
];

// Create text streamer
const streamer = new TextStreamer(generator.tokenizer, {
skip_prompt: true,
// Optionally, do something with the text (e.g., write to a textbox)
// callback_function: (text) => { /* Do something with text */ },
})

// Generate a response
const result = await generator(messages, { max_new_tokens: 512, do_sample: false, streamer });
```

Logging `result[0].generated_text` to the console gives:


<details>
<summary>Click to view the console output</summary>
<pre>
Here's a simple implementation of the quick sort algorithm in Python:
```python
def quick_sort(arr):
if len(arr) <= 1:
return arr
pivot = arr[len(arr) // 2]
left = [x for x in arr if x < pivot]
middle = [x for x in arr if x == pivot]
right = [x for x in arr if x > pivot]
return quick_sort(left) + middle + quick_sort(right)
# Example usage:
arr = [3, 6, 8, 10, 1, 2]
sorted_arr = quick_sort(arr)
print(sorted_arr)
```
### Explanation:
- **Base Case**: If the array has less than or equal to one element (i.e., `len(arr)` is less than or equal to `1`), it is already sorted and can be returned as is.
- **Pivot Selection**: The pivot is chosen as the middle element of the array.
- **Partitioning**: The array is partitioned into three parts: elements less than the pivot (`left`), elements equal to the pivot (`middle`), and elements greater than the pivot (`right`). These partitions are then recursively sorted.
- **Recursive Sorting**: The subarrays are sorted recursively using `quick_sort`.
This approach ensures that each recursive call reduces the problem size by half until it reaches a base case.
</pre>
seonglae marked this conversation as resolved.
Show resolved Hide resolved
</details>

This streaming feature allows you to process the output as it is generated, rather than waiting for the entire output to be generated before processing it.


For more information on the available options for each pipeline, refer to the [API Reference](./api/pipelines).
If you would like more control over the inference process, you can use the [`AutoModel`](./api/models), [`AutoTokenizer`](./api/tokenizers), or [`AutoProcessor`](./api/processors) classes instead.

Expand Down
7 changes: 7 additions & 0 deletions src/generation/configuration_utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,13 @@ export class GenerationConfig {
*/
suppress_tokens = null;

/**
* A streamer that will be used to stream the generation.
* @type {import('./streamers.js').TextStreamer}
* @default null
*/
streamer = null;

/**
* A list of tokens that will be suppressed at the beginning of the generation.
* The `SuppressBeginTokens` logit processor will set their log probs to `-inf` so that they are not sampled.
Expand Down
9 changes: 7 additions & 2 deletions src/generation/streamers.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ const stdout_write = apis.IS_PROCESS_AVAILABLE
export class TextStreamer extends BaseStreamer {
/**
*
* @param {import('../tokenizers.js').PreTrainedTokenizer} tokenizer
* @param {import('../tokenizers.js').PreTrainedTokenizer} tokenizer
* @param {Object} options
* @param {boolean} [options.skip_prompt=false] Whether to skip the prompt tokens
* @param {function(string): void} [options.callback_function=null] Function to call when a piece of text is ready to display
* @param {function(bigint[]): void} [options.token_callback_function=null] Function to call when a new token is generated
* @param {Object} [options.decode_kwargs={}] Additional keyword arguments to pass to the tokenizer's decode method
seonglae marked this conversation as resolved.
Show resolved Hide resolved
*/
constructor(tokenizer, {
skip_prompt = false,
Expand Down Expand Up @@ -143,7 +148,7 @@ export class WhisperTextStreamer extends TextStreamer {
* @param {Object} options
* @param {boolean} [options.skip_prompt=false] Whether to skip the prompt tokens
* @param {function(string): void} [options.callback_function=null] Function to call when a piece of text is ready to display
* @param {function(string): void} [options.token_callback_function=null] Function to call when a new token is generated
* @param {function(bigint[]): void} [options.token_callback_function=null] Function to call when a new token is generated
* @param {function(number): void} [options.on_chunk_start=null] Function to call when a new chunk starts
* @param {function(number): void} [options.on_chunk_end=null] Function to call when a chunk ends
* @param {function(): void} [options.on_finalize=null] Function to call when the stream is finalized
Expand Down
Loading