From 3938fbe36c0f3e0d748f330efe7d80a74918b0b2 Mon Sep 17 00:00:00 2001 From: yethee Date: Tue, 30 Apr 2024 16:43:48 +0300 Subject: [PATCH] Improve performance to encode BPE --- src/Encoder.php | 59 +++++++++++++++++---------------------- src/Vocab/Vocab.php | 25 +++++++++-------- tests/Vocab/VocabTest.php | 7 +++-- 3 files changed, 44 insertions(+), 47 deletions(-) diff --git a/src/Encoder.php b/src/Encoder.php index 839bfae..a7d119a 100644 --- a/src/Encoder.php +++ b/src/Encoder.php @@ -11,15 +11,14 @@ use function array_map; use function array_merge; -use function array_slice; use function array_values; -use function assert; use function count; use function implode; use function preg_last_error_msg; use function preg_match_all; -use function range; use function sprintf; +use function strlen; +use function substr; use const PHP_INT_MAX; @@ -57,8 +56,7 @@ public function encode(string $text): array continue; } - $piece = EncodeUtil::toBytes($match); - $rank = $this->vocab->tryGetRank($piece); + $rank = $this->vocab->tryGetRank($match); if ($rank !== null) { $tokens[] = $rank; @@ -66,7 +64,7 @@ public function encode(string $text): array continue; } - foreach ($this->mergeBytePairs($piece) as $rank) { + foreach ($this->mergeBytePairs($match) as $rank) { $tokens[] = $rank; } } @@ -101,15 +99,15 @@ public function encodeInChunks(string $text, int $maxTokensPerChunk): array continue; } - $tokenBytes = EncodeUtil::toBytes($match); - $mergedBytePairs = $this->mergeBytePairs($tokenBytes); + $rank = $this->vocab->tryGetRank($match); + $tokens = $rank !== null ? [$rank] : $this->mergeBytePairs($match); - if (count($tokensInCurrentChunk) + count($mergedBytePairs) > $maxTokensPerChunk) { + if (count($tokensInCurrentChunk) + count($tokens) > $maxTokensPerChunk) { $chunks[] = $tokensInCurrentChunk; $tokensInCurrentChunk = []; } - $tokensInCurrentChunk = array_merge($tokensInCurrentChunk, $mergedBytePairs); + $tokensInCurrentChunk = array_merge($tokensInCurrentChunk, $tokens); } if (count($tokensInCurrentChunk) > 0) { @@ -130,38 +128,33 @@ public function decode(array $tokens): string } /** - * @psalm-param NonEmptyByteVector $bytes + * @param non-empty-string $piece * * @return list */ - private function mergeBytePairs(array $bytes): array + private function mergeBytePairs(string $piece): array { - /** @var list $parts */ - $parts = array_map( - function (int $i) use ($bytes): array { - if ($i + 1 < count($bytes)) { - $piece = array_slice($bytes, $i, 2); - assert(count($piece) === 2); - - return [$i, $this->vocab->tryGetRank($piece) ?? PHP_INT_MAX]; - } + $parts = []; + + for ($i = 0; $i <= strlen($piece); $i++) { + $parts[] = [$i, PHP_INT_MAX]; + } - return [$i, PHP_INT_MAX]; - }, - range(0, count($bytes)), - ); - $getRank = function (array $parts, int $startIndex) use ($bytes): int { - if ($startIndex + 2 >= count($parts)) { + $getRank = function (array $parts, int $startIndex, int $skip = 0) use (&$piece): int { + if (($startIndex + $skip + 2) >= count($parts)) { return PHP_INT_MAX; } $offset = $parts[$startIndex][0]; - $piece = array_slice($bytes, $offset, $parts[$startIndex + 2][0] - $offset); - assert(count($piece) > 0); + $length = $parts[$startIndex + $skip + 2][0] - $offset; - return $this->vocab->tryGetRank($piece) ?? PHP_INT_MAX; + return $this->vocab->tryGetRank(substr($piece, $offset, $length)) ?? PHP_INT_MAX; }; + for ($i = 0; $i < count($parts) - 2; $i++) { + $parts[$i][1] = $getRank($parts, $i); + } + while (count($parts) > 1) { $minRank = PHP_INT_MAX; $partIndex = 0; @@ -196,10 +189,10 @@ function (int $i) use ($bytes): array { $res = []; for ($i = 0; $i < $stop; $i++) { - $piece = array_slice($bytes, $parts[$i][0], $parts[$i + 1][0] - $parts[$i][0]); - assert(count($piece) > 0); + $offset = $parts[$i][0]; + $length = $parts[$i + 1][0] - $offset; - $res[] = $this->vocab->getRank($piece); + $res[] = $this->vocab->getRank(substr($piece, $offset, $length)); } return $res; diff --git a/src/Vocab/Vocab.php b/src/Vocab/Vocab.php index fb25f73..9dfb92b 100644 --- a/src/Vocab/Vocab.php +++ b/src/Vocab/Vocab.php @@ -104,22 +104,25 @@ public static function fromStream($stream): self return new self($map); } - /** @psalm-param NonEmptyByteVector $bytes */ - public function tryGetRank(array $bytes): int|null + public function tryGetRank(string $binary): int|null { - return $this->tokenToRankMap[EncodeUtil::fromBytes($bytes)] ?? null; + if ($binary === '') { + throw new InvalidArgumentException('Argument $binary cannot be an empty string'); + } + + return $this->tokenToRankMap[$binary] ?? null; } - /** - * @psalm-param NonEmptyByteVector $bytes - * - * @throws OutOfBoundsException - */ - public function getRank(array $bytes): int + /** @throws OutOfBoundsException */ + public function getRank(string $binary): int { - return $this->tokenToRankMap[EncodeUtil::fromBytes($bytes)] ?? throw new OutOfBoundsException(sprintf( + if ($binary === '') { + throw new InvalidArgumentException('Argument $binary cannot be an empty string'); + } + + return $this->tokenToRankMap[$binary] ?? throw new OutOfBoundsException(sprintf( 'No rank for bytes vector: [%s]', - implode(', ', $bytes), + implode(', ', EncodeUtil::toBytes($binary)), )); } diff --git a/tests/Vocab/VocabTest.php b/tests/Vocab/VocabTest.php index 77431c7..8f8d482 100644 --- a/tests/Vocab/VocabTest.php +++ b/tests/Vocab/VocabTest.php @@ -5,9 +5,10 @@ namespace Yethee\Tiktoken\Tests\Vocab; use PHPUnit\Framework\TestCase; -use Yethee\Tiktoken\Util\EncodeUtil; use Yethee\Tiktoken\Vocab\Vocab; +use function chr; + final class VocabTest extends TestCase { public function testLoadFromFile(): void @@ -15,9 +16,9 @@ public function testLoadFromFile(): void $vocab = Vocab::fromFile(__DIR__ . '/Fixtures/test.tiktoken'); self::assertCount(47, $vocab); - self::assertSame(285, $vocab->getRank(EncodeUtil::toBytes('is'))); + self::assertSame(285, $vocab->getRank('is')); self::assertSame('is', $vocab->getToken(285)); - self::assertSame(18, $vocab->getRank([51])); + self::assertSame(18, $vocab->getRank(chr(51))); self::assertSame('3', $vocab->getToken(18)); } }