diff --git a/psalm-baseline.xml b/psalm-baseline.xml
index d29adab..81b512e 100644
--- a/psalm-baseline.xml
+++ b/psalm-baseline.xml
@@ -3,11 +3,9 @@
-
-
diff --git a/src/Encoder.php b/src/Encoder.php
index 839bfae..a7d119a 100644
--- a/src/Encoder.php
+++ b/src/Encoder.php
@@ -11,15 +11,14 @@
use function array_map;
use function array_merge;
-use function array_slice;
use function array_values;
-use function assert;
use function count;
use function implode;
use function preg_last_error_msg;
use function preg_match_all;
-use function range;
use function sprintf;
+use function strlen;
+use function substr;
use const PHP_INT_MAX;
@@ -57,8 +56,7 @@ public function encode(string $text): array
continue;
}
- $piece = EncodeUtil::toBytes($match);
- $rank = $this->vocab->tryGetRank($piece);
+ $rank = $this->vocab->tryGetRank($match);
if ($rank !== null) {
$tokens[] = $rank;
@@ -66,7 +64,7 @@ public function encode(string $text): array
continue;
}
- foreach ($this->mergeBytePairs($piece) as $rank) {
+ foreach ($this->mergeBytePairs($match) as $rank) {
$tokens[] = $rank;
}
}
@@ -101,15 +99,15 @@ public function encodeInChunks(string $text, int $maxTokensPerChunk): array
continue;
}
- $tokenBytes = EncodeUtil::toBytes($match);
- $mergedBytePairs = $this->mergeBytePairs($tokenBytes);
+ $rank = $this->vocab->tryGetRank($match);
+ $tokens = $rank !== null ? [$rank] : $this->mergeBytePairs($match);
- if (count($tokensInCurrentChunk) + count($mergedBytePairs) > $maxTokensPerChunk) {
+ if (count($tokensInCurrentChunk) + count($tokens) > $maxTokensPerChunk) {
$chunks[] = $tokensInCurrentChunk;
$tokensInCurrentChunk = [];
}
- $tokensInCurrentChunk = array_merge($tokensInCurrentChunk, $mergedBytePairs);
+ $tokensInCurrentChunk = array_merge($tokensInCurrentChunk, $tokens);
}
if (count($tokensInCurrentChunk) > 0) {
@@ -130,38 +128,33 @@ public function decode(array $tokens): string
}
/**
- * @psalm-param NonEmptyByteVector $bytes
+ * @param non-empty-string $piece
*
* @return list
*/
- private function mergeBytePairs(array $bytes): array
+ private function mergeBytePairs(string $piece): array
{
- /** @var list $parts */
- $parts = array_map(
- function (int $i) use ($bytes): array {
- if ($i + 1 < count($bytes)) {
- $piece = array_slice($bytes, $i, 2);
- assert(count($piece) === 2);
-
- return [$i, $this->vocab->tryGetRank($piece) ?? PHP_INT_MAX];
- }
+ $parts = [];
+
+ for ($i = 0; $i <= strlen($piece); $i++) {
+ $parts[] = [$i, PHP_INT_MAX];
+ }
- return [$i, PHP_INT_MAX];
- },
- range(0, count($bytes)),
- );
- $getRank = function (array $parts, int $startIndex) use ($bytes): int {
- if ($startIndex + 2 >= count($parts)) {
+ $getRank = function (array $parts, int $startIndex, int $skip = 0) use (&$piece): int {
+ if (($startIndex + $skip + 2) >= count($parts)) {
return PHP_INT_MAX;
}
$offset = $parts[$startIndex][0];
- $piece = array_slice($bytes, $offset, $parts[$startIndex + 2][0] - $offset);
- assert(count($piece) > 0);
+ $length = $parts[$startIndex + $skip + 2][0] - $offset;
- return $this->vocab->tryGetRank($piece) ?? PHP_INT_MAX;
+ return $this->vocab->tryGetRank(substr($piece, $offset, $length)) ?? PHP_INT_MAX;
};
+ for ($i = 0; $i < count($parts) - 2; $i++) {
+ $parts[$i][1] = $getRank($parts, $i);
+ }
+
while (count($parts) > 1) {
$minRank = PHP_INT_MAX;
$partIndex = 0;
@@ -196,10 +189,10 @@ function (int $i) use ($bytes): array {
$res = [];
for ($i = 0; $i < $stop; $i++) {
- $piece = array_slice($bytes, $parts[$i][0], $parts[$i + 1][0] - $parts[$i][0]);
- assert(count($piece) > 0);
+ $offset = $parts[$i][0];
+ $length = $parts[$i + 1][0] - $offset;
- $res[] = $this->vocab->getRank($piece);
+ $res[] = $this->vocab->getRank(substr($piece, $offset, $length));
}
return $res;
diff --git a/src/Util/EncodeUtil.php b/src/Util/EncodeUtil.php
index 5c79ebd..94bbc7f 100644
--- a/src/Util/EncodeUtil.php
+++ b/src/Util/EncodeUtil.php
@@ -7,7 +7,6 @@
use function array_map;
use function bin2hex;
use function hexdec;
-use function pack;
use function str_split;
/** @psalm-type NonEmptyByteVector = non-empty-list> */
@@ -22,14 +21,4 @@ public static function toBytes(string $text): array
{
return array_map(hexdec(...), str_split(bin2hex($text), 2));
}
-
- /**
- * @psalm-param NonEmptyByteVector $bytes
- *
- * @return non-empty-string
- */
- public static function fromBytes(array $bytes): string
- {
- return pack('C*', ...$bytes);
- }
}
diff --git a/src/Vocab/Vocab.php b/src/Vocab/Vocab.php
index fb25f73..9dfb92b 100644
--- a/src/Vocab/Vocab.php
+++ b/src/Vocab/Vocab.php
@@ -104,22 +104,25 @@ public static function fromStream($stream): self
return new self($map);
}
- /** @psalm-param NonEmptyByteVector $bytes */
- public function tryGetRank(array $bytes): int|null
+ public function tryGetRank(string $binary): int|null
{
- return $this->tokenToRankMap[EncodeUtil::fromBytes($bytes)] ?? null;
+ if ($binary === '') {
+ throw new InvalidArgumentException('Argument $binary cannot be an empty string');
+ }
+
+ return $this->tokenToRankMap[$binary] ?? null;
}
- /**
- * @psalm-param NonEmptyByteVector $bytes
- *
- * @throws OutOfBoundsException
- */
- public function getRank(array $bytes): int
+ /** @throws OutOfBoundsException */
+ public function getRank(string $binary): int
{
- return $this->tokenToRankMap[EncodeUtil::fromBytes($bytes)] ?? throw new OutOfBoundsException(sprintf(
+ if ($binary === '') {
+ throw new InvalidArgumentException('Argument $binary cannot be an empty string');
+ }
+
+ return $this->tokenToRankMap[$binary] ?? throw new OutOfBoundsException(sprintf(
'No rank for bytes vector: [%s]',
- implode(', ', $bytes),
+ implode(', ', EncodeUtil::toBytes($binary)),
));
}
diff --git a/tests/Vocab/VocabTest.php b/tests/Vocab/VocabTest.php
index 77431c7..8f8d482 100644
--- a/tests/Vocab/VocabTest.php
+++ b/tests/Vocab/VocabTest.php
@@ -5,9 +5,10 @@
namespace Yethee\Tiktoken\Tests\Vocab;
use PHPUnit\Framework\TestCase;
-use Yethee\Tiktoken\Util\EncodeUtil;
use Yethee\Tiktoken\Vocab\Vocab;
+use function chr;
+
final class VocabTest extends TestCase
{
public function testLoadFromFile(): void
@@ -15,9 +16,9 @@ public function testLoadFromFile(): void
$vocab = Vocab::fromFile(__DIR__ . '/Fixtures/test.tiktoken');
self::assertCount(47, $vocab);
- self::assertSame(285, $vocab->getRank(EncodeUtil::toBytes('is')));
+ self::assertSame(285, $vocab->getRank('is'));
self::assertSame('is', $vocab->getToken(285));
- self::assertSame(18, $vocab->getRank([51]));
+ self::assertSame(18, $vocab->getRank(chr(51)));
self::assertSame('3', $vocab->getToken(18));
}
}