diff --git a/docs/tutorials/word2vec.ipynb b/docs/tutorials/word2vec.ipynb index e85a39dfc..30dba82ab 100644 --- a/docs/tutorials/word2vec.ipynb +++ b/docs/tutorials/word2vec.ipynb @@ -728,13 +728,12 @@ " window_size=window_size,\n", " negative_samples=0)\n", "\n", - " # Generate positive context windows for each target word in the sequence.\n", - " window = defaultdict(list)\n", - " for i in range(window_size, len(sequence)-window_size):\n", - " window[sequence[i]].append(sequence[i-window_size:1+i+window_size])\n", - "\n", " # Iterate over each positive skip-gram pair to produce training examples\n", " # with a positive context word and negative samples.\n", + " window = defaultdict(set)\n", + " for target_word, context_word in positive_skip_grams:\n", + " window[target_word].add(context_word)\n", + "\n", " for target_word, context_word in positive_skip_grams:\n", " context_class = tf.expand_dims(\n", " tf.constant([context_word], dtype=\"int64\"), 1)\n", @@ -747,12 +746,9 @@ " seed=seed,\n", " name=\"negative_sampling\")\n", "\n", - " # Discard iteration if negative samples overlap with positive context.\n", - " for target in window[target_word]:\n", - " if not any(t in target for t in negative_sampling_candidates):\n", - " break # All candidates are true negatives: use this skip_gram.\n", - " else:\n", - " continue # Discard this skip_gram.\n", + " # Discard this negative sample if it intersects with the positive context.\n", + " if window[target_word].intersection(negative_sampling_candidates.numpy()):\n", + " continue\n", "\n", " # Build context and label vectors (for one target word).\n", " context = tf.concat([tf.squeeze(context_class,1), negative_sampling_candidates], 0)\n",