Skip to content

Commit

Permalink
much faster set-intersection based version
Browse files Browse the repository at this point in the history
  • Loading branch information
hoosierEE committed Nov 1, 2023
1 parent 99aad3b commit f8da6c5
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions docs/tutorials/word2vec.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -728,13 +728,12 @@
" window_size=window_size,\n",
" negative_samples=0)\n",
"\n",
" # Generate positive context windows for each target word in the sequence.\n",
" window = defaultdict(list)\n",
" for i in range(window_size, len(sequence)-window_size):\n",
" window[sequence[i]].append(sequence[i-window_size:1+i+window_size])\n",
"\n",
" # Iterate over each positive skip-gram pair to produce training examples\n",
" # with a positive context word and negative samples.\n",
" window = defaultdict(set)\n",
" for target_word, context_word in positive_skip_grams:\n",
" window[target_word].add(context_word)\n",
"\n",
" for target_word, context_word in positive_skip_grams:\n",
" context_class = tf.expand_dims(\n",
" tf.constant([context_word], dtype=\"int64\"), 1)\n",
Expand All @@ -747,12 +746,9 @@
" seed=seed,\n",
" name=\"negative_sampling\")\n",
"\n",
" # Discard iteration if negative samples overlap with positive context.\n",
" for target in window[target_word]:\n",
" if not any(t in target for t in negative_sampling_candidates):\n",
" break # All candidates are true negatives: use this skip_gram.\n",
" else:\n",
" continue # Discard this skip_gram.\n",
" # Discard this negative sample if it intersects with the positive context.\n",
" if window[target_word].intersection(negative_sampling_candidates.numpy()):\n",
" continue\n",
"\n",
" # Build context and label vectors (for one target word).\n",
" context = tf.concat([tf.squeeze(context_class,1), negative_sampling_candidates], 0)\n",
Expand Down

0 comments on commit f8da6c5

Please sign in to comment.