diff --git a/src/EncoderProvider.php b/src/EncoderProvider.php index 76b709a..b88348c 100644 --- a/src/EncoderProvider.php +++ b/src/EncoderProvider.php @@ -22,22 +22,27 @@ final class EncoderProvider implements ResetInterface public const ENCODINGS = [ 'r50k_base' => [ 'vocab' => 'https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken', + 'hash' => '306cd27f03c1a714eca7108e03d66b7dc042abe8c258b44c199a7ed9838dd930', 'pat' => '/\'s|\'t|\'re|\'ve|\'m|\'ll|\'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+/u', ], 'p50k_base' => [ 'vocab' => 'https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken', + 'hash' => '94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069', 'pat' => '/\'s|\'t|\'re|\'ve|\'m|\'ll|\'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+/u', ], 'p50k_edit' => [ 'vocab' => 'https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken', + 'hash' => '94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069', 'pat' => '/\'s|\'t|\'re|\'ve|\'m|\'ll|\'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+/u', ], 'cl100k_base' => [ 'vocab' => 'https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken', + 'hash' => '223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7', 'pat' => '/(?i:\'s|\'t|\'re|\'ve|\'m|\'ll|\'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+/u', ], 'o200k_base' => [ 'vocab' => 'https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken', + 'hash' => '446a9538cb6c348e3516120d7c08b09f57c36495e2acfffe59a5bf8b0cfb1a2d', 'pat' => '/[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:\'s|\'t|\'re|\'ve|\'m|\'ll|\'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:\'s|\'t|\'re|\'ve|\'m|\'ll|\'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n\/]*|\s*[\r\n]+|\s+(?!\S)|\s+/u', ], ]; diff --git a/src/Vocab/Loader/DefaultVocabLoader.php b/src/Vocab/Loader/DefaultVocabLoader.php index 41bb796..cd007ae 100644 --- a/src/Vocab/Loader/DefaultVocabLoader.php +++ b/src/Vocab/Loader/DefaultVocabLoader.php @@ -12,10 +12,11 @@ use function fclose; use function file_exists; use function fopen; +use function hash_equals; +use function hash_file; use function is_dir; use function is_writable; use function mkdir; -use function preg_match; use function sha1; use function sprintf; use function stream_copy_to_stream; @@ -28,16 +29,12 @@ public function __construct(private string|null $cacheDir = null) { } - public function load(string $uri): Vocab + public function load(string $uri, string|null $checksum = null): Vocab { - if ($this->cacheDir !== null && preg_match('@^https?://@i', $uri)) { - $cacheFile = $this->cacheDir . DIRECTORY_SEPARATOR . sha1($uri); - } else { - $cacheFile = null; - } + $cacheFile = $this->cacheDir !== null ? $this->cacheDir . DIRECTORY_SEPARATOR . sha1($uri) : null; if ($cacheFile !== null) { - if (file_exists($cacheFile)) { + if (file_exists($cacheFile) && $this->checkHash($cacheFile, $checksum)) { return Vocab::fromFile($cacheFile); } @@ -83,4 +80,19 @@ public function load(string $uri): Vocab fclose($stream); } } + + private function checkHash(string $filename, string|null $expectedHash): bool + { + if ($expectedHash === null) { + return true; + } + + $hash = hash_file('sha256', $filename); + + if ($hash === false) { + return false; + } + + return hash_equals($hash, $expectedHash); + } } diff --git a/src/Vocab/VocabLoader.php b/src/Vocab/VocabLoader.php index a308fbd..fd4156f 100644 --- a/src/Vocab/VocabLoader.php +++ b/src/Vocab/VocabLoader.php @@ -7,5 +7,5 @@ interface VocabLoader { /** @param non-empty-string $uri */ - public function load(string $uri): Vocab; + public function load(string $uri, string|null $checksum = null): Vocab; } diff --git a/tests/Vocab/Loader/DefaultVocabLoaderTest.php b/tests/Vocab/Loader/DefaultVocabLoaderTest.php new file mode 100644 index 0000000..da60faa --- /dev/null +++ b/tests/Vocab/Loader/DefaultVocabLoaderTest.php @@ -0,0 +1,84 @@ +cacheDir); + + $vocabUrl = 'http://localhost/cl100k_base.tiktoken'; + $cacheFile = $this->cacheDir . '/' . hash('sha1', $vocabUrl); + + copy(dirname(__DIR__, 2) . '/Fixtures/cl100k_base.tiktoken', $cacheFile); + self::assertFileEquals(dirname(__DIR__, 2) . '/Fixtures/cl100k_base.tiktoken', $cacheFile); + + $vocab = $loader->load($vocabUrl, '223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7'); + + self::assertSame(100256, $vocab->count()); + } + + public function testInvalidateCacheWhenChecksumMismatch(): void + { + $loader = new DefaultVocabLoader($this->cacheDir); + + $vocabUrl = dirname(__DIR__, 2) . '/Fixtures/p50k_base.tiktoken'; + $cacheFile = $this->cacheDir . '/' . hash('sha1', $vocabUrl); + + file_put_contents($cacheFile, 'outdated content'); + self::assertFileExists($cacheFile); + + $vocab = $loader->load($vocabUrl, '94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069'); + + self::assertSame(50280, $vocab->count()); + + self::assertFileExists($cacheFile); + self::assertFileEquals($vocabUrl, $cacheFile); + } + + protected function setUp(): void + { + $this->cacheDir = sys_get_temp_dir() . '/tiktoken-test'; + + self::removeDir($this->cacheDir); + } + + protected function tearDown(): void + { + self::removeDir($this->cacheDir); + } + + private static function removeDir(string $path): void + { + $iterator = new RecursiveIteratorIterator( + new RecursiveDirectoryIterator($path, RecursiveDirectoryIterator::SKIP_DOTS), + RecursiveIteratorIterator::CHILD_FIRST, + ); + + foreach ($iterator as $entry) { + if ($entry->isFile()) { + unlink($entry->getPathname()); + } else { + rmdir($entry->getPathname()); + } + } + } +}