Skip to content

Commit

Permalink
Fix truncation issue in classify_review function (#373)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Sep 26, 2024
1 parent b56d0b2 commit 7ef5129
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
4 changes: 3 additions & 1 deletion ch06/01_main-chapter-code/ch06.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2207,7 +2207,9 @@
"\n",
" # Prepare inputs to the model\n",
" input_ids = tokenizer.encode(text)\n",
" supported_context_length = model.pos_emb.weight.shape[1]\n",
" supported_context_length = model.pos_emb.weight.shape[0]\n",
" # Note: In the book, this was originally written as pos_emb.weight.shape[1] by mistake\n",
" # It didn't break the code but would have caused unnecessary truncation (to 768 instead of 1024)\n",
"\n",
" # Truncate sequences if they too long\n",
" input_ids = input_ids[:min(max_length, supported_context_length)]\n",
Expand Down
2 changes: 1 addition & 1 deletion ch06/01_main-chapter-code/load-finetuned-model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@
"\n",
" # Prepare inputs to the model\n",
" input_ids = tokenizer.encode(text)\n",
" supported_context_length = model.pos_emb.weight.shape[1]\n",
" supported_context_length = model.pos_emb.weight.shape[0]\n",
"\n",
" # Truncate sequences if they too long\n",
" input_ids = input_ids[:min(max_length, supported_context_length)]\n",
Expand Down
2 changes: 1 addition & 1 deletion ch06/04_user_interface/previous_chapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def classify_review(text, model, tokenizer, device, max_length=None, pad_token_i

# Prepare inputs to the model
input_ids = tokenizer.encode(text)
supported_context_length = model.pos_emb.weight.shape[1]
supported_context_length = model.pos_emb.weight.shape[0]

# Truncate sequences if they too long
input_ids = input_ids[:min(max_length, supported_context_length)]
Expand Down

0 comments on commit 7ef5129

Please sign in to comment.