From 823fc175f4298dbd91c15a50a234db74738e0398 Mon Sep 17 00:00:00 2001
From: Tarek Mahmoud Sayed <10833894+tarekgh@users.noreply.github.com>
Date: Fri, 11 Oct 2024 16:06:22 -0700
Subject: [PATCH] Misc Changes (#7264)
* Add o1 model support
* Replace Usage of tuples with Range in EncodedToken and Remove TorchSharp Range/Index implementation
* Rename SentencePieceBpeTokenizer to allow adding more models to it in the future.
* Make Tokenizer.Decode returns non-nullable string
* Make BPE tokenizer support added tokens
* add net9 package source to the nuget.config file
* Rename TiktokenPreTokenizer to RegexPreTokenizer
---
NuGet.config | 4 +
.../Microsoft.ML.AutoML.Samples.csproj | 1 +
eng/Versions.props | 1 +
.../Microsoft.ML.AutoML.Interactive.csproj | 3 +-
.../Microsoft.ML.GenAI.Core.csproj | 4 +
.../Pipeline/CausalLMPipeline.cs | 6 +-
.../LlamaTokenizerHelper.cs | 2 +-
.../Microsoft.ML.GenAI.LLaMA.csproj | 4 +
.../Microsoft.ML.GenAI.Mistral.csproj | 4 +
.../Microsoft.ML.GenAI.Phi.csproj | 6 +-
src/Microsoft.ML.Tokenizers/EncodedToken.cs | 6 +-
.../Microsoft.ML.Tokenizers.csproj | 1 +
.../Model/BPETokenizer.cs | 68 ++++++--
.../Model/CodeGenTokenizer.cs | 40 ++---
.../Model/EnglishRobertaTokenizer.cs | 20 +--
.../Model/LlamaTokenizer.cs | 6 +-
.../Model/Phi2Tokenizer.cs | 2 +-
...Tokenizer.cs => SentencePieceTokenizer.cs} | 32 ++--
.../Model/TiktokenTokenizer.cs | 20 ++-
src/Microsoft.ML.Tokenizers/Model/Word.cs | 2 +-
.../PreTokenizer/PreTokenizer.cs | 28 +++-
...enPreTokenizer.cs => RegexPreTokenizer.cs} | 6 +-
.../PreTokenizer/WhiteSpacePreTokenizer.cs | 61 --------
src/Microsoft.ML.Tokenizers/Tokenizer.cs | 6 +-
.../AutoFormerV2/Anchors.cs | 12 +-
.../AutoFormerV2/Attention.cs | 2 +-
.../AutoFormerV2/AutoFormerV2Block.cs | 2 +-
.../AutoFormerV2/ObjectDetectionTrainer.cs | 14 +-
src/Microsoft.ML.TorchSharp/Loss/FocalLoss.cs | 40 ++---
.../Microsoft.ML.TorchSharp.csproj | 7 +-
.../Utils/ImageUtils.cs | 46 +++---
src/Microsoft.ML.TorchSharp/Utils/Index.cs | 145 ------------------
src/Microsoft.ML.TorchSharp/Utils/Range.cs | 141 -----------------
.../Utils/RangeUtil.cs | 19 +++
.../Microsoft.ML.AutoML.Tests.csproj | 4 +
.../Microsoft.ML.CodeGenerator.Tests.csproj | 4 +
.../Microsoft.ML.Fairlearn.Tests.csproj | 4 +
.../Microsoft.ML.GenAI.Core.Tests.csproj | 1 +
.../Microsoft.ML.GenAI.LLaMA.Tests.csproj | 4 +
.../Microsoft.ML.GenAI.Mistral.Tests.csproj | 4 +
.../Microsoft.ML.GenAI.Phi.Tests.csproj | 1 +
.../Microsoft.ML.Tokenizers.Data.Tests.csproj | 4 +
.../Microsoft.ML.Tokenizers.Tests/BpeTests.cs | 66 +++++++-
.../CodeGenTests.cs | 84 +++++-----
.../EnglishRobertaTests.cs | 6 +-
.../LlamaTests.cs | 14 +-
.../Microsoft.ML.Tokenizers.Tests.csproj | 4 +
.../PreTokenizerTests.cs | 6 +-
.../TiktokenTests.cs | 21 +--
.../TokenizerTests.cs | 6 +-
.../Microsoft.ML.TorchSharp.Tests.csproj | 4 +
51 files changed, 433 insertions(+), 565 deletions(-)
rename src/Microsoft.ML.Tokenizers/Model/{SentencePieceBpeTokenizer.cs => SentencePieceTokenizer.cs} (98%)
rename src/Microsoft.ML.Tokenizers/PreTokenizer/{TiktokenPreTokenizer.cs => RegexPreTokenizer.cs} (95%)
delete mode 100644 src/Microsoft.ML.Tokenizers/PreTokenizer/WhiteSpacePreTokenizer.cs
delete mode 100644 src/Microsoft.ML.TorchSharp/Utils/Index.cs
delete mode 100644 src/Microsoft.ML.TorchSharp/Utils/Range.cs
create mode 100644 src/Microsoft.ML.TorchSharp/Utils/RangeUtil.cs
diff --git a/NuGet.config b/NuGet.config
index 5f023aa721..c60a5b8571 100644
--- a/NuGet.config
+++ b/NuGet.config
@@ -15,6 +15,7 @@
+
@@ -47,6 +48,9 @@
+
+
+
diff --git a/docs/samples/Microsoft.ML.AutoML.Samples/Microsoft.ML.AutoML.Samples.csproj b/docs/samples/Microsoft.ML.AutoML.Samples/Microsoft.ML.AutoML.Samples.csproj
index 628cbe5293..464a2cedd7 100644
--- a/docs/samples/Microsoft.ML.AutoML.Samples/Microsoft.ML.AutoML.Samples.csproj
+++ b/docs/samples/Microsoft.ML.AutoML.Samples/Microsoft.ML.AutoML.Samples.csproj
@@ -8,6 +8,7 @@
None
+ true
diff --git a/eng/Versions.props b/eng/Versions.props
index 12eda87457..48c8bb2e1c 100644
--- a/eng/Versions.props
+++ b/eng/Versions.props
@@ -41,6 +41,7 @@
3.27.1
3.3.5
1.1.1
+ 9.0.0-rc.1.24431.7
3.3.4
4.9.2
1.0.0-beta.24375.2
diff --git a/src/Microsoft.ML.AutoML.Interactive/Microsoft.ML.AutoML.Interactive.csproj b/src/Microsoft.ML.AutoML.Interactive/Microsoft.ML.AutoML.Interactive.csproj
index c391b0a00b..2ae1ca8467 100644
--- a/src/Microsoft.ML.AutoML.Interactive/Microsoft.ML.AutoML.Interactive.csproj
+++ b/src/Microsoft.ML.AutoML.Interactive/Microsoft.ML.AutoML.Interactive.csproj
@@ -4,9 +4,10 @@
net6.0
false
$(NoWarn)
-
+
None
+ true
diff --git a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj
index 0486831b27..59cc59edc7 100644
--- a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj
+++ b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj
@@ -7,6 +7,10 @@
preview
+
+ true
+
+
diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs
index c368378337..13c598b4ec 100644
--- a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs
+++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs
@@ -255,7 +255,7 @@ public virtual IEnumerable GenerateStreaming(
return tokens
// Skip the first _ token automatically added by tokenizer
- .Where(t => t.Offset != (0, 0))
+ .Where(t => !t.Offset.Equals(new Range(0, 0)))
.Select(t => t.Id)
.ToArray();
}));
@@ -268,13 +268,13 @@ public virtual IEnumerable GenerateStreaming(
var tokenIds = token[0].to_type(ScalarType.Int32).data().ToArray();
var duplicateTokenString = this.Tokenizer switch
{
- SentencePieceBpeTokenizer bpeTokenizer => bpeTokenizer.Decode(tokenIds.Concat(tokenIds), considerSpecialTokens: true) ?? throw new InvalidOperationException("Failed to decode token ids"),
+ SentencePieceTokenizer bpeTokenizer => bpeTokenizer.Decode(tokenIds.Concat(tokenIds), considerSpecialTokens: true) ?? throw new InvalidOperationException("Failed to decode token ids"),
_ => this.Tokenizer.Decode(tokenIds.Concat(tokenIds)) ?? throw new InvalidOperationException("Failed to decode token ids"),
};
var tokenString = this.Tokenizer switch
{
- SentencePieceBpeTokenizer bpeTokenizer => bpeTokenizer.Decode(tokenIds, considerSpecialTokens: true) ?? throw new InvalidOperationException("Failed to decode token ids"),
+ SentencePieceTokenizer bpeTokenizer => bpeTokenizer.Decode(tokenIds, considerSpecialTokens: true) ?? throw new InvalidOperationException("Failed to decode token ids"),
_ => this.Tokenizer.Decode(tokenIds) ?? throw new InvalidOperationException("Failed to decode token ids"),
};
diff --git a/src/Microsoft.ML.GenAI.LLaMA/LlamaTokenizerHelper.cs b/src/Microsoft.ML.GenAI.LLaMA/LlamaTokenizerHelper.cs
index ea6f49edf7..489acb6524 100644
--- a/src/Microsoft.ML.GenAI.LLaMA/LlamaTokenizerHelper.cs
+++ b/src/Microsoft.ML.GenAI.LLaMA/LlamaTokenizerHelper.cs
@@ -49,7 +49,7 @@ public static TiktokenTokenizer FromPretrained(
string modelFile = "tokenizer.model")
{
var modelFilePath = Path.Join(modelWeightFolder, modelFile);
- var preTokenizer = new TiktokenPreTokenizer(new Regex(_re), _specialTokens);
+ var preTokenizer = new RegexPreTokenizer(new Regex(_re), _specialTokens);
return TiktokenTokenizer.Create(File.OpenRead(modelFilePath), preTokenizer, normalizer: null, specialTokens: _specialTokens);
}
}
diff --git a/src/Microsoft.ML.GenAI.LLaMA/Microsoft.ML.GenAI.LLaMA.csproj b/src/Microsoft.ML.GenAI.LLaMA/Microsoft.ML.GenAI.LLaMA.csproj
index 9fd5d267ac..81b334564e 100644
--- a/src/Microsoft.ML.GenAI.LLaMA/Microsoft.ML.GenAI.LLaMA.csproj
+++ b/src/Microsoft.ML.GenAI.LLaMA/Microsoft.ML.GenAI.LLaMA.csproj
@@ -7,6 +7,10 @@
true
+
+ true
+
+
diff --git a/src/Microsoft.ML.GenAI.Mistral/Microsoft.ML.GenAI.Mistral.csproj b/src/Microsoft.ML.GenAI.Mistral/Microsoft.ML.GenAI.Mistral.csproj
index 6dbf9f1aa5..4d0a2fb4b1 100644
--- a/src/Microsoft.ML.GenAI.Mistral/Microsoft.ML.GenAI.Mistral.csproj
+++ b/src/Microsoft.ML.GenAI.Mistral/Microsoft.ML.GenAI.Mistral.csproj
@@ -7,6 +7,10 @@
true
+
+ true
+
+
diff --git a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj
index b614d2f73a..0e2f8021a2 100644
--- a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj
+++ b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj
@@ -7,6 +7,10 @@
true
+
+ true
+
+
@@ -23,5 +27,5 @@
-
+
diff --git a/src/Microsoft.ML.Tokenizers/EncodedToken.cs b/src/Microsoft.ML.Tokenizers/EncodedToken.cs
index 06a00c9126..e6f3411b14 100644
--- a/src/Microsoft.ML.Tokenizers/EncodedToken.cs
+++ b/src/Microsoft.ML.Tokenizers/EncodedToken.cs
@@ -2,6 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using System;
+
namespace Microsoft.ML.Tokenizers
{
///
@@ -23,7 +25,7 @@ public readonly struct EncodedToken
///
/// Gets the offset mapping to the original string.
///
- public (int Index, int Length) Offset { get; }
+ public Range Offset { get; }
///
/// Construct a new Token object using the token value, Id, and the offset mapping to the original string.
@@ -31,7 +33,7 @@ public readonly struct EncodedToken
/// The Id value associated to the token.
/// The token string value.
/// The offset mapping to the original string.
- public EncodedToken(int id, string value, (int, int) offset)
+ public EncodedToken(int id, string value, Range offset)
{
Id = id;
Offset = offset;
diff --git a/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj b/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj
index 93a6cbb644..56686641b6 100644
--- a/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj
+++ b/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj
@@ -23,6 +23,7 @@
+
diff --git a/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs
index d85464ba39..6b6ec7a234 100644
--- a/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs
@@ -29,6 +29,13 @@ public sealed class BpeTokenizer : Tokenizer
private int? _unknownTokenId;
private readonly PreTokenizer? _preTokenizer;
private readonly Normalizer? _normalizer;
+ private readonly Dictionary? _addedTokens;
+ private readonly Dictionary? _addedTokensReverse;
+
+ ///
+ /// Gets the added tokens.
+ ///
+ public IReadOnlyDictionary? AddedTokens { get; }
///
/// Gets or Sets unknown token. The unknown token to be used when we encounter an unknown char
@@ -80,7 +87,7 @@ private set
/// The JSON file path containing the dictionary of string keys and their ids.
/// The file path containing the tokens's pairs list.
public static BpeTokenizer Create(string vocabFile, string? mergesFile)
- => Create(vocabFile, mergesFile, preTokenizer: WhiteSpacePreTokenizer.Instance, normalizer: null, unknownToken: null, continuingSubwordPrefix: null, endOfWordSuffix: null, fuseUnknownTokens: false);
+ => Create(vocabFile, mergesFile, preTokenizer: PreTokenizer.CreateWhiteSpace(), normalizer: null, unknownToken: null, continuingSubwordPrefix: null, endOfWordSuffix: null, fuseUnknownTokens: false);
///
/// Create a new Bpe tokenizer object to use for text encoding.
@@ -89,6 +96,7 @@ public static BpeTokenizer Create(string vocabFile, string? mergesFile)
/// The file path containing the tokens's pairs list.
/// The pre-tokenizer to use.
/// The normalizer to use.
+ /// The additional tokens to add to the vocabulary.
/// The unknown token to be used by the model.
/// The prefix to attach to sub-word units that don’t represent a beginning of word.
/// The suffix to attach to sub-word units that represent an end of word.
@@ -98,6 +106,7 @@ public static BpeTokenizer Create(
string? mergesFile,
PreTokenizer? preTokenizer = null,
Normalizer? normalizer = null,
+ IReadOnlyDictionary? addedTokens = null,
string? unknownToken = null,
string? continuingSubwordPrefix = null,
string? endOfWordSuffix = null,
@@ -113,7 +122,7 @@ public static BpeTokenizer Create(
(Dictionary? vocab, Vec<(string, string)> merges) result = ReadModelDataAsync(vocabStream, mergesStream, useAsync: false).GetAwaiter().GetResult();
- return new BpeTokenizer(result.vocab, result.merges, preTokenizer, normalizer, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens);
+ return new BpeTokenizer(result.vocab, result.merges, preTokenizer, normalizer, addedTokens, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens);
}
///
@@ -122,7 +131,7 @@ public static BpeTokenizer Create(
/// The JSON stream containing the dictionary of string keys and their ids.
/// The stream containing the tokens's pairs list.
public static BpeTokenizer Create(Stream vocabStream, Stream? mergesStream)
- => Create(vocabStream, mergesStream, preTokenizer: WhiteSpacePreTokenizer.Instance, normalizer: null, unknownToken: null, continuingSubwordPrefix: null, endOfWordSuffix: null, fuseUnknownTokens: false);
+ => Create(vocabStream, mergesStream, preTokenizer: PreTokenizer.CreateWhiteSpace(), normalizer: null, addedTokens: null, unknownToken: null, continuingSubwordPrefix: null, endOfWordSuffix: null, fuseUnknownTokens: false);
///
/// Create a new Bpe tokenizer object to use for text encoding.
@@ -131,6 +140,7 @@ public static BpeTokenizer Create(Stream vocabStream, Stream? mergesStream)
/// The stream containing the tokens's pairs list.
/// The pre-tokenizer to use.
/// The normalizer to use.
+ /// The additional tokens to add to the vocabulary.
/// The unknown token to be used by the model.
/// The prefix to attach to sub-word units that don’t represent a beginning of word.
/// The suffix to attach to sub-word units that represent an end of word.
@@ -140,6 +150,7 @@ public static BpeTokenizer Create(
Stream? mergesStream,
PreTokenizer? preTokenizer = null,
Normalizer? normalizer = null,
+ IReadOnlyDictionary? addedTokens = null,
string? unknownToken = null,
string? continuingSubwordPrefix = null,
string? endOfWordSuffix = null,
@@ -152,7 +163,7 @@ public static BpeTokenizer Create(
(Dictionary? vocab, Vec<(string, string)> merges) result = ReadModelDataAsync(vocabStream, mergesStream, useAsync: false).GetAwaiter().GetResult();
- return new BpeTokenizer(result.vocab, result.merges, preTokenizer, normalizer, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens);
+ return new BpeTokenizer(result.vocab, result.merges, preTokenizer, normalizer, addedTokens, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens);
}
///
@@ -162,6 +173,7 @@ public static BpeTokenizer Create(
/// The stream containing the tokens's pairs list.
/// The pre-tokenizer to use.
/// The normalizer to use.
+ /// The additional tokens to add to the vocabulary.
/// The unknown token to be used by the model.
/// The prefix to attach to sub-word units that don’t represent a beginning of word.
/// The suffix to attach to sub-word units that represent an end of word.
@@ -171,6 +183,7 @@ public static async Task CreateAsync(
Stream? mergesStream,
PreTokenizer? preTokenizer = null,
Normalizer? normalizer = null,
+ IReadOnlyDictionary? addedTokens = null,
string? unknownToken = null,
string? continuingSubwordPrefix = null,
string? endOfWordSuffix = null,
@@ -183,7 +196,7 @@ public static async Task CreateAsync(
(Dictionary? vocab, Vec<(string, string)> merges) result = await ReadModelDataAsync(vocabStream, mergesStream, useAsync: true).ConfigureAwait(false);
- return new BpeTokenizer(result.vocab, result.merges, preTokenizer, normalizer, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens);
+ return new BpeTokenizer(result.vocab, result.merges, preTokenizer, normalizer, addedTokens, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens);
}
///
@@ -193,16 +206,26 @@ public static async Task CreateAsync(
/// The pairs list help in merging tokens during the encoding process.
/// The pre-tokenizer to use.
/// The normalizer to use.
+ /// The additional tokens to add to the vocabulary.
/// The unknown token to be used by the model.
/// The prefix to attach to sub-word units that don’t represent a beginning of word.
/// The suffix to attach to sub-word units that represent an end of word.
/// Indicate whether allowing multiple unknown tokens get fused.
- private BpeTokenizer(Dictionary? vocab, Vec<(string, string)> merges, PreTokenizer? preTokenizer, Normalizer? normalizer, string? unknownToken, string? continuingSubwordPrefix, string? endOfWordSuffix, bool fuseUnknownTokens)
+ private BpeTokenizer(
+ Dictionary? vocab,
+ Vec<(string, string)> merges,
+ PreTokenizer? preTokenizer,
+ Normalizer? normalizer,
+ IReadOnlyDictionary? addedTokens,
+ string? unknownToken,
+ string? continuingSubwordPrefix,
+ string? endOfWordSuffix,
+ bool fuseUnknownTokens)
{
FuseUnknownTokens = fuseUnknownTokens;
ContinuingSubwordPrefix = continuingSubwordPrefix;
EndOfWordSuffix = endOfWordSuffix;
- _preTokenizer = preTokenizer ?? WhiteSpacePreTokenizer.Instance; // Default to WhiteSpace pre-tokenizer
+ _preTokenizer = preTokenizer ?? PreTokenizer.CreateWhiteSpace(); // Default to WhiteSpace pre-tokenizer
_normalizer = normalizer;
_vocab = vocab ?? new Dictionary();
@@ -215,6 +238,13 @@ private BpeTokenizer(Dictionary? vocab, Vec<(string,
VocabReverse.Add(kvp.Value, kvp.Key.Data!);
}
+ if (addedTokens is not null)
+ {
+ AddedTokens = addedTokens;
+ _addedTokens = addedTokens.ToDictionary(kvp => new StringSpanOrdinalKey(kvp.Key), kvp => (kvp.Value, kvp.Key));
+ _addedTokensReverse = addedTokens.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
+ }
+
UnknownToken = unknownToken;
int prefixLen = ContinuingSubwordPrefix is null ? 0 : ContinuingSubwordPrefix.Length;
@@ -568,7 +598,7 @@ private int LastIndexOf(string? text, ReadOnlySpan textSpan, int maxTokenC
///
/// The list of ids that we want to decode.
/// The decoded string.
- public override string? Decode(IEnumerable ids) => Decode(ids, considerSpecialTokens: true);
+ public override string Decode(IEnumerable ids) => Decode(ids, considerSpecialTokens: true);
///
/// Decode the given ids, back to a String.
@@ -576,7 +606,7 @@ private int LastIndexOf(string? text, ReadOnlySpan textSpan, int maxTokenC
/// The list of ids that we want to decode.
/// Indicate whether to consider special tokens or not.
/// The decoded string.
- public string? Decode(IEnumerable ids, bool considerSpecialTokens)
+ public string Decode(IEnumerable ids, bool considerSpecialTokens)
{
if (ids is null)
{
@@ -936,6 +966,12 @@ internal Word MergeWord(ReadOnlySpan w, ref PriorityQueue? priority
internal void EncodeWithCache(ReadOnlySpan text, List tokens, int offset, ref PriorityQueue? priorityQueue)
{
+ if (_addedTokens?.TryGetValue(text, out (int addedTokenId, string addedToken) value) is true)
+ {
+ tokens.Add(new EncodedToken(value.addedTokenId, value.addedToken, new Range(offset, offset + text.Length)));
+ return;
+ }
+
Word word;
if (Cache is not null)
{
@@ -1004,6 +1040,13 @@ internal int WordToIdsFromEnd(ref Word word, IList? accumulatedIds, out int
private int EncodeToIdsWithCache(ReadOnlySpan text, List? accumulatedIds, int maxTokens, out int charsConsumed, ref PriorityQueue? priorityQueue)
{
+ if (_addedTokens?.TryGetValue(text, out (int addedTokenId, string addedToken) value) is true && maxTokens > 0)
+ {
+ accumulatedIds?.Add(value.addedTokenId);
+ charsConsumed = text.Length;
+ return 1;
+ }
+
Word word;
if (Cache is not null)
@@ -1032,6 +1075,13 @@ internal int EncodeToIdsFromEndWithCache(ReadOnlySpan text, IList? ac
{
Word word;
+ if (_addedTokens?.TryGetValue(text, out (int addedTokenId, string addedToken) value) is true && maxTokens > 0)
+ {
+ accumulatedIds?.Add(value.addedTokenId);
+ textIndex = 0;
+ return 1;
+ }
+
if (Cache is not null)
{
if (Cache.TryGetValue(text, out Word hit))
diff --git a/src/Microsoft.ML.Tokenizers/Model/CodeGenTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/CodeGenTokenizer.cs
index fbfbba7f7e..c1fd6bb1ca 100644
--- a/src/Microsoft.ML.Tokenizers/Model/CodeGenTokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/CodeGenTokenizer.cs
@@ -376,7 +376,7 @@ private EncodeResults EncodeToTokens(string? text, scoped ReadOnly
List tokens = new();
if (addBos && BeginningOfSentenceId.HasValue)
{
- tokens.Add(new EncodedToken(BeginningOfSentenceId.Value, BeginningOfSentenceToken!, (0, 0)));
+ tokens.Add(new EncodedToken(BeginningOfSentenceId.Value, BeginningOfSentenceToken!, new Range(0, 0)));
}
PriorityQueue agenda = new(textSpanToEncode.Length);
@@ -395,7 +395,8 @@ private EncodeResults EncodeToTokens(string? text, scoped ReadOnly
if (addEos && EndOfSentenceId.HasValue)
{
- tokens.Add(new EncodedToken(EndOfSentenceId.Value, EndOfSentenceToken!, (addPrefixSpace ? Math.Max(0, textSpanToEncode.Length - 1) : textSpanToEncode.Length, 0)));
+ int index = addPrefixSpace ? Math.Max(0, textSpanToEncode.Length - 1) : textSpanToEncode.Length;
+ tokens.Add(new EncodedToken(EndOfSentenceId.Value, EndOfSentenceToken!, new Range(index, index)));
}
return new EncodeResults { Tokens = tokens, NormalizedText = normalizedString, CharsConsumed = textSpanToEncode.Length };
@@ -427,7 +428,8 @@ private void EncodeInternal(string? text, scoped ReadOnlySpan textSpan, Li
if (_addedTokens is not null && _addedTokens.TryGetValue(textSpan, out (int addedTokenId, string addedToken) value))
{
- tokens.Add(new EncodedToken(value.addedTokenId, value.addedToken, ((addPrefixSpace && offset > 0) ? offset - 1 : offset, (addPrefixSpace && offset == 0) ? textSpan.Length - 1 : textSpan.Length)));
+ int index = (addPrefixSpace && offset > 0) ? offset - 1 : offset;
+ tokens.Add(new EncodedToken(value.addedTokenId, value.addedToken, new Range(index, index + ((addPrefixSpace && offset == 0) ? textSpan.Length - 1 : textSpan.Length))));
return;
}
@@ -1027,11 +1029,11 @@ private int EncodeToIdsResult(List tokens, IList? accumulated
for (tokenCount = 0; tokenCount < maxTokens; tokenCount++)
{
// maxTokens is less than tokens.Count, so it is safe to index maxTokens.
- if (tokens[tokenCount].Offset.Index == tokens[tokenCount + 1].Offset.Index)
+ if (tokens[tokenCount].Offset.Start.Value == tokens[tokenCount + 1].Offset.Start.Value)
{
// Ensure we'll not break the text in the middle of a code-point
int j = tokenCount + 2;
- while (j < tokens.Count && tokens[j].Offset.Index == tokens[tokenCount].Offset.Index)
+ while (j < tokens.Count && tokens[j].Offset.Start.Value == tokens[tokenCount].Offset.Start.Value)
{
j++;
}
@@ -1042,7 +1044,7 @@ private int EncodeToIdsResult(List tokens, IList? accumulated
for (int k = tokenCount; k < j; k++)
{
accumulatedIds?.Add(tokens[k].Id);
- charsConsumed += tokens[k].Offset.Length;
+ charsConsumed += tokens[k].Offset.End.Value - tokens[k].Offset.Start.Value;
}
tokenCount = j - 1;
}
@@ -1054,7 +1056,7 @@ private int EncodeToIdsResult(List tokens, IList? accumulated
else
{
accumulatedIds?.Add(tokens[tokenCount].Id);
- charsConsumed += tokens[tokenCount].Offset.Length;
+ charsConsumed += tokens[tokenCount].Offset.End.Value - tokens[tokenCount].Offset.Start.Value;
}
}
@@ -1082,7 +1084,7 @@ private int EncodeToIdsFromEndResult(List tokens, IList? accu
int index = tokens.Count - maxTokens;
// avoid breaking the text in the middle of a code-point
- while (index < tokens.Count && tokens[index].Offset.Index == tokens[index - 1].Offset.Index)
+ while (index < tokens.Count && tokens[index].Offset.Start.Value == tokens[index - 1].Offset.Start.Value)
{
index++;
}
@@ -1090,7 +1092,7 @@ private int EncodeToIdsFromEndResult(List tokens, IList? accu
for (int i = index; i < tokens.Count; i++)
{
accumulatedIds?.Add(tokens[i].Id);
- textIndex -= tokens[i].Offset.Length;
+ textIndex -= tokens[i].Offset.End.Value - tokens[i].Offset.Start.Value;
}
return tokens.Count - index;
@@ -1229,7 +1231,7 @@ private int EncodeToIdsFromEndInternal(string? text, scoped ReadOnlySpan t
///
/// The list of ids that we want to decode.
/// The decoded string.
- public override string? Decode(IEnumerable ids) => Decode(ids, hasPrefixSpace: AddPrefixSpace, considerSpecialTokens: false);
+ public override string Decode(IEnumerable ids) => Decode(ids, hasPrefixSpace: AddPrefixSpace, considerSpecialTokens: false);
///
/// Decode the given ids, back to a String.
@@ -1238,7 +1240,7 @@ private int EncodeToIdsFromEndInternal(string? text, scoped ReadOnlySpan t
/// Indicate whether the encoded string has a leading space.
/// Indicate whether to consider special tokens during decoding.
/// The decoded string.
- public string? Decode(IEnumerable ids, bool hasPrefixSpace, bool considerSpecialTokens)
+ public string Decode(IEnumerable ids, bool hasPrefixSpace, bool considerSpecialTokens)
{
if (ids is null)
{
@@ -1590,11 +1592,12 @@ private static void AppendTokenWithOffsetAdjusting(IReadOnlyList t
{
if (tokensToAdd.Count > 0)
{
- tokens.Add(new EncodedToken(tokensToAdd[0].Id, tokensToAdd[0].Value, (offset == 0 ? tokensToAdd[0].Offset.Index : tokensToAdd[0].Offset.Index + offset - 1, offset == 0 ? tokensToAdd[0].Offset.Length - 1 : tokensToAdd[0].Offset.Length)));
+ (int s, int e) r = offset == 0 ? (tokensToAdd[0].Offset.Start.Value, tokensToAdd[0].Offset.End.Value - 1) : (tokensToAdd[0].Offset.Start.Value + offset - 1, tokensToAdd[0].Offset.End.Value + offset - 1);
+ tokens.Add(new EncodedToken(tokensToAdd[0].Id, tokensToAdd[0].Value, new Range(r.s, r.e)));
for (int i = 1; i < tokensToAdd.Count; i++)
{
- tokens.Add(new EncodedToken(tokensToAdd[i].Id, tokensToAdd[i].Value, (tokensToAdd[i].Offset.Index + offset - 1, tokensToAdd[i].Offset.Length)));
+ tokens.Add(new EncodedToken(tokensToAdd[i].Id, tokensToAdd[i].Value, new Range(tokensToAdd[i].Offset.Start.Value + offset - 1, tokensToAdd[i].Offset.End.Value + offset - 1)));
}
}
}
@@ -1602,7 +1605,7 @@ private static void AppendTokenWithOffsetAdjusting(IReadOnlyList t
{
foreach (EncodedToken t in tokensToAdd)
{
- tokens.Add(new EncodedToken(t.Id, t.Value, (t.Offset.Index + offset, t.Offset.Length)));
+ tokens.Add(new EncodedToken(t.Id, t.Value, new Range(t.Offset.Start.Value + offset, t.Offset.End.Value + offset)));
}
}
}
@@ -1622,7 +1625,7 @@ private List EncodeToTokens(Span text, Span mapping, Re
char c = text[0];
string[] charToString = ByteToUnicodeEncoding.Instance.CharToString;
string tokenValue = (uint)c < charToString.Length ? charToString[c] : c.ToString();
- return new List { new EncodedToken(_vocab[new StringSpanOrdinalKey(tokenValue)].Id, tokenValue, (mapping[0], 1)) };
+ return new List { new EncodedToken(_vocab[new StringSpanOrdinalKey(tokenValue)].Id, tokenValue, new Range(mapping[0], mapping[0] + 1)) };
}
BpeSymbol[] symbols = ArrayPool.Shared.Rent(text.Length);
@@ -1694,9 +1697,8 @@ private List EncodeToTokens(Span text, Span mapping, Re
static EncodedToken GetToken(int id, string token, int index, int length, ReadOnlySpan originalText, Span mapping)
{
- int tokenStartIndex = mapping[index];
- int tokenLength = (index + length < mapping.Length ? mapping[index + length] - tokenStartIndex : originalText.Length - tokenStartIndex);
- return new EncodedToken(id, token, (tokenStartIndex, tokenLength));
+ int endIndex = index + length < mapping.Length ? mapping[index + length] : originalText.Length;
+ return new EncodedToken(id, token, new Range(mapping[index], endIndex));
}
void TryMerge(int left, int right, ReadOnlySpan textSpan)
@@ -1892,7 +1894,7 @@ public static CodeGenTokenizer Create(
return new CodeGenTokenizer(
vocabStream,
mergesStream,
- new TiktokenPreTokenizer(TiktokenTokenizer.P50kBaseRegex(), CodeGenTokenizer.CodeGenAddedTokens),
+ new RegexPreTokenizer(TiktokenTokenizer.P50kBaseRegex(), CodeGenTokenizer.CodeGenAddedTokens),
normalizer: null,
CodeGenTokenizer.CodeGenAddedTokens,
addPrefixSpace: addPrefixSpace,
diff --git a/src/Microsoft.ML.Tokenizers/Model/EnglishRobertaTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/EnglishRobertaTokenizer.cs
index e1cc47e13f..85f921ff0f 100644
--- a/src/Microsoft.ML.Tokenizers/Model/EnglishRobertaTokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/EnglishRobertaTokenizer.cs
@@ -325,7 +325,7 @@ protected override EncodeResults EncodeToTokens(string? text, Read
{
foreach (EncodedToken t in EncodeInternal(textSpanToEncode.Slice(split.Offset, split.Length)))
{
- tokens.Add(new EncodedToken(t.Id, t.Value, (split.Offset + t.Offset.Index, t.Offset.Length)));
+ tokens.Add(new EncodedToken(t.Id, t.Value, new Range(split.Offset + t.Offset.Start.Value, split.Offset + t.Offset.End.Value)));
}
}
@@ -597,14 +597,14 @@ private int EncodeToIdsResult(List tokens, IList? accumulated
for (int i = 0; i < maxTokens; i++)
{
accumulatedIds.Add(tokens[i].Id);
- charsConsumed += tokens[i].Offset.Length;
+ charsConsumed += tokens[i].Offset.End.Value - tokens[i].Offset.Start.Value;
}
}
else
{
for (int i = 0; i < maxTokens; i++)
{
- charsConsumed += tokens[i].Offset.Length;
+ charsConsumed += tokens[i].Offset.End.Value - tokens[i].Offset.Start.Value;
}
}
@@ -634,14 +634,14 @@ private int EncodeToIdsFromEndResult(List tokens, IList? accu
for (int i = tokens.Count - maxTokens; i < tokens.Count; i++)
{
accumulatedIds.Add(tokens[i].Id);
- textIndex -= tokens[i].Offset.Length;
+ textIndex -= tokens[i].Offset.End.Value - tokens[i].Offset.Start.Value;
}
}
else
{
for (int i = tokens.Count - maxTokens; i < tokens.Count; i++)
{
- textIndex -= tokens[i].Offset.Length;
+ textIndex -= tokens[i].Offset.End.Value - tokens[i].Offset.Start.Value;
}
}
@@ -750,7 +750,7 @@ private int EncodeToIdsFromEndInternal(ReadOnlySpan text, IList? accu
///
/// The list of ids that we want to decode.
/// The decoded string.
- public override string? Decode(IEnumerable ids)
+ public override string Decode(IEnumerable ids)
{
if (ids is null)
{
@@ -905,7 +905,7 @@ private IReadOnlyList ModifyTokenListOffsets(IReadOnlyList list = new List(tokens.Count);
for (int j = 0; j < i; j++)
@@ -915,7 +915,7 @@ private IReadOnlyList ModifyTokenListOffsets(IReadOnlyList EncodeToTokens(Span token, Span indexMappi
{
Debug.Assert(token[0] < charToString.Length);
string tokenValue = charToString[token[0]];
- return new List { new EncodedToken(_vocab[new StringSpanOrdinalKey(tokenValue)], tokenValue, (indexMapping[0], 1)) };
+ return new List { new EncodedToken(_vocab[new StringSpanOrdinalKey(tokenValue)], tokenValue, new Range(indexMapping[0], indexMapping[0] + 1)) };
}
List word = new(token.Length);
@@ -1036,7 +1036,7 @@ private List EncodeToTokens(Span token, Span indexMappi
foreach (string w in word)
{
- tokens.Add(new EncodedToken(_vocab[new StringSpanOrdinalKey(w)], w, (indexMapping[index], w.Length)));
+ tokens.Add(new EncodedToken(_vocab[new StringSpanOrdinalKey(w)], w, new Range(indexMapping[index], indexMapping[index] + w.Length)));
index += w.Length;
}
diff --git a/src/Microsoft.ML.Tokenizers/Model/LlamaTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/LlamaTokenizer.cs
index 2406ab50fb..fe58b7bde1 100644
--- a/src/Microsoft.ML.Tokenizers/Model/LlamaTokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/LlamaTokenizer.cs
@@ -12,16 +12,16 @@ namespace Microsoft.ML.Tokenizers
// SentencePiece is under the Apache License 2.0 https://github.com/google/sentencepiece/blob/master/LICENSE
///
- /// LlamaTokenizer is SentencePieceBpeTokenizer which is implemented based on https://github.com/google/sentencepiece.
+ /// LlamaTokenizer is SentencePieceTokenizer which is implemented based on https://github.com/google/sentencepiece.
///
- public sealed class LlamaTokenizer : SentencePieceBpeTokenizer
+ public sealed class LlamaTokenizer : SentencePieceTokenizer
{
internal LlamaTokenizer(ModelProto modelProto, bool addBos, bool addEos, IReadOnlyDictionary? addedTokens = null) : base(modelProto, addBos, addEos, addedTokens)
{
}
///
- /// Create from the given model stream a LlamaTokenizer which is based on SentencePieceBpeTokenizer. The model stream should contain the SentencePiece Bpe model according to
+ /// Create from the given model stream a LlamaTokenizer which is based on SentencePieceTokenizer. The model stream should contain the SentencePiece Bpe model according to
/// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto specification.
///
/// The stream containing the SentencePiece Bpe model.
diff --git a/src/Microsoft.ML.Tokenizers/Model/Phi2Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/Phi2Tokenizer.cs
index 64985bcc9d..b2229482fa 100644
--- a/src/Microsoft.ML.Tokenizers/Model/Phi2Tokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/Phi2Tokenizer.cs
@@ -113,7 +113,7 @@ internal Phi2Tokenizer(
}
return new Phi2Tokenizer(
- vocabStream, mergesStream, new TiktokenPreTokenizer(TiktokenTokenizer.P50kBaseRegex(), CodeGenTokenizer.CodeGenAddedTokens), normalizer: null,
+ vocabStream, mergesStream, new RegexPreTokenizer(TiktokenTokenizer.P50kBaseRegex(), CodeGenTokenizer.CodeGenAddedTokens), normalizer: null,
CodeGenTokenizer.CodeGenAddedTokens, addPrefixSpace: addPrefixSpace, addBeginningOfSentence: addBeginOfSentence, addEndOfSentence: addEndOfSentence);
}
}
diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs
similarity index 98%
rename from src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeTokenizer.cs
rename to src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs
index 45a58c84a4..b89606ba8d 100644
--- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeTokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs
@@ -22,7 +22,7 @@ namespace Microsoft.ML.Tokenizers
///
/// SentencePieceBpe is a tokenizer that splits the input into tokens using the SentencePiece Bpe model.
///
- public class SentencePieceBpeTokenizer : Tokenizer
+ public class SentencePieceTokenizer : Tokenizer
{
private const int UninitializedId = -2; // indicate if the symbol contains uninitialized id.
private readonly Dictionary _vocab = new();
@@ -36,14 +36,14 @@ public class SentencePieceBpeTokenizer : Tokenizer
private readonly Dictionary? _specialTokens;
private readonly Dictionary? _specialTokensReverse;
- internal SentencePieceBpeTokenizer(ModelProto modelProto, bool addBos, bool addEos, IReadOnlyDictionary? specialTokens = null) :
+ internal SentencePieceTokenizer(ModelProto modelProto, bool addBos, bool addEos, IReadOnlyDictionary? specialTokens = null) :
this(modelProto is null ? throw new ArgumentNullException(nameof(modelProto)) : modelProto, specialTokens)
{
AddBeginningOfSentence = addBos;
AddEndOfSentence = addEos;
}
- private SentencePieceBpeTokenizer(ModelProto modelProto, IReadOnlyDictionary? specialTokens)
+ private SentencePieceTokenizer(ModelProto modelProto, IReadOnlyDictionary? specialTokens)
{
for (int i = 0; i < modelProto.Pieces.Count; i++)
{
@@ -272,7 +272,7 @@ private void EncodeWithSpecialTokens(ReadOnlySpan text, bool addBeginOfSen
if (addBeginOfSentence)
{
- tokens.Add(new EncodedToken(BeginningOfSentenceId, BeginningOfSentenceToken, (0, 0)));
+ tokens.Add(new EncodedToken(BeginningOfSentenceId, BeginningOfSentenceToken, new Range(0, 0)));
}
int currentOffset = 0;
@@ -286,7 +286,7 @@ private void EncodeWithSpecialTokens(ReadOnlySpan text, bool addBeginOfSen
if (_specialTokens!.TryGetValue(text.Slice(Offset, Length), out int id))
{
- tokens.Add(new EncodedToken(id, _specialTokensReverse![id], (Offset, Length)));
+ tokens.Add(new EncodedToken(id, _specialTokensReverse![id], new Range(Offset, Offset + Length)));
}
currentOffset = Offset + Length;
@@ -299,7 +299,7 @@ private void EncodeWithSpecialTokens(ReadOnlySpan text, bool addBeginOfSen
if (addEndOfSentence)
{
- tokens.Add(new EncodedToken(EndOfSentenceId, EndOfSentenceToken, (text.Length, 0)));
+ tokens.Add(new EncodedToken(EndOfSentenceId, EndOfSentenceToken, new Range(text.Length, text.Length)));
}
}
@@ -319,7 +319,7 @@ private void EncodeInternal(ReadOnlySpan text, bool addBeginOfSentence, bo
if (addBeginOfSentence)
{
- tokens.Add(new EncodedToken(BeginningOfSentenceId, BeginningOfSentenceToken, (0, 0)));
+ tokens.Add(new EncodedToken(BeginningOfSentenceId, BeginningOfSentenceToken, new Range(0, 0)));
}
for (int index = 0; (uint)index < (uint)symbols.Length; index = symbols[index].next)
@@ -352,7 +352,7 @@ private void EncodeInternal(ReadOnlySpan text, bool addBeginOfSentence, bo
tokens.Add(new EncodedToken(
id,
GetTokenString(id, symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length, text),
- (symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length)));
+ new Range(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Index + symbols[index].pieceSpan.Length)));
}
continue;
}
@@ -364,7 +364,7 @@ private void EncodeInternal(ReadOnlySpan text, bool addBeginOfSentence, bo
if (addEndOfSentence)
{
- tokens.Add(new EncodedToken(EndOfSentenceId, EndOfSentenceToken, (text.Length, 0)));
+ tokens.Add(new EncodedToken(EndOfSentenceId, EndOfSentenceToken, new Range(text.Length, text.Length)));
}
return;
@@ -381,7 +381,7 @@ void EncodeAsBytes(ReadOnlySpan text, int index)
if (_vocabReverse.TryGetValue(id, out string? token))
{
- tokens.Add(new EncodedToken(id, token, (index + i, 1)));
+ tokens.Add(new EncodedToken(id, token, new Range(index + i, index + i + 1)));
}
}
else
@@ -405,7 +405,7 @@ void EncodeAsBytes(ReadOnlySpan text, int index)
if (_vocabReverse.TryGetValue(id, out string? token))
{
- tokens.Add(new EncodedToken(id, token, (index + i, length)));
+ tokens.Add(new EncodedToken(id, token, new Range(index + i, index + i + length)));
}
length = 0;
@@ -433,7 +433,7 @@ void Segment((int Index, int Length) pieceSpan, ReadOnlySpan text)
revMerge is null ||
!revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge))
{
- tokens.Add(new EncodedToken(id.Id, text.Slice(pieceSpan.Index, pieceSpan.Length).ToString(), (pieceSpan.Index, pieceSpan.Length)));
+ tokens.Add(new EncodedToken(id.Id, text.Slice(pieceSpan.Index, pieceSpan.Length).ToString(), new Range(pieceSpan.Index, pieceSpan.Index + pieceSpan.Length)));
return;
}
@@ -1526,7 +1526,7 @@ revMerge is null ||
///
/// The list of ids that we want to decode.
/// The decoded string.
- public override string? Decode(IEnumerable ids)
+ public override string Decode(IEnumerable ids)
=> Decode(ids, considerSpecialTokens: false);
///
@@ -1535,7 +1535,7 @@ revMerge is null ||
/// The list of ids that we want to decode.
/// Indicate whether to consider special tokens during decoding.
/// The decoded string.
- public string? Decode(IEnumerable ids, bool considerSpecialTokens)
+ public string Decode(IEnumerable ids, bool considerSpecialTokens)
{
if (ids is null)
{
@@ -1735,7 +1735,7 @@ static void AppendTokenWithCheckingPrefix(bool addDummyPrefix, bool treatWhitesp
prefixRemoved = true;
}
- static void TryDecodeAsSpecialToken(SentencePieceBpeTokenizer tokenizer, int id, bool considerSpecialTokens, ref ValueStringBuilder sb)
+ static void TryDecodeAsSpecialToken(SentencePieceTokenizer tokenizer, int id, bool considerSpecialTokens, ref ValueStringBuilder sb)
{
if (!considerSpecialTokens)
{
@@ -1979,7 +1979,7 @@ public OperationStatus Decode(IEnumerable ids, Span destination, bool
return OperationStatus.Done;
- static OperationStatus TryDecodeAsSpecialToken(SentencePieceBpeTokenizer tokenizer, int id, bool considerSpecialTokens, Span buffer, ref int charsWritten)
+ static OperationStatus TryDecodeAsSpecialToken(SentencePieceTokenizer tokenizer, int id, bool considerSpecialTokens, Span buffer, ref int charsWritten)
{
string? specialToken = null;
diff --git a/src/Microsoft.ML.Tokenizers/Model/TiktokenTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/TiktokenTokenizer.cs
index 47fc5971c0..b169b2234f 100644
--- a/src/Microsoft.ML.Tokenizers/Model/TiktokenTokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/TiktokenTokenizer.cs
@@ -307,7 +307,7 @@ private void EncodeToTokens(ReadOnlySpan text, List tokens,
tokens.Add(new EncodedToken(
value[i].Id,
value[i].TokenLength == 0 ? string.Empty : text.Slice(value[i].TokenIndex, value[i].TokenLength).ToString(),
- (value[i].TokenIndex + offset, value[i].TokenLength)));
+ new Range(value[i].TokenIndex + offset, value[i].TokenIndex + offset + value[i].TokenLength)));
}
return;
@@ -316,7 +316,7 @@ private void EncodeToTokens(ReadOnlySpan text, List tokens,
// cache miss
if (_vocab.TryGetValue(text, out (int Id, string Token) mappedId))
{
- tokens.Add(new EncodedToken(mappedId.Id, mappedId.Token, (offset, mappedId.Token.Length)));
+ tokens.Add(new EncodedToken(mappedId.Id, mappedId.Token, new Range(offset, offset + mappedId.Token.Length)));
return;
}
@@ -348,7 +348,7 @@ private void EncodeToTokens(ReadOnlySpan text, List tokens,
tokens.Add(new EncodedToken(
encodedTokens[i].Id,
encodedTokens[i].TokenLength == 0 ? string.Empty : text.Slice(encodedTokens[i].TokenIndex, encodedTokens[i].TokenLength).ToString(),
- (encodedTokens[i].TokenIndex + offset, encodedTokens[i].TokenLength)));
+ new Range(encodedTokens[i].TokenIndex + offset, encodedTokens[i].TokenIndex + offset + encodedTokens[i].TokenLength)));
}
}
@@ -792,7 +792,7 @@ private int EncodeToIdsFromEndResult((int Id, int TokenIndex, int TokenLength)[]
///
/// The list of ids that we want to decode.
/// The decoded string.
- public override string? Decode(IEnumerable ids)
+ public override string Decode(IEnumerable ids)
{
// Tiktoken doesn't guarantee a one-to-one correspondence between IDs and UTF-16 words.
// Consequently, decoding individual IDs into UTF-16 string is not supported; instead, decoding all IDs must be performed collectively.
@@ -824,10 +824,6 @@ private int EncodeToIdsFromEndResult((int Id, int TokenIndex, int TokenLength)[]
tokenBytes.Span.CopyTo(utf8Bytes.Slice(utf8ByteCount));
utf8ByteCount += tokenBytes.Length;
}
- else
- {
- return null;
- }
}
return Helpers.GetString(utf8Bytes.Slice(0, utf8ByteCount));
@@ -1029,6 +1025,7 @@ private enum ModelEncoding
private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixToEncoding =
[
// chat
+ ( "o1-", ModelEncoding.O200kBase ), // e.g. o1-mini
( "gpt-4o-", ModelEncoding.O200kBase), // e.g., gpt-4o-2024-05-13
( "gpt-4-", ModelEncoding.Cl100kBase), // e.g., gpt-4-0314, etc., plus gpt-4-32k
( "gpt-3.5-", ModelEncoding.Cl100kBase), // e.g, gpt-3.5-turbo-0301, -0401, etc.
@@ -1040,6 +1037,7 @@ private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixTo
{
// chat
{ "gpt-4o", ModelEncoding.O200kBase },
+ { "o1", ModelEncoding.O200kBase },
{ "gpt-4", ModelEncoding.Cl100kBase },
{ "gpt-3.5-turbo", ModelEncoding.Cl100kBase },
{ "gpt-3.5-turbo-16k", ModelEncoding.Cl100kBase },
@@ -1239,7 +1237,7 @@ private static TiktokenTokenizer CreateForModel(
cache.encoder,
cache.decoder,
cache.vocab,
- new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
+ new RegexPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
tiktokenConfiguration.SpecialTokens,
normalizer,
LruCache.DefaultCacheSize);
@@ -1367,7 +1365,7 @@ public static TiktokenTokenizer CreateForModel(
}
return new TiktokenTokenizer(vocabStream,
- new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
+ new RegexPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
tiktokenConfiguration.SpecialTokens,
normalizer,
cacheSize);
@@ -1407,7 +1405,7 @@ public static async Task CreateForModelAsync(
}
return await CreateAsync(vocabStream,
- new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
+ new RegexPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
normalizer,
tiktokenConfiguration.SpecialTokens,
cacheSize, cancellationToken).ConfigureAwait(false);
diff --git a/src/Microsoft.ML.Tokenizers/Model/Word.cs b/src/Microsoft.ML.Tokenizers/Model/Word.cs
index 5acfd9ae4b..003243934c 100644
--- a/src/Microsoft.ML.Tokenizers/Model/Word.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/Word.cs
@@ -296,7 +296,7 @@ public void ToTokens(SortedDictionary vocabReverse, List
- public abstract class PreTokenizer
+ public abstract partial class PreTokenizer
{
///
/// Get the offsets and lengths of the tokens relative to the .
@@ -40,6 +40,32 @@ public abstract class PreTokenizer
}
}
+ private const string WhiteSpacePattern = /*lang=regex*/ @"\w+|[^\w\s]+";
+ private static PreTokenizer? _whiteSpacePreTokenizer;
+#if NET7_0_OR_GREATER
+ [GeneratedRegex(WhiteSpacePattern)]
+ private static partial Regex WhiteSpaceRegex();
+#else
+ private static Regex WhiteSpaceRegex() => new Regex(WhiteSpacePattern, RegexOptions.Compiled);
+#endif
+
+ ///
+ /// Create a new instance of the class which split the text at the word boundary.
+ /// The word is a set of alphabet, numeric, and underscore characters.
+ ///
+ /// The dictionary containing the special tokens and their corresponding ids.
+ /// The pre-tokenizer that splits the text at the word boundary.
+ public static PreTokenizer CreateWhiteSpace(IReadOnlyDictionary? specialTokensEncoder = null)
+ {
+ if (specialTokensEncoder is null)
+ {
+ // return a singleton instance of the WhiteSpace pre-tokenizer
+ return _whiteSpacePreTokenizer ??= new RegexPreTokenizer(WhiteSpaceRegex(), null);
+ }
+
+ return new RegexPreTokenizer(WhiteSpaceRegex(), specialTokensEncoder);
+ }
+
internal static IEnumerable<(int Offset, int Length)> SplitText(ReadOnlySpan text, Regex regex)
{
#if NET7_0_OR_GREATER
diff --git a/src/Microsoft.ML.Tokenizers/PreTokenizer/TiktokenPreTokenizer.cs b/src/Microsoft.ML.Tokenizers/PreTokenizer/RegexPreTokenizer.cs
similarity index 95%
rename from src/Microsoft.ML.Tokenizers/PreTokenizer/TiktokenPreTokenizer.cs
rename to src/Microsoft.ML.Tokenizers/PreTokenizer/RegexPreTokenizer.cs
index 4050f75d07..9685e370b7 100644
--- a/src/Microsoft.ML.Tokenizers/PreTokenizer/TiktokenPreTokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/PreTokenizer/RegexPreTokenizer.cs
@@ -13,18 +13,18 @@ namespace Microsoft.ML.Tokenizers
///
/// The pre-tokenizer for Tiktoken tokenizer.
///
- public sealed class TiktokenPreTokenizer : PreTokenizer
+ public sealed partial class RegexPreTokenizer : PreTokenizer
{
private readonly Regex? _specialTokensRegex;
private readonly Regex _regex;
///
- /// Initializes a new instance of the class.
+ /// Initializes a new instance of the class.
///
/// The regex to use for splitting the text into smaller tokens in the pre-tokenization process.
/// The dictionary containing the special tokens and their corresponding ids.
/// When regex is null
- public TiktokenPreTokenizer(Regex regex, IReadOnlyDictionary? specialTokensEncoder)
+ public RegexPreTokenizer(Regex regex, IReadOnlyDictionary? specialTokensEncoder)
{
if (regex is null)
{
diff --git a/src/Microsoft.ML.Tokenizers/PreTokenizer/WhiteSpacePreTokenizer.cs b/src/Microsoft.ML.Tokenizers/PreTokenizer/WhiteSpacePreTokenizer.cs
deleted file mode 100644
index 4ba737d1bb..0000000000
--- a/src/Microsoft.ML.Tokenizers/PreTokenizer/WhiteSpacePreTokenizer.cs
+++ /dev/null
@@ -1,61 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System;
-using System.Collections.Generic;
-using System.Text.RegularExpressions;
-
-namespace Microsoft.ML.Tokenizers
-{
- ///
- /// The pre-tokenizer which split the text at the word boundary.
- /// The word is a set of alphabet, numeric, and underscore characters.
- ///
- public sealed partial class WhiteSpacePreTokenizer : PreTokenizer
- {
- ///
- /// Gets a singleton instance of the WhiteSpace pre-tokenizer..
- ///
- public static WhiteSpacePreTokenizer Instance { get; } = new WhiteSpacePreTokenizer();
-
- private const string PretokenizePattern = /*lang=regex*/ @"\w+|[^\w\s]+";
-#if NET7_0_OR_GREATER
- [GeneratedRegex(PretokenizePattern)]
- private static partial Regex PretokenizeRegex();
-#else
- private static readonly Regex _regex = new Regex(PretokenizePattern, RegexOptions.Compiled);
- private static Regex PretokenizeRegex() => _regex;
-#endif
-
- ///
- /// Get the offsets and lengths of the tokens relative to the .
- ///
- /// The string to split into tokens.
- /// The offsets and lengths of the tokens, expressed as pairs, are relative to the original string.
- public override IEnumerable<(int Offset, int Length)> PreTokenize(string text)
- {
- if (string.IsNullOrEmpty(text))
- {
- return [];
- }
-
- return SplitText(text, PretokenizeRegex());
- }
-
- ///
- /// Get the offsets and lengths of the tokens relative to the .
- ///
- /// The string to split into tokens.
- /// The offsets and lengths of the tokens, expressed as pairs, are relative to the original string.
- public override IEnumerable<(int Offset, int Length)> PreTokenize(ReadOnlySpan text)
- {
- if (text.IsEmpty)
- {
- return [];
- }
-
- return SplitText(text, PretokenizeRegex());
- }
- }
-}
diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs
index 4821a91984..f9e47707b0 100644
--- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs
@@ -241,7 +241,7 @@ protected virtual int GetIndexByTokenCount(string? text, ReadOnlySpan text
if (tokenCount > 0)
{
var token = tokens.Tokens[tokenCount - 1];
- return token.Offset.Index + token.Offset.Length;
+ return token.Offset.End.Value;
}
return 0;
@@ -251,7 +251,7 @@ protected virtual int GetIndexByTokenCount(string? text, ReadOnlySpan text
if (tokenCount > 0)
{
var token = tokens.Tokens[tokens.Tokens.Count - tokenCount];
- return token.Offset.Index;
+ return token.Offset.Start.Value;
}
return tokens.NormalizedText?.Length ?? textSpan.Length;
@@ -361,7 +361,7 @@ public int GetIndexByTokenCountFromEnd(ReadOnlySpan text, int maxTokenCoun
/// Types derived from may override this implementation to provide a more efficient implementation.
/// By default, it uses .
///
- public virtual string? Decode(IEnumerable ids)
+ public virtual string Decode(IEnumerable ids)
{
if (ids is null)
{
diff --git a/src/Microsoft.ML.TorchSharp/AutoFormerV2/Anchors.cs b/src/Microsoft.ML.TorchSharp/AutoFormerV2/Anchors.cs
index fdcbc070c8..081decbf07 100644
--- a/src/Microsoft.ML.TorchSharp/AutoFormerV2/Anchors.cs
+++ b/src/Microsoft.ML.TorchSharp/AutoFormerV2/Anchors.cs
@@ -103,18 +103,18 @@ private static Tensor GenerateAnchors(int baseSize = 16, double[] ratios = null,
var anchors = torch.zeros(new long[] { numAnchors, 4 }, dtype: torch.float32);
// scale base_size
- anchors[.., 2..] = baseSize * torch.tile(scales, new long[] { 2, ratios.Length }).transpose(1, 0);
+ anchors[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(2..)] = baseSize * torch.tile(scales, new long[] { 2, ratios.Length }).transpose(1, 0);
// compute areas of anchors
- var areas = torch.mul(anchors[.., 2], anchors[.., 3]);
+ var areas = torch.mul(anchors[RangeUtil.ToTensorIndex(..), 2], anchors[RangeUtil.ToTensorIndex(..), 3]);
// correct for ratios
- anchors[.., 2] = torch.sqrt(areas / torch.repeat_interleave(ratios, new long[] { scales.Length }));
- anchors[.., 3] = torch.mul(anchors[.., 2], torch.repeat_interleave(ratios, new long[] { scales.Length }));
+ anchors[RangeUtil.ToTensorIndex(..), 2] = torch.sqrt(areas / torch.repeat_interleave(ratios, new long[] { scales.Length }));
+ anchors[RangeUtil.ToTensorIndex(..), 3] = torch.mul(anchors[RangeUtil.ToTensorIndex(..), 2], torch.repeat_interleave(ratios, new long[] { scales.Length }));
// transform from (x_ctr, y_ctr, w, h) -> (x1, y1, x2, y2)
- anchors[.., torch.TensorIndex.Tensor(torch.tensor(new long[] { 0, 2 }, dtype: torch.int64))] -= torch.tile(anchors[.., 2] * 0.5, new long[] { 2, 1 }).T;
- anchors[.., torch.TensorIndex.Tensor(torch.tensor(new long[] { 1, 3 }, dtype: torch.int64))] -= torch.tile(anchors[.., 3] * 0.5, new long[] { 2, 1 }).T;
+ anchors[RangeUtil.ToTensorIndex(..), torch.TensorIndex.Tensor(torch.tensor(new long[] { 0, 2 }, dtype: torch.int64))] -= torch.tile(anchors[RangeUtil.ToTensorIndex(..), 2] * 0.5, new long[] { 2, 1 }).T;
+ anchors[RangeUtil.ToTensorIndex(..), torch.TensorIndex.Tensor(torch.tensor(new long[] { 1, 3 }, dtype: torch.int64))] -= torch.tile(anchors[RangeUtil.ToTensorIndex(..), 3] * 0.5, new long[] { 2, 1 }).T;
return anchors.MoveToOuterDisposeScope();
}
diff --git a/src/Microsoft.ML.TorchSharp/AutoFormerV2/Attention.cs b/src/Microsoft.ML.TorchSharp/AutoFormerV2/Attention.cs
index a44d64c506..d50791a965 100644
--- a/src/Microsoft.ML.TorchSharp/AutoFormerV2/Attention.cs
+++ b/src/Microsoft.ML.TorchSharp/AutoFormerV2/Attention.cs
@@ -113,7 +113,7 @@ public override Tensor forward(Tensor x, Tensor mask)
k = k.permute(0, 2, 1, 3);
v = v.permute(0, 2, 1, 3);
- var attn = (torch.matmul(q, k.transpose(-2, -1)) * this.scale) + this.attention_biases[.., this.attention_bias_idxs];
+ var attn = (torch.matmul(q, k.transpose(-2, -1)) * this.scale) + this.attention_biases[RangeUtil.ToTensorIndex(..), this.attention_bias_idxs];
if (!(mask is null))
{
long nW = mask.shape[0];
diff --git a/src/Microsoft.ML.TorchSharp/AutoFormerV2/AutoFormerV2Block.cs b/src/Microsoft.ML.TorchSharp/AutoFormerV2/AutoFormerV2Block.cs
index 6bba3fc596..28b9a948d9 100644
--- a/src/Microsoft.ML.TorchSharp/AutoFormerV2/AutoFormerV2Block.cs
+++ b/src/Microsoft.ML.TorchSharp/AutoFormerV2/AutoFormerV2Block.cs
@@ -127,7 +127,7 @@ public override Tensor forward(Tensor x, int h, int w, Tensor maskMatrix)
}
else
{
- x = x[.., ..h, ..w].contiguous();
+ x = x[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..h), RangeUtil.ToTensorIndex(..w)].contiguous();
}
}
diff --git a/src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionTrainer.cs b/src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionTrainer.cs
index 6f3732c72b..735e135691 100644
--- a/src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionTrainer.cs
+++ b/src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionTrainer.cs
@@ -384,7 +384,7 @@ private bool TrainStep(IHost host,
var padW = 32 - (image.Width % 32);
var padH = 32 - (image.Height % 32);
using var transMidTensor = torch.zeros(1, 3, image.Height + padH, image.Width + padW, device: Device);
- transMidTensor[.., .., ..image.Height, ..image.Width] = reMidTensor / 255.0;
+ transMidTensor[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..image.Height), RangeUtil.ToTensorIndex(..image.Width)] = reMidTensor / 255.0;
var imageTensor = Normalize(transMidTensor, Device);
VBuffer labels = default;
@@ -407,11 +407,11 @@ private bool TrainStep(IHost host,
long y1 = (long)boxValues[b++];
// Our labels are 1 based, the TorchSharp model is 0 based so subtract 1 to they align correctly.
long cl = labelValues[i] - 1;
- labelTensor[.., i, 0] = x0;
- labelTensor[.., i, 1] = y0;
- labelTensor[.., i, 2] = x1;
- labelTensor[.., i, 3] = y1;
- labelTensor[.., i, 4] = cl;
+ labelTensor[RangeUtil.ToTensorIndex(..), i, 0] = x0;
+ labelTensor[RangeUtil.ToTensorIndex(..), i, 1] = y0;
+ labelTensor[RangeUtil.ToTensorIndex(..), i, 2] = x1;
+ labelTensor[RangeUtil.ToTensorIndex(..), i, 3] = y1;
+ labelTensor[RangeUtil.ToTensorIndex(..), i, 4] = cl;
}
return (imageTensor.MoveToOuterDisposeScope(), labelTensor.MoveToOuterDisposeScope());
}
@@ -919,7 +919,7 @@ private Tensor PrepInputTensors(ref MLImage image, ValueGetter imageGet
var padW = 32 - (image.Width % 32);
var padH = 32 - (image.Height % 32);
var transMidTensor = torch.zeros(1, 3, image.Height + padH, image.Width + padW, device: _parent.Device);
- transMidTensor[.., .., ..image.Height, ..image.Width] = reMidTensor / 255.0;
+ transMidTensor[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..image.Height), RangeUtil.ToTensorIndex(..image.Width)] = reMidTensor / 255.0;
var imageTensor = ObjectDetectionTrainer.Trainer.Normalize(transMidTensor, _parent.Device);
return imageTensor.MoveToOuterDisposeScope();
}
diff --git a/src/Microsoft.ML.TorchSharp/Loss/FocalLoss.cs b/src/Microsoft.ML.TorchSharp/Loss/FocalLoss.cs
index 3954677526..45ebeb4aae 100644
--- a/src/Microsoft.ML.TorchSharp/Loss/FocalLoss.cs
+++ b/src/Microsoft.ML.TorchSharp/Loss/FocalLoss.cs
@@ -40,20 +40,20 @@ public override Tensor forward(Tensor classifications, Tensor regressions, Tenso
var classificationLosses = new List();
var regressionLosses = new List();
- var anchor = anchors[0, .., ..];
+ var anchor = anchors[0, RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..)];
- var anchorWidths = anchor[.., 2] - anchor[.., 0];
- var anchorHeights = anchor[.., 3] - anchor[.., 1];
- var anchorCtrX = anchor[.., 0] + (0.5 * anchorWidths);
- var anchorCtrY = anchor[.., 1] + (0.5 * anchorHeights);
+ var anchorWidths = anchor[RangeUtil.ToTensorIndex(..), 2] - anchor[RangeUtil.ToTensorIndex(..), 0];
+ var anchorHeights = anchor[RangeUtil.ToTensorIndex(..), 3] - anchor[RangeUtil.ToTensorIndex(..), 1];
+ var anchorCtrX = anchor[RangeUtil.ToTensorIndex(..), 0] + (0.5 * anchorWidths);
+ var anchorCtrY = anchor[RangeUtil.ToTensorIndex(..), 1] + (0.5 * anchorHeights);
for (int j = 0; j < batchSize; ++j)
{
- var classification = classifications[j, .., ..];
- var regression = regressions[j, .., ..];
+ var classification = classifications[j, RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..)];
+ var regression = regressions[j, RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..)];
- var bboxAnnotation = annotations[j, .., ..];
- bboxAnnotation = bboxAnnotation[bboxAnnotation[.., 4] != -1];
+ var bboxAnnotation = annotations[j, RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..)];
+ bboxAnnotation = bboxAnnotation[bboxAnnotation[RangeUtil.ToTensorIndex(..), 4] != -1];
classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4);
@@ -73,7 +73,7 @@ public override Tensor forward(Tensor classifications, Tensor regressions, Tenso
}
else
{
- var iou = CalcIou(anchors[0, .., ..], bboxAnnotation[.., ..4]); // num_anchors x num_annotations
+ var iou = CalcIou(anchors[0, RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..)], bboxAnnotation[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..4)]); // num_anchors x num_annotations
var (iou_max, iou_argmax) = torch.max(iou, dim: 1); // num_anchors x 1
@@ -125,10 +125,10 @@ public override Tensor forward(Tensor classifications, Tensor regressions, Tenso
var anchorCtrXPi = anchorCtrX[positiveIndices];
var anchorCtrYPi = anchorCtrY[positiveIndices];
- var gtWidths = assignedAnnotations[.., 2] - assignedAnnotations[.., 0];
- var gtHeights = assignedAnnotations[.., 3] - assignedAnnotations[.., 1];
- var gtCtrX = assignedAnnotations[.., 0] + (0.5 * gtWidths);
- var gtCtrY = assignedAnnotations[.., 1] + (0.5 * gtHeights);
+ var gtWidths = assignedAnnotations[RangeUtil.ToTensorIndex(..), 2] - assignedAnnotations[RangeUtil.ToTensorIndex(..), 0];
+ var gtHeights = assignedAnnotations[RangeUtil.ToTensorIndex(..), 3] - assignedAnnotations[RangeUtil.ToTensorIndex(..), 1];
+ var gtCtrX = assignedAnnotations[RangeUtil.ToTensorIndex(..), 0] + (0.5 * gtWidths);
+ var gtCtrY = assignedAnnotations[RangeUtil.ToTensorIndex(..), 1] + (0.5 * gtHeights);
// clip widths to 1
gtWidths = torch.clamp(gtWidths, min: 1);
@@ -178,17 +178,17 @@ private object ToTensorIndex()
private static Tensor CalcIou(Tensor a, Tensor b)
{
- var area = (b[.., 2] - b[.., 0]) * (b[.., 3] - b[.., 1]);
+ var area = (b[RangeUtil.ToTensorIndex(..), 2] - b[RangeUtil.ToTensorIndex(..), 0]) * (b[RangeUtil.ToTensorIndex(..), 3] - b[RangeUtil.ToTensorIndex(..), 1]);
- var iw = torch.minimum(input: torch.unsqueeze(a[.., 2], dim: 1), b[.., 2]) -
- torch.maximum(input: torch.unsqueeze(a[.., 0], 1), b[.., 0]);
- var ih = torch.minimum(input: torch.unsqueeze(a[.., 3], dim: 1), b[.., 3]) -
- torch.maximum(input: torch.unsqueeze(a[.., 1], 1), b[.., 1]);
+ var iw = torch.minimum(input: torch.unsqueeze(a[RangeUtil.ToTensorIndex(..), 2], dim: 1), b[RangeUtil.ToTensorIndex(..), 2]) -
+ torch.maximum(input: torch.unsqueeze(a[RangeUtil.ToTensorIndex(..), 0], 1), b[RangeUtil.ToTensorIndex(..), 0]);
+ var ih = torch.minimum(input: torch.unsqueeze(a[RangeUtil.ToTensorIndex(..), 3], dim: 1), b[RangeUtil.ToTensorIndex(..), 3]) -
+ torch.maximum(input: torch.unsqueeze(a[RangeUtil.ToTensorIndex(..), 1], 1), b[RangeUtil.ToTensorIndex(..), 1]);
iw = torch.clamp(iw, min: 0);
ih = torch.clamp(ih, min: 0);
- var ua = torch.unsqueeze((a[.., 2] - a[.., 0]) * (a[.., 3] - a[.., 1]), dim: 1) + area - (iw * ih);
+ var ua = torch.unsqueeze((a[RangeUtil.ToTensorIndex(..), 2] - a[RangeUtil.ToTensorIndex(..), 0]) * (a[RangeUtil.ToTensorIndex(..), 3] - a[RangeUtil.ToTensorIndex(..), 1]), dim: 1) + area - (iw * ih);
ua = torch.clamp(ua, min: 1e-8);
var intersection = iw * ih;
diff --git a/src/Microsoft.ML.TorchSharp/Microsoft.ML.TorchSharp.csproj b/src/Microsoft.ML.TorchSharp/Microsoft.ML.TorchSharp.csproj
index 698dbfd623..c347333d27 100644
--- a/src/Microsoft.ML.TorchSharp/Microsoft.ML.TorchSharp.csproj
+++ b/src/Microsoft.ML.TorchSharp/Microsoft.ML.TorchSharp.csproj
@@ -19,6 +19,7 @@
+
@@ -32,13 +33,13 @@
dict.txt
-
+
encoder.json
-
+
vocab.bpe
-
+
diff --git a/src/Microsoft.ML.TorchSharp/Utils/ImageUtils.cs b/src/Microsoft.ML.TorchSharp/Utils/ImageUtils.cs
index 7d2e0d3850..cd158fa5d8 100644
--- a/src/Microsoft.ML.TorchSharp/Utils/ImageUtils.cs
+++ b/src/Microsoft.ML.TorchSharp/Utils/ImageUtils.cs
@@ -50,7 +50,7 @@ public static void Postprocess(Tensor imgBatch, Tensor classification, Tensor re
for (int i = 0; i < classification.shape[2]; ++i)
{
- var scores1 = torch.squeeze(classification[.., .., i], null);
+ var scores1 = torch.squeeze(classification[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), i], null);
var scoresOverThresh = scores1 > 0.05;
if (scoresOverThresh.sum().ToSingle() == 0)
{
@@ -108,16 +108,16 @@ private static Tensor Nms(Tensor boxes, Tensor scores, double iouThreshold = 0.5
using (var nmsScope = torch.NewDisposeScope())
{
// boxes: Tensor [N,4],scores: Tensor [N,]
- var x1 = boxes[.., 0];
- var y1 = boxes[.., 1];
- var x2 = boxes[.., 2];
- var y2 = boxes[.., 3];
+ var x1 = boxes[RangeUtil.ToTensorIndex(..), 0];
+ var y1 = boxes[RangeUtil.ToTensorIndex(..), 1];
+ var x2 = boxes[RangeUtil.ToTensorIndex(..), 2];
+ var y2 = boxes[RangeUtil.ToTensorIndex(..), 3];
var areas = (x2 - x1) * (y2 - y1); // [N,]
var (_, _order) = scores.sort(0, descending: true);
var keep = new List();
- var order = _order[..];
+ var order = _order[RangeUtil.ToTensorIndex(..)];
while (order.numel() > 0)
{
long i;
@@ -133,13 +133,13 @@ private static Tensor Nms(Tensor boxes, Tensor scores, double iouThreshold = 0.5
keep.Add(i);
}
- var xx1 = x1[order[1..]].clamp(min: x1[i]); // [N - 1,]
- var yy1 = y1[order[1..]].clamp(min: y1[i]);
- var xx2 = x2[order[1..]].clamp(max: x2[i]);
- var yy2 = y2[order[1..]].clamp(max: y2[i]);
+ var xx1 = x1[order[RangeUtil.ToTensorIndex(1..)]].clamp(min: x1[i]); // [N - 1,]
+ var yy1 = y1[order[RangeUtil.ToTensorIndex(1..)]].clamp(min: y1[i]);
+ var xx2 = x2[order[RangeUtil.ToTensorIndex(1..)]].clamp(max: x2[i]);
+ var yy2 = y2[order[RangeUtil.ToTensorIndex(1..)]].clamp(max: y2[i]);
var inter = (xx2 - xx1).clamp(min: 0) * (yy2 - yy1).clamp(min: 0); // [N - 1,]
- var iou = inter / (areas[i] + areas[order[1..]] - inter); // [N-1, ]
+ var iou = inter / (areas[i] + areas[order[RangeUtil.ToTensorIndex(1..)]] - inter); // [N-1, ]
var idx = (iou <= iouThreshold).nonzero().squeeze(); // idx: [N - 1,] and order:[N,]
if (idx.numel() == 0)
{
@@ -167,15 +167,15 @@ private static Tensor TransformBbox(Tensor boxes, Tensor deltas)
var mean = torch.from_array(new double[] { 0, 0, 0, 0 }).to_type(ScalarType.Float32).to(boxes.device);
var std = torch.from_array(new double[] { 0.1, 0.1, 0.2, 0.2 }).to_type(ScalarType.Float32).to(boxes.device);
- var widths = boxes[.., .., 2] - boxes[.., .., 0];
- var heights = boxes[.., .., 3] - boxes[.., .., 1];
- var ctrX = boxes[.., .., 0] + (0.5 * widths);
- var ctrY = boxes[.., .., 1] + (0.5 * heights);
+ var widths = boxes[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 2] - boxes[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 0];
+ var heights = boxes[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 3] - boxes[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 1];
+ var ctrX = boxes[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 0] + (0.5 * widths);
+ var ctrY = boxes[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 1] + (0.5 * heights);
- var dx = (deltas[.., .., 0] * std[0]) + mean[0];
- var dy = (deltas[.., .., 1] * std[1]) + mean[1];
- var dw = (deltas[.., .., 2] * std[2]) + mean[2];
- var dh = (deltas[.., .., 3] * std[3]) + mean[3];
+ var dx = (deltas[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 0] * std[0]) + mean[0];
+ var dy = (deltas[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 1] * std[1]) + mean[1];
+ var dw = (deltas[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 2] * std[2]) + mean[2];
+ var dh = (deltas[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 3] * std[3]) + mean[3];
var predCtrX = ctrX + (dx * widths);
var predCtrY = ctrY + (dy * heights);
@@ -210,11 +210,11 @@ private static Tensor ClipBoxes(Tensor boxes, Tensor img)
var height = img.shape[2];
var width = img.shape[3];
- var clippedBoxesX0 = torch.clamp(boxes[.., .., 0], min: 0);
- var clippedBoxesY0 = torch.clamp(boxes[.., .., 1], min: 0);
+ var clippedBoxesX0 = torch.clamp(boxes[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 0], min: 0);
+ var clippedBoxesY0 = torch.clamp(boxes[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 1], min: 0);
- var clippedBoxesX1 = torch.clamp(boxes[.., .., 2], max: width);
- var clippedBoxesY1 = torch.clamp(boxes[.., .., 3], max: height);
+ var clippedBoxesX1 = torch.clamp(boxes[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 2], max: width);
+ var clippedBoxesY1 = torch.clamp(boxes[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), 3], max: height);
var clippedBoxes = torch.stack(
new List { clippedBoxesX0, clippedBoxesY0, clippedBoxesX1, clippedBoxesY1 },
diff --git a/src/Microsoft.ML.TorchSharp/Utils/Index.cs b/src/Microsoft.ML.TorchSharp/Utils/Index.cs
deleted file mode 100644
index 20f59a2e50..0000000000
--- a/src/Microsoft.ML.TorchSharp/Utils/Index.cs
+++ /dev/null
@@ -1,145 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System.Runtime.CompilerServices;
-
-namespace System
-{
- /// Represent a type can be used to index a collection either from the start or the end.
- ///
- /// Index is used by the C# compiler to support the new index syntax
- ///
- /// int[] someArray = new int[5] { 1, 2, 3, 4, 5 } ;
- /// int lastElement = someArray[^1]; // lastElement = 5
- ///
- ///
- internal readonly struct Index : IEquatable
- {
- private readonly int _value;
-
- /// Construct an Index using a value and indicating if the index is from the start or from the end.
- /// The index value. it has to be zero or positive number.
- /// Indicating if the index is from the start or from the end.
- ///
- /// If the Index constructed from the end, index value 1 means pointing at the last element and index value 0 means pointing at beyond last element.
- ///
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public Index(int value, bool fromEnd = false)
- {
- if (value < 0)
- {
- throw new ArgumentOutOfRangeException(nameof(value), "Non-negative number required.");
- }
-
- if (fromEnd)
- _value = ~value;
- else
- _value = value;
- }
-
- // The following private constructors mainly created for perf reason to avoid the checks
- private Index(int value)
- {
- _value = value;
- }
-
- /// Create an Index pointing at first element.
- public static Index Start => new Index(0);
-
- /// Create an Index pointing at beyond last element.
- public static Index End => new Index(~0);
-
- /// Create an Index from the start at the position indicated by the value.
- /// The index value from the start.
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static Index FromStart(int value)
- {
- if (value < 0)
- {
- throw new ArgumentOutOfRangeException(nameof(value), "Non-negative number required.");
- }
-
- return new Index(value);
- }
-
- /// Create an Index from the end at the position indicated by the value.
- /// The index value from the end.
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static Index FromEnd(int value)
- {
- if (value < 0)
- {
- throw new ArgumentOutOfRangeException(nameof(value), "Non-negative number required.");
- }
-
- return new Index(~value);
- }
-
- /// Returns the index value.
- public int Value
- {
- get
- {
- if (_value < 0)
- return ~_value;
- else
- return _value;
- }
- }
-
- /// Indicates whether the index is from the start or the end.
- public bool IsFromEnd => _value < 0;
-
- /// Calculate the offset from the start using the giving collection length.
- /// The length of the collection that the Index will be used with. length has to be a positive value
- ///
- /// For performance reason, we don't validate the input length parameter and the returned offset value against negative values.
- /// we don't validate either the returned offset is greater than the input length.
- /// It is expected Index will be used with collections which always have non negative length/count. If the returned offset is negative and
- /// then used to index a collection will get out of range exception which will be same affect as the validation.
- ///
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public int GetOffset(int length)
- {
- var offset = _value;
- if (IsFromEnd)
- {
- // offset = length - (~value)
- // offset = length + (~(~value) + 1)
- // offset = length + value + 1
-
- offset += length + 1;
- }
- return offset;
- }
-
- /// Indicates whether the current Index object is equal to another object of the same type.
- /// An object to compare with this object
- public override bool Equals(object value) => value is Index && _value == ((Index)value)._value;
-
- /// Indicates whether the current Index object is equal to another Index object.
- /// An object to compare with this object
- public bool Equals(Index other) => _value == other._value;
-
- /// Returns the hash code for this instance.
- public override int GetHashCode() => _value;
-
- /// Converts integer number to an Index.
- public static implicit operator Index(int value) => FromStart(value);
-
- /// Converts the value of the current Index object to its equivalent string representation.
- public override string ToString()
- {
- if (IsFromEnd)
- return ToStringFromEnd();
-
- return ((uint)Value).ToString();
- }
-
- private string ToStringFromEnd()
- {
- return '^' + Value.ToString();
- }
- }
-}
diff --git a/src/Microsoft.ML.TorchSharp/Utils/Range.cs b/src/Microsoft.ML.TorchSharp/Utils/Range.cs
deleted file mode 100644
index b372aed591..0000000000
--- a/src/Microsoft.ML.TorchSharp/Utils/Range.cs
+++ /dev/null
@@ -1,141 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System.Diagnostics;
-using System.Runtime.CompilerServices;
-using System.Runtime.InteropServices;
-using Microsoft.ML.TorchSharp.Utils;
-using static TorchSharp.torch;
-
-namespace System
-{
- /// Represent a range has start and end indexes.
- ///
- /// Range is used by the C# compiler to support the range syntax.
- ///
- /// int[] someArray = new int[5] { 1, 2, 3, 4, 5 };
- /// int[] subArray1 = someArray[0..2]; // { 1, 2 }
- /// int[] subArray2 = someArray[1..^0]; // { 2, 3, 4, 5 }
- ///
- ///
- internal readonly struct Range : IEquatable
- {
- /// Represent the inclusive start index of the Range.
- public Index Start { get; }
-
- /// Represent the exclusive end index of the Range.
- public Index End { get; }
-
- /// Construct a Range object using the start and end indexes.
- /// Represent the inclusive start index of the range.
- /// Represent the exclusive end index of the range.
- public Range(Index start, Index end)
- {
- Start = start;
- End = end;
- }
-
- /// Indicates whether the current Range object is equal to another object of the same type.
- /// An object to compare with this object
- public override bool Equals(object value) =>
- value is Range r &&
- r.Start.Equals(Start) &&
- r.End.Equals(End);
-
- /// Indicates whether the current Range object is equal to another Range object.
- /// An object to compare with this object
- public bool Equals(Range other) => other.Start.Equals(Start) && other.End.Equals(End);
-
- /// Returns the hash code for this instance.
- public override int GetHashCode()
- {
-#if (!NETSTANDARD2_0 && !NETFRAMEWORK)
- return HashCode.Combine(Start.GetHashCode(), End.GetHashCode());
-#else
- return HashHelpers.Combine(Start.GetHashCode(), End.GetHashCode());
-#endif
- }
-
- /// Converts the value of the current Range object to its equivalent string representation.
- public override string ToString()
- {
-#if (!NETSTANDARD2_0 && !NETFRAMEWORK)
- Span span = stackalloc char[2 + (2 * 11)]; // 2 for "..", then for each index 1 for '^' and 10 for longest possible uint
- int pos = 0;
-
- if (Start.IsFromEnd)
- {
- span[0] = '^';
- pos = 1;
- }
- bool formatted = ((uint)Start.Value).TryFormat(span.Slice(pos), out int charsWritten);
- Debug.Assert(formatted);
- pos += charsWritten;
-
- span[pos++] = '.';
- span[pos++] = '.';
-
- if (End.IsFromEnd)
- {
- span[pos++] = '^';
- }
- formatted = ((uint)End.Value).TryFormat(span.Slice(pos), out charsWritten);
- Debug.Assert(formatted);
- pos += charsWritten;
-
- return new string(span.Slice(0, pos));
-#else
- return Start.ToString() + ".." + End.ToString();
-#endif
- }
-
- /// Create a Range object starting from start index to the end of the collection.
- public static Range StartAt(Index start) => new Range(start, Index.End);
-
- /// Create a Range object starting from first element in the collection to the end Index.
- public static Range EndAt(Index end) => new Range(Index.Start, end);
-
- /// Create a Range object starting from first element to the end.
- public static Range All => new Range(Index.Start, Index.End);
-
- /// Calculate the start offset and length of range object using a collection length.
- /// The length of the collection that the range will be used with. length has to be a positive value.
- ///
- /// For performance reason, we don't validate the input length parameter against negative values.
- /// It is expected Range will be used with collections which always have non negative length/count.
- /// We validate the range is inside the length scope though.
- ///
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public (int Offset, int Length) GetOffsetAndLength(int length)
- {
- int start;
- var startIndex = Start;
- if (startIndex.IsFromEnd)
- start = length - startIndex.Value;
- else
- start = startIndex.Value;
-
- int end;
- var endIndex = End;
- if (endIndex.IsFromEnd)
- end = length - endIndex.Value;
- else
- end = endIndex.Value;
-
- if ((uint)end > (uint)length || (uint)start > (uint)end)
- {
- throw new ArgumentOutOfRangeException(nameof(length));
- }
-
- return (start, end - start);
- }
-
- public static implicit operator TensorIndex(Range range)
- {
- long? start = !range.Start.IsFromEnd ? range.Start.Value : -1 * range.Start.Value;
- var stop = !range.End.IsFromEnd ? new long?(range.End.Value) : range.End.Value == 0 ? null : new long?(-1 * range.End.Value);
- return TensorIndex.Slice(start, stop);
- }
- }
-}
diff --git a/src/Microsoft.ML.TorchSharp/Utils/RangeUtil.cs b/src/Microsoft.ML.TorchSharp/Utils/RangeUtil.cs
new file mode 100644
index 0000000000..50f10eb431
--- /dev/null
+++ b/src/Microsoft.ML.TorchSharp/Utils/RangeUtil.cs
@@ -0,0 +1,19 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using static TorchSharp.torch;
+
+namespace Microsoft.ML.TorchSharp
+{
+ internal static class RangeUtil
+ {
+ public static TensorIndex ToTensorIndex(this Range range)
+ {
+ long? start = !range.Start.IsFromEnd ? range.Start.Value : -1 * range.Start.Value;
+ var stop = !range.End.IsFromEnd ? new long?(range.End.Value) : range.End.Value == 0 ? null : new long?(-1 * range.End.Value);
+ return TensorIndex.Slice(start, stop);
+ }
+ }
+}
diff --git a/test/Microsoft.ML.AutoML.Tests/Microsoft.ML.AutoML.Tests.csproj b/test/Microsoft.ML.AutoML.Tests/Microsoft.ML.AutoML.Tests.csproj
index 8c65cf0621..149962617d 100644
--- a/test/Microsoft.ML.AutoML.Tests/Microsoft.ML.AutoML.Tests.csproj
+++ b/test/Microsoft.ML.AutoML.Tests/Microsoft.ML.AutoML.Tests.csproj
@@ -5,6 +5,10 @@
None
+
+ true
+
+
diff --git a/test/Microsoft.ML.CodeGenerator.Tests/Microsoft.ML.CodeGenerator.Tests.csproj b/test/Microsoft.ML.CodeGenerator.Tests/Microsoft.ML.CodeGenerator.Tests.csproj
index 4bff917a66..af3f6b1d13 100644
--- a/test/Microsoft.ML.CodeGenerator.Tests/Microsoft.ML.CodeGenerator.Tests.csproj
+++ b/test/Microsoft.ML.CodeGenerator.Tests/Microsoft.ML.CodeGenerator.Tests.csproj
@@ -5,6 +5,10 @@
None
+
+ true
+
+
diff --git a/test/Microsoft.ML.Fairlearn.Tests/Microsoft.ML.Fairlearn.Tests.csproj b/test/Microsoft.ML.Fairlearn.Tests/Microsoft.ML.Fairlearn.Tests.csproj
index 09faf80224..ab5b0aba34 100644
--- a/test/Microsoft.ML.Fairlearn.Tests/Microsoft.ML.Fairlearn.Tests.csproj
+++ b/test/Microsoft.ML.Fairlearn.Tests/Microsoft.ML.Fairlearn.Tests.csproj
@@ -5,6 +5,10 @@
$(NoWarn);MSML_ParameterLocalVarName;MSML_PrivateFieldName;MSML_ExtendBaseTestClass;MSML_GeneralName
+
+ true
+
+
diff --git a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj
index 2ded80987a..f07f80089e 100644
--- a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj
+++ b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj
@@ -6,6 +6,7 @@
$(NoWarn);MSML_ExtendBaseTestClass
enable
true
+ true
diff --git a/test/Microsoft.ML.GenAI.LLaMA.Tests/Microsoft.ML.GenAI.LLaMA.Tests.csproj b/test/Microsoft.ML.GenAI.LLaMA.Tests/Microsoft.ML.GenAI.LLaMA.Tests.csproj
index d135f09bbb..62d0fed2fd 100644
--- a/test/Microsoft.ML.GenAI.LLaMA.Tests/Microsoft.ML.GenAI.LLaMA.Tests.csproj
+++ b/test/Microsoft.ML.GenAI.LLaMA.Tests/Microsoft.ML.GenAI.LLaMA.Tests.csproj
@@ -8,6 +8,10 @@
true
+
+ true
+
+
diff --git a/test/Microsoft.ML.GenAI.Mistral.Tests/Microsoft.ML.GenAI.Mistral.Tests.csproj b/test/Microsoft.ML.GenAI.Mistral.Tests/Microsoft.ML.GenAI.Mistral.Tests.csproj
index 4715947431..6852856a4e 100644
--- a/test/Microsoft.ML.GenAI.Mistral.Tests/Microsoft.ML.GenAI.Mistral.Tests.csproj
+++ b/test/Microsoft.ML.GenAI.Mistral.Tests/Microsoft.ML.GenAI.Mistral.Tests.csproj
@@ -8,6 +8,10 @@
true
+
+ true
+
+
diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj
index dec8dbbb25..d86f06c8a0 100644
--- a/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj
+++ b/test/Microsoft.ML.GenAI.Phi.Tests/Microsoft.ML.GenAI.Phi.Tests.csproj
@@ -6,6 +6,7 @@
$(NoWarn);MSML_ExtendBaseTestClass
enable
true
+ true
diff --git a/test/Microsoft.ML.Tokenizers.Data.Tests/Microsoft.ML.Tokenizers.Data.Tests.csproj b/test/Microsoft.ML.Tokenizers.Data.Tests/Microsoft.ML.Tokenizers.Data.Tests.csproj
index fe4dce9c2e..0bb5927412 100644
--- a/test/Microsoft.ML.Tokenizers.Data.Tests/Microsoft.ML.Tokenizers.Data.Tests.csproj
+++ b/test/Microsoft.ML.Tokenizers.Data.Tests/Microsoft.ML.Tokenizers.Data.Tests.csproj
@@ -7,6 +7,10 @@
enable
+
+ true
+
+
diff --git a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs
index 1fbb56128f..6fb5619660 100644
--- a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs
+++ b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs
@@ -251,7 +251,7 @@ public void SimpleTestWithUnknownToken(
try
{
- BpeTokenizer bpe = BpeTokenizer.Create(vocabFile: vocabFile, mergesFile: mergesFile, preTokenizer: WhiteSpacePreTokenizer.Instance, normalizer: null, unknownToken: unknownToken,
+ BpeTokenizer bpe = BpeTokenizer.Create(vocabFile: vocabFile, mergesFile: mergesFile, preTokenizer: PreTokenizer.CreateWhiteSpace(), normalizer: null, unknownToken: unknownToken,
continuingSubwordPrefix: continuingSubwordPrefix, endOfWordSuffix: endOfWordSuffix, fuseUnknownTokens: fuseUnknownToken);
Tokenizer tokenizer = bpe;
IReadOnlyList encoding = tokenizer.EncodeToTokens(sentence, out _);
@@ -274,7 +274,7 @@ public void SimpleTestWithUnknownToken(
for (int i = 0; i < encoding.Count; i++)
{
Assert.Equal(expectedTokens[i], encoding[i].Value);
- Assert.Equal(offsets[i], encoding[i].Offset);
+ Assert.Equal(offsets[i], (encoding[i].Offset.Start.Value, encoding[i].Offset.End.Value - encoding[i].Offset.Start.Value));
Assert.Equal(ids[i], encoding[i].Id);
Assert.Equal(ids[i], idsList[i]);
Assert.Equal(encoding[i].Value, reverseVocabulary[encodingIds[i]]);
@@ -430,11 +430,11 @@ public void TestBpeTokenizer(string text, string[] expectedTokens, (int Index, i
IReadOnlyList encoding1 = tokenizer.EncodeToTokens(text.AsSpan(), out _);
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
- Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray());
+ Assert.Equal(expectedOffsets, encoding.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray());
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding1.Select(t => t.Value).ToArray());
- Assert.Equal(expectedOffsets, encoding1.Select(t => t.Offset).ToArray());
+ Assert.Equal(expectedOffsets, encoding1.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray());
Assert.Equal(expectedIds, encoding1.Select(t => t.Id).ToArray());
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text));
@@ -472,6 +472,62 @@ public void TestBpeTokenizer(string text, string[] expectedTokens, (int Index, i
Assert.Equal(3, tokenCount);
}
+ [Fact]
+ public void TestWithAddedTokens()
+ {
+ // Picked from https://huggingface.co/HuggingFaceTB/SmolLM-135M-Instruct/raw/main/tokenizer.json
+ IReadOnlyDictionary addedTokens = new Dictionary()
+ {
+ {"<|endoftext|>", 0 },
+ {"<|im_start|>", 1 },
+ {"<|im_end|>", 2 },
+ {"", 3 },
+ {"", 4 },
+ {"", 5 },
+ {"", 6 },
+ {"", 7 },
+ {"", 8 },
+ {"", 9 },
+ {"", 10 },
+ {"", 11 },
+ {"", 12 },
+ {"", 13 },
+ {"", 14 },
+ {"", 15 },
+ {"", 16 },
+ };
+
+ using Stream vocabStream = File.OpenRead(Path.Combine(@"Gpt-2", "vocab.json"));
+ using Stream mergesStream = File.OpenRead(Path.Combine(@"Gpt-2", "merges.txt"));
+
+ var bpeTokenizer = BpeTokenizer.Create(vocabStream, mergesStream, PreTokenizer.CreateWhiteSpace(addedTokens), normalizer: null, addedTokens: addedTokens, unknownToken: "<|endoftext|>");
+
+ string input = "Hello, y'all! How are you 😁 ?<|endoftext|>";
+
+ IReadOnlyList tokens = bpeTokenizer.EncodeToTokens(input, out _);
+
+ EncodedToken[] expectedTokens = [
+ new EncodedToken(15496, "Hello", new Range(0, 5)),
+ new EncodedToken(11, ",", new Range(5, 6)),
+ new EncodedToken(88, "y", new Range(7, 8)),
+ new EncodedToken(6, "'", new Range(8, 9)),
+ new EncodedToken(439, "all", new Range(9, 12)),
+ new EncodedToken(0, "!", new Range(12, 13)),
+ new EncodedToken(9, "", new Range(14, 29)),
+ new EncodedToken(2437, "How", new Range(29, 32)),
+ new EncodedToken(533, "are", new Range(33, 36)),
+ new EncodedToken(5832, "you", new Range(37, 40)),
+ new EncodedToken(50256, "<|endoftext|>", new Range(41, 43)),
+ new EncodedToken(30, "?", new Range(44, 45)),
+ new EncodedToken(0, "<|endoftext|>", new Range(45, 58))
+ ];
+
+ Assert.Equal(expectedTokens, tokens);
+
+ IReadOnlyList ids = bpeTokenizer.EncodeToIds(input);
+ Assert.Equal(expectedTokens.Select(t => t.Id).ToArray(), ids);
+ }
+
private static string WriteToMergeFile((string, string)[] mergeEntries)
{
string fileName = Utils.CreateTemporaryFile("txt");
@@ -500,7 +556,7 @@ internal static BpeTokenizer CreateEmptyBpe(PreTokenizer? preTokenizer = null, N
emptyVocabStream.Position = 0;
return BpeTokenizer.Create(
- vocabStream: emptyVocabStream, mergesStream: null, preTokenizer: preTokenizer ?? WhiteSpacePreTokenizer.Instance, normalizer: normalizer, unknownToken: "Ukn");
+ vocabStream: emptyVocabStream, mergesStream: null, preTokenizer: preTokenizer ?? PreTokenizer.CreateWhiteSpace(), normalizer: normalizer, unknownToken: "Ukn");
}
}
}
diff --git a/test/Microsoft.ML.Tokenizers.Tests/CodeGenTests.cs b/test/Microsoft.ML.Tokenizers.Tests/CodeGenTests.cs
index a4273f040c..4965ce064a 100644
--- a/test/Microsoft.ML.Tokenizers.Tests/CodeGenTests.cs
+++ b/test/Microsoft.ML.Tokenizers.Tests/CodeGenTests.cs
@@ -235,13 +235,13 @@ private void ValidateEncoding(IReadOnlyList encoding, bool addPref
{
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
- Assert.Equal(expectedOffsetsWithSpace, encoding.Select(t => t.Offset).ToArray());
+ Assert.Equal(expectedOffsetsWithSpace, encoding.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray());
}
else
{
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
- Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray());
+ Assert.Equal(expectedOffsets, encoding.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray());
}
}
@@ -555,22 +555,22 @@ public void TestBegginingAndEndOfSentenceEncoding(
tokensList.Insert(0, codeGenTokenizer.BeginningOfSentenceToken!);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
- Assert.Equal((0, 0), encoding[0].Offset);
+ Assert.Equal((0, 0), (encoding[0].Offset.Start.Value, encoding[0].Offset.End.Value));
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
- Assert.Equal((0, 0), encoding[0].Offset);
+ Assert.Equal((0, 0), (encoding[0].Offset.Start.Value, encoding[0].Offset.End.Value));
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
- Assert.Equal((0, 0), encoding[0].Offset);
+ Assert.Equal((0, 0), (encoding[0].Offset.Start.Value, encoding[0].Offset.End.Value));
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: false, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
- Assert.Equal((0, 0), encoding[0].Offset);
+ Assert.Equal((0, 0), (encoding[0].Offset.Start.Value, encoding[0].Offset.End.Value));
idList = new List(expectedIdsWithSpace);
idList.Insert(0, codeGenTokenizer.BeginningOfSentenceId!.Value);
@@ -579,32 +579,32 @@ public void TestBegginingAndEndOfSentenceEncoding(
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: false, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
- Assert.Equal((0, 0), encoding[0].Offset);
+ Assert.Equal((0, 0), (encoding[0].Offset.Start.Value, encoding[0].Offset.End.Value));
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: false, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
- Assert.Equal((0, 0), encoding[0].Offset);
+ Assert.Equal((0, 0), (encoding[0].Offset.Start.Value, encoding[0].Offset.End.Value));
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
- Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
+ Assert.True(!encoding[0].Offset.Equals(new Range(0, 0)) || !encoding[1].Offset.Equals(new Range(0, 0)));
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
- Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
+ Assert.True(!encoding[0].Offset.Equals(new Range(0, 0)) || !encoding[1].Offset.Equals(new Range(0, 0)));
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
- Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
+ Assert.True(!encoding[0].Offset.Equals(new Range(0, 0)) || !encoding[1].Offset.Equals(new Range(0, 0)));
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
- Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
+ Assert.True(!encoding[0].Offset.Equals(new Range(0, 0)) || !encoding[1].Offset.Equals(new Range(0, 0)));
IReadOnlyList ids = codeGenTokenizer.EncodeToIds(text);
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
@@ -688,22 +688,22 @@ public void TestBegginingAndEndOfSentenceEncoding(
tokensList.Add(codeGenTokenizer.EndOfSentenceToken!);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
- Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
+ Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
- Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
+ Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
- Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
+ Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
- Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
+ Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset);
idList = new List(expectedIdsWithSpace);
idList.Add(codeGenTokenizer.EndOfSentenceId!.Value);
@@ -712,32 +712,32 @@ public void TestBegginingAndEndOfSentenceEncoding(
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
- Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
+ Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
- Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
+ Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
- Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
+ Assert.NotEqual(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
- Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
+ Assert.NotEqual(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
- Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
+ Assert.NotEqual(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
- Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
+ Assert.NotEqual(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset);
ids = codeGenTokenizer.EncodeToIds(text);
Assert.Equal(codeGenTokenizer.EndOfSentenceId.Value, ids[ids.Count - 1]);
@@ -823,26 +823,26 @@ public void TestBegginingAndEndOfSentenceEncoding(
tokensList.Add(codeGenTokenizer.EndOfSentenceToken!);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
- Assert.Equal((0, 0), encoding[0].Offset);
- Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
+ Assert.Equal(new Range(0, 0), encoding[0].Offset);
+ Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
- Assert.Equal((0, 0), encoding[0].Offset);
- Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
+ Assert.Equal(new Range(0, 0), encoding[0].Offset);
+ Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
- Assert.Equal((0, 0), encoding[0].Offset);
- Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
+ Assert.Equal(new Range(0, 0), encoding[0].Offset);
+ Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
- Assert.Equal((0, 0), encoding[0].Offset);
- Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
+ Assert.Equal(new Range(0, 0), encoding[0].Offset);
+ Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset);
idList = new List(expectedIdsWithSpace);
idList.Insert(0, codeGenTokenizer.BeginningOfSentenceId!.Value);
@@ -853,38 +853,38 @@ public void TestBegginingAndEndOfSentenceEncoding(
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
- Assert.Equal((0, 0), encoding[0].Offset);
- Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
+ Assert.Equal(new Range(0, 0), encoding[0].Offset);
+ Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: true, addEndOfSentence: true, out _);
Assert.Equal(idList, encoding.Select(t => t.Id).ToArray());
Assert.Equal(tokensList, encoding.Select(t => t.Value).ToArray());
- Assert.Equal((0, 0), encoding[0].Offset);
- Assert.Equal((text.Length, 0), encoding[encoding.Count - 1].Offset);
+ Assert.Equal(new Range(0, 0), encoding[0].Offset);
+ Assert.Equal(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
- Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
- Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
+ Assert.True(!encoding[0].Offset.Equals(new Range(0, 0)) || !encoding[1].Offset.Equals(new Range(0, 0)));
+ Assert.NotEqual(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: false, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
- Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
- Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
+ Assert.True(!encoding[0].Offset.Equals(new Range(0, 0)) || !encoding[1].Offset.Equals(new Range(0, 0)));
+ Assert.NotEqual(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.EncodeToTokens(text, addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
- Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
- Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
+ Assert.True(!encoding[0].Offset.Equals(new Range(0, 0)) || !encoding[1].Offset.Equals(new Range(0, 0)));
+ Assert.NotEqual(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset);
encoding = codeGenTokenizer.EncodeToTokens(text.AsSpan(), addPrefixSpace: true, addBeginningOfSentence: false, addEndOfSentence: false, out _);
Assert.Equal(expectedIdsWithSpace, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokensWithSpace, encoding.Select(t => t.Value).ToArray());
- Assert.True(encoding[0].Offset != (0, 0) || encoding[1].Offset != (0, 0));
- Assert.NotEqual((text.Length, 0), encoding[encoding.Count - 1].Offset);
+ Assert.True(!encoding[0].Offset.Equals(new Range(0, 0)) || !encoding[1].Offset.Equals(new Range(0, 0)));
+ Assert.NotEqual(new Range(text.Length, text.Length), encoding[encoding.Count - 1].Offset);
ids = codeGenTokenizer.EncodeToIds(text);
Assert.Equal(codeGenTokenizer.BeginningOfSentenceId.Value, ids[0]);
diff --git a/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs b/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs
index 9014d208e1..56dec4f144 100644
--- a/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs
+++ b/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs
@@ -182,11 +182,11 @@ public void TestTokenizerEncoding(string text, string[] expectedTokens, (int Ind
IReadOnlyList encoding1 = tokenizer.EncodeToTokens(text.AsSpan(), out _);
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
- Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray());
+ Assert.Equal(expectedOffsets, encoding.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray());
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding1.Select(t => t.Value).ToArray());
- Assert.Equal(expectedOffsets, encoding1.Select(t => t.Offset).ToArray());
+ Assert.Equal(expectedOffsets, encoding1.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray());
Assert.Equal(expectedIds, encoding1.Select(t => t.Id).ToArray());
Assert.Equal(expectedIds, tokenizer.EncodeToIds(text));
@@ -264,7 +264,7 @@ private void TestTokenizer(Tokenizer tokenizer, CallingOrder callingOrder = Call
}
int[] encodingIds = encoding.Select(t => t.Id).ToArray();
- (int, int)[] offsets = encoding.Select(t => t.Offset).ToArray();
+ (int, int)[] offsets = encoding.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray();
string[] tokens = encoding.Select(t => t.Value).ToArray();
Assert.Equal(p[1], encodingIds);
diff --git a/test/Microsoft.ML.Tokenizers.Tests/LlamaTests.cs b/test/Microsoft.ML.Tokenizers.Tests/LlamaTests.cs
index 6d7178ac2d..7bd41bda45 100644
--- a/test/Microsoft.ML.Tokenizers.Tests/LlamaTests.cs
+++ b/test/Microsoft.ML.Tokenizers.Tests/LlamaTests.cs
@@ -66,7 +66,7 @@ private static Tokenizer CreateLPhi3Tokenizer(bool treatWhitespaceAsSuffix = fal
if (treatWhitespaceAsSuffix)
{
- PropertyInfo? propertyInfo = typeof(SentencePieceBpeTokenizer).GetProperty("TreatWhitespaceAsSuffix", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public);
+ PropertyInfo? propertyInfo = typeof(SentencePieceTokenizer).GetProperty("TreatWhitespaceAsSuffix", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public);
if (propertyInfo != null)
{
propertyInfo.SetValue(tokenizer, true);
@@ -244,7 +244,7 @@ public void TestLlamaTokenizer(Tokenizer tokenizer, string input, int[] ids, str
IReadOnlyList result = llamaTokenizer.EncodeToTokens(input, out _);
Assert.Equal(ids, result.Select(t => t.Id).ToArray());
Assert.Equal(tokens, result.Select(t => t.Value).ToArray());
- Assert.Equal(offsets, result.Select(t => t.Offset).ToArray());
+ Assert.Equal(offsets, result.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray());
Assert.Equal(input, llamaTokenizer.Decode(ids));
TestDecodingWithSpan(bpe, ids, input);
Assert.Equal(ids, llamaTokenizer.EncodeToIds(input));
@@ -501,14 +501,14 @@ public void TestTokenizerEncoding(string text, string normalizedText, string[] e
IReadOnlyList encoding1 = tokenizer.EncodeToTokens(text.AsSpan(), out _);
Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray());
- Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray());
+ Assert.Equal(expectedOffsets, encoding.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray());
Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray());
Assert.Equal(expectedTokens, encoding1.Select(t => t.Value).ToArray());
- Assert.Equal(expectedOffsets, encoding1.Select(t => t.Offset).ToArray());
+ Assert.Equal(expectedOffsets, encoding1.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray());
Assert.Equal(expectedIds, encoding1.Select(t => t.Id).ToArray());
- SentencePieceBpeTokenizer sentencePieceBpe = (tokenizer as SentencePieceBpeTokenizer)!;
+ SentencePieceTokenizer sentencePieceBpe = (tokenizer as SentencePieceTokenizer)!;
foreach (bool considerNormalization in new[] { true, false })
foreach (bool addBeginningOfSentence in new[] { true, false })
foreach (bool addEndOfSentence in new[] { true, false })
@@ -539,7 +539,7 @@ public void TestTokenizerEncoding(string text, string normalizedText, string[] e
expectedIds1 = addEndOfSentence ? expectedIds1.Concat(new[] { sentencePieceBpe.EndOfSentenceId }).ToArray() : expectedIds1;
Assert.Equal(expectedTokens1, encoding.Select(t => t.Value).ToArray());
- Assert.Equal(expectedOffsets1, encoding.Select(t => t.Offset).ToArray());
+ Assert.Equal(expectedOffsets1, encoding.Select(t => (t.Offset.Start.Value, t.Offset.End.Value - t.Offset.Start.Value)).ToArray());
Assert.Equal(expectedIds1, encoding.Select(t => t.Id).ToArray());
}
}
@@ -562,7 +562,7 @@ public void TestTokenizerEncodingToIds(string text, string normalizedText, strin
Assert.Equal(normalizedText, normalizedString);
Assert.Equal(normalizedText.Length, length);
- SentencePieceBpeTokenizer sentencePieceBpe = (tokenizer as SentencePieceBpeTokenizer)!;
+ SentencePieceTokenizer sentencePieceBpe = (tokenizer as SentencePieceTokenizer)!;
foreach (bool considerNormalization in new[] { true, false })
foreach (bool addBeginningOfSentence in new[] { true, false })
foreach (bool addEndOfSentence in new[] { true, false })
diff --git a/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj b/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj
index b4a386bc40..e0d08c93aa 100644
--- a/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj
+++ b/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj
@@ -10,6 +10,10 @@
+
+ true
+
+
diff --git a/test/Microsoft.ML.Tokenizers.Tests/PreTokenizerTests.cs b/test/Microsoft.ML.Tokenizers.Tests/PreTokenizerTests.cs
index f048a6a209..3d77179dfd 100644
--- a/test/Microsoft.ML.Tokenizers.Tests/PreTokenizerTests.cs
+++ b/test/Microsoft.ML.Tokenizers.Tests/PreTokenizerTests.cs
@@ -18,14 +18,14 @@ public static IEnumerable