Skip to content

Commit

Permalink
Merge pull request #1229 from hoosierEE/master
Browse files Browse the repository at this point in the history
negative sampling excludes positive class
  • Loading branch information
cantonios authored Mar 5, 2024
2 parents d8d1544 + f8da6c5 commit 8d6b5e0
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions docs/tutorials/word2vec.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@
"id": "axZvd-hhotVB"
},
"source": [
"where *v* and *v\u003csup\u003e'\u003csup\u003e* are target and context vector representations of words and *W* is vocabulary size. "
"where *v* and *v\u003csup\u003e'\u003csup\u003e* are target and context vector representations of words and *W* is vocabulary size."
]
},
{
Expand All @@ -198,7 +198,7 @@
"id": "WTZBPf1RsOsg"
},
"source": [
"The simplified negative sampling objective for a target word is to distinguish the context word from `num_ns` negative samples drawn from noise distribution *P\u003csub\u003en\u003c/sub\u003e(w)* of words. More precisely, an efficient approximation of full softmax over the vocabulary is, for a skip-gram pair, to pose the loss for a target word as a classification problem between the context word and `num_ns` negative samples. "
"The simplified negative sampling objective for a target word is to distinguish the context word from `num_ns` negative samples drawn from noise distribution *P\u003csub\u003en\u003c/sub\u003e(w)* of words. More precisely, an efficient approximation of full softmax over the vocabulary is, for a skip-gram pair, to pose the loss for a target word as a classification problem between the context word and `num_ns` negative samples."
]
},
{
Expand Down Expand Up @@ -250,7 +250,9 @@
"import numpy as np\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow.keras import layers"
"from tensorflow.keras import layers\n",
"\n",
"from collections import defaultdict"
]
},
{
Expand Down Expand Up @@ -454,7 +456,7 @@
"id": "Esqn8WBfZnEK"
},
"source": [
"The `skipgrams` function returns all positive skip-gram pairs by sliding over a given window span. To produce additional skip-gram pairs that would serve as negative samples for training, you need to sample random words from the vocabulary. Use the `tf.random.log_uniform_candidate_sampler` function to sample `num_ns` number of negative samples for a given target word in a window. You can call the function on one skip-grams's target word and pass the context word as true class to exclude it from being sampled.\n"
"The `skipgrams` function returns all positive skip-gram pairs by sliding over a given window span. To produce additional skip-gram pairs that would serve as negative samples for training, you can sample random words from the vocabulary. Use the `tf.random.log_uniform_candidate_sampler` function to sample `num_ns` number of negative samples for a given target word in a window. You can pass words from the positive class but this does not exclude them from the results. For large vocabularies, this is not a problem because the chance of drawing one of the positive classes is small. However for small data you may see overlap between negative and positive samples. Later we will add code to exclude positive samples for slightly improved accuracy at the cost of longer runtime."
]
},
{
Expand Down Expand Up @@ -728,6 +730,10 @@
"\n",
" # Iterate over each positive skip-gram pair to produce training examples\n",
" # with a positive context word and negative samples.\n",
" window = defaultdict(set)\n",
" for target_word, context_word in positive_skip_grams:\n",
" window[target_word].add(context_word)\n",
"\n",
" for target_word, context_word in positive_skip_grams:\n",
" context_class = tf.expand_dims(\n",
" tf.constant([context_word], dtype=\"int64\"), 1)\n",
Expand All @@ -740,7 +746,11 @@
" seed=seed,\n",
" name=\"negative_sampling\")\n",
"\n",
" # Build context and label vectors (for one target word)\n",
" # Discard this negative sample if it intersects with the positive context.\n",
" if window[target_word].intersection(negative_sampling_candidates.numpy()):\n",
" continue\n",
"\n",
" # Build context and label vectors (for one target word).\n",
" context = tf.concat([tf.squeeze(context_class,1), negative_sampling_candidates], 0)\n",
" label = tf.constant([1] + [0]*num_ns, dtype=\"int64\")\n",
"\n",
Expand Down

0 comments on commit 8d6b5e0

Please sign in to comment.