Skip to content

Commit

Permalink
Improve performance to encode BPE
Browse files Browse the repository at this point in the history
  • Loading branch information
yethee committed Apr 30, 2024
1 parent b9679f8 commit 3938fbe
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 47 deletions.
59 changes: 26 additions & 33 deletions src/Encoder.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@

use function array_map;
use function array_merge;
use function array_slice;
use function array_values;
use function assert;
use function count;
use function implode;
use function preg_last_error_msg;
use function preg_match_all;
use function range;
use function sprintf;
use function strlen;
use function substr;

use const PHP_INT_MAX;

Expand Down Expand Up @@ -57,16 +56,15 @@ public function encode(string $text): array
continue;
}

$piece = EncodeUtil::toBytes($match);
$rank = $this->vocab->tryGetRank($piece);
$rank = $this->vocab->tryGetRank($match);

if ($rank !== null) {
$tokens[] = $rank;

continue;
}

foreach ($this->mergeBytePairs($piece) as $rank) {
foreach ($this->mergeBytePairs($match) as $rank) {
$tokens[] = $rank;
}
}
Expand Down Expand Up @@ -101,15 +99,15 @@ public function encodeInChunks(string $text, int $maxTokensPerChunk): array
continue;
}

$tokenBytes = EncodeUtil::toBytes($match);
$mergedBytePairs = $this->mergeBytePairs($tokenBytes);
$rank = $this->vocab->tryGetRank($match);
$tokens = $rank !== null ? [$rank] : $this->mergeBytePairs($match);

if (count($tokensInCurrentChunk) + count($mergedBytePairs) > $maxTokensPerChunk) {
if (count($tokensInCurrentChunk) + count($tokens) > $maxTokensPerChunk) {
$chunks[] = $tokensInCurrentChunk;
$tokensInCurrentChunk = [];
}

$tokensInCurrentChunk = array_merge($tokensInCurrentChunk, $mergedBytePairs);
$tokensInCurrentChunk = array_merge($tokensInCurrentChunk, $tokens);
}

if (count($tokensInCurrentChunk) > 0) {
Expand All @@ -130,38 +128,33 @@ public function decode(array $tokens): string
}

/**
* @psalm-param NonEmptyByteVector $bytes
* @param non-empty-string $piece
*
* @return list<int>
*/
private function mergeBytePairs(array $bytes): array
private function mergeBytePairs(string $piece): array
{
/** @var list<array{int, int}> $parts */
$parts = array_map(
function (int $i) use ($bytes): array {
if ($i + 1 < count($bytes)) {
$piece = array_slice($bytes, $i, 2);
assert(count($piece) === 2);

return [$i, $this->vocab->tryGetRank($piece) ?? PHP_INT_MAX];
}
$parts = [];

for ($i = 0; $i <= strlen($piece); $i++) {
$parts[] = [$i, PHP_INT_MAX];
}

return [$i, PHP_INT_MAX];
},
range(0, count($bytes)),
);
$getRank = function (array $parts, int $startIndex) use ($bytes): int {
if ($startIndex + 2 >= count($parts)) {
$getRank = function (array $parts, int $startIndex, int $skip = 0) use (&$piece): int {
if (($startIndex + $skip + 2) >= count($parts)) {
return PHP_INT_MAX;
}

$offset = $parts[$startIndex][0];
$piece = array_slice($bytes, $offset, $parts[$startIndex + 2][0] - $offset);
assert(count($piece) > 0);
$length = $parts[$startIndex + $skip + 2][0] - $offset;

return $this->vocab->tryGetRank($piece) ?? PHP_INT_MAX;
return $this->vocab->tryGetRank(substr($piece, $offset, $length)) ?? PHP_INT_MAX;
};

for ($i = 0; $i < count($parts) - 2; $i++) {
$parts[$i][1] = $getRank($parts, $i);
}

while (count($parts) > 1) {
$minRank = PHP_INT_MAX;
$partIndex = 0;
Expand Down Expand Up @@ -196,10 +189,10 @@ function (int $i) use ($bytes): array {
$res = [];

for ($i = 0; $i < $stop; $i++) {
$piece = array_slice($bytes, $parts[$i][0], $parts[$i + 1][0] - $parts[$i][0]);
assert(count($piece) > 0);
$offset = $parts[$i][0];
$length = $parts[$i + 1][0] - $offset;

$res[] = $this->vocab->getRank($piece);
$res[] = $this->vocab->getRank(substr($piece, $offset, $length));
}

return $res;
Expand Down
25 changes: 14 additions & 11 deletions src/Vocab/Vocab.php
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,25 @@ public static function fromStream($stream): self
return new self($map);
}

/** @psalm-param NonEmptyByteVector $bytes */
public function tryGetRank(array $bytes): int|null
public function tryGetRank(string $binary): int|null
{
return $this->tokenToRankMap[EncodeUtil::fromBytes($bytes)] ?? null;
if ($binary === '') {
throw new InvalidArgumentException('Argument $binary cannot be an empty string');
}

return $this->tokenToRankMap[$binary] ?? null;
}

/**
* @psalm-param NonEmptyByteVector $bytes
*
* @throws OutOfBoundsException
*/
public function getRank(array $bytes): int
/** @throws OutOfBoundsException */
public function getRank(string $binary): int
{
return $this->tokenToRankMap[EncodeUtil::fromBytes($bytes)] ?? throw new OutOfBoundsException(sprintf(
if ($binary === '') {
throw new InvalidArgumentException('Argument $binary cannot be an empty string');
}

return $this->tokenToRankMap[$binary] ?? throw new OutOfBoundsException(sprintf(
'No rank for bytes vector: [%s]',
implode(', ', $bytes),
implode(', ', EncodeUtil::toBytes($binary)),
));
}

Expand Down
7 changes: 4 additions & 3 deletions tests/Vocab/VocabTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@
namespace Yethee\Tiktoken\Tests\Vocab;

use PHPUnit\Framework\TestCase;
use Yethee\Tiktoken\Util\EncodeUtil;
use Yethee\Tiktoken\Vocab\Vocab;

use function chr;

final class VocabTest extends TestCase
{
public function testLoadFromFile(): void
{
$vocab = Vocab::fromFile(__DIR__ . '/Fixtures/test.tiktoken');

self::assertCount(47, $vocab);
self::assertSame(285, $vocab->getRank(EncodeUtil::toBytes('is')));
self::assertSame(285, $vocab->getRank('is'));
self::assertSame('is', $vocab->getToken(285));
self::assertSame(18, $vocab->getRank([51]));
self::assertSame(18, $vocab->getRank(chr(51)));
self::assertSame('3', $vocab->getToken(18));
}
}

0 comments on commit 3938fbe

Please sign in to comment.