diff --git a/notebooks/Adversarial Debias Training.ipynb b/notebooks/Adversarial Debias Training.ipynb index 64dcc9d..dde4a25 100644 --- a/notebooks/Adversarial Debias Training.ipynb +++ b/notebooks/Adversarial Debias Training.ipynb @@ -32,7 +32,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 3000/3000 [00:00<00:00, 3007.87it/s]\n" + "100%|██████████| 3000/3000 [00:00<00:00, 3089.31it/s]\n" ] }, { @@ -40,32 +40,32 @@ "output_type": "stream", "text": [ " tag cat occurences\n", - "0 B-MISC 0 8\n", - "1 I-LOC 1 2563\n", - "2 I-MISC 2 1106\n", - "3 I-ORG 3 1871\n", - "4 I-PER 4 8216\n", - "5 O 5 53719\n", + "0 B-MISC 0 4\n", + "1 I-LOC 1 2510\n", + "2 I-MISC 2 1137\n", + "3 I-ORG 3 1819\n", + "4 I-PER 4 8244\n", + "5 O 5 53099\n", "6 [nerCLS] 6 3000\n", - "7 [nerPAD] 7 293949\n", + "7 [nerPAD] 7 295291\n", "8 [nerSEP] 8 3000\n", - "9 [nerX] 9 16568\n", + "9 [nerX] 9 15896\n", "\n", " tag cat occurences\n", - "0 AFRICAN-AMERICAN 0 3382\n", - "1 EUROPEAN 1 1555\n", + "0 AFRICAN-AMERICAN 0 3405\n", + "1 EUROPEAN 1 1538\n", "2 [raceCLS] 2 3000\n", - "3 [racePAD] 3 293949\n", + "3 [racePAD] 3 295291\n", "4 [raceSEP] 4 3000\n", - "5 [raceX] 5 79114\n", + "5 [raceX] 5 77766\n", "\n", " tag cat occurences\n", - "0 FEMALE 0 2855\n", - "1 MALE 1 2082\n", + "0 FEMALE 0 2838\n", + "1 MALE 1 2105\n", "2 [genderCLS] 2 3000\n", - "3 [genderPAD] 3 293949\n", + "3 [genderPAD] 3 295291\n", "4 [genderSEP] 4 3000\n", - "5 [genderX] 5 79114\n", + "5 [genderX] 5 77766\n", "\n" ] } @@ -121,24 +121,6 @@ "cell_type": "code", "execution_count": 4, "metadata": {}, - "outputs": [], - "source": [ - "def random_batch(data, batch_size=32):\n", - " idx = np.random.randint(len(data[\"nerLabels\"]), size=batch_size)\n", - " return [\n", - " data[\"inputs\"][0][idx], \n", - " data[\"inputs\"][1][idx], \n", - " data[\"inputs\"][2][idx], \n", - " data[\"nerLabels\"][idx],\n", - " data[\"genderLabels\"][idx],\n", - " data[\"raceLabels\"][idx]\n", - " ]" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, "outputs": [ { "name": "stdout", @@ -180,13 +162,20 @@ "config.gpu_options.allow_growth = True\n", "sess = tf.Session(config=config)\n", "\n", - "model = getDebiasedModel(max_length, 1)\n", - "\n", + "model = getDebiasedModel(max_length, 1)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "protect_loss_weight = 0.1\n", "pred_learning_rate = 2**-16\n", "protect_learning_rate = 2**-16\n", - "num_epochs = 5\n", - "\n", - "num_train_samples = len(train_data[\"nerLabels\"])" + "num_epochs = 8\n", + "batch_size = 32" ] }, { @@ -195,7 +184,79 @@ "metadata": {}, "outputs": [], "source": [ - "protect_loss_weight = 0.1" + "def fit(train_data, epochs, batch_size, debias,\n", + " protect_loss_weight = 0.1, \n", + " pred_learning_rate = 2**-16, \n", + " protect_learning_rate = 2**-16):\n", + "\n", + " num_train_samples = len(train_data[\"nerLabels\"])\n", + "\n", + " ids_ph = tf.placeholder(tf.float32, shape=[batch_size, max_length])\n", + " masks_ph = tf.placeholder(tf.float32, shape=[batch_size, max_length])\n", + " sentenceIds_ph = tf.placeholder(tf.float32, shape=[batch_size, max_length])\n", + "\n", + " gender_ph = tf.placeholder(tf.float32, shape=[batch_size, max_length])\n", + " ner_labels_ph = tf.placeholder(tf.float32, shape=[batch_size, max_length])\n", + "\n", + " global_step = tf.Variable(0, trainable=False)\n", + " starter_learning_rate = 0.001\n", + " learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step, 1000, 0.96, staircase=True)\n", + "\n", + " protect_vars = [var for var in tf.trainable_variables() if 'gender' in var.name]\n", + " pred_vars = model.layers[3]._trainable_weights + [var for var in tf.trainable_variables() if any(x in var.name for x in [\"pred_dense\",\"ner\"])]\n", + "\n", + " y_pred = model([ids_ph, masks_ph, sentenceIds_ph], training=True)\n", + "\n", + " ner_loss = model_utils.custom_loss(ner_labels_ph, y_pred[\"ner\"])\n", + " gender_loss = model_utils.custom_loss_protected(gender_ph, y_pred[\"gender\"])\n", + "\n", + " protect_opt = tf.train.AdamOptimizer(protect_learning_rate)\n", + " pred_opt = tf.train.AdamOptimizer(pred_learning_rate)\n", + "\n", + " protect_grads = {var: grad for (grad, var) in protect_opt.compute_gradients(gender_loss,var_list=pred_vars)}\n", + " pred_grads = []\n", + "\n", + " tf_normalize = lambda x: x / (tf.norm(x) + np.finfo(np.float32).tiny)\n", + "\n", + " for (grad, var) in pred_opt.compute_gradients(ner_loss, var_list=pred_vars):\n", + " unit_protect = tf_normalize(protect_grads[var])\n", + " # the two lines below can be commented out to train without debiasing\n", + " if debias:\n", + " grad -= tf.reduce_sum(grad * unit_protect) * unit_protect\n", + " grad -= tf.math.scalar_mul(protect_loss_weight, protect_grads[var])\n", + " pred_grads.append((grad, var))\n", + "\n", + " pred_min = pred_opt.apply_gradients(pred_grads, global_step=global_step)\n", + " protect_min = protect_opt.minimize(gender_loss, var_list=[protect_vars], global_step=global_step)\n", + "\n", + " model_utils.initialize_vars(sess)\n", + "\n", + " # Begin training\n", + " for epoch in range(epochs):\n", + " \n", + " shuffled_ids = np.random.choice(num_train_samples, num_train_samples)\n", + "\n", + " for i in range(1, num_train_samples//32 + 1):\n", + " \n", + " batch_ids = shuffled_ids[batch_size*i: batch_size*(i+1)]\n", + "\n", + " batch_feed_dict = {ids_ph: train_data[\"inputs\"][0][batch_ids], \n", + " masks_ph: train_data[\"inputs\"][1][batch_ids],\n", + " sentenceIds_ph: train_data[\"inputs\"][2][batch_ids],\n", + " gender_ph: train_data[\"genderLabels\"][batch_ids],\n", + " ner_labels_ph: train_data[\"nerLabels\"][batch_ids]}\n", + "\n", + " _, _, pred_labels_loss_value, pred_protected_attributes_loss_vale = sess.run([\n", + " pred_min,\n", + " protect_min,\n", + " ner_loss,\n", + " gender_loss\n", + " ], feed_dict=batch_feed_dict)\n", + " \n", + "\n", + " if i % 10 == 0:\n", + " print(\"epoch %d; iter: %d; batch classifier loss: %f; batch adversarial loss: %f\" % (epoch, i, pred_labels_loss_value,\n", + " pred_protected_attributes_loss_vale))\n" ] }, { @@ -204,255 +265,41 @@ "metadata": {}, "outputs": [], "source": [ - "def tf_normalize(x):\n", - " \"\"\"Returns the input vector, normalized.\n", - "\n", - " A small number is added to the norm so that this function does not break when\n", - " dealing with the zero vector (e.g. if the weights are zero-initialized).\n", - "\n", - " Args:\n", - " x: the tensor to normalize\n", - " \"\"\"\n", - " return x / (tf.norm(x) + np.finfo(np.float32).tiny)" + "def print_status_bar(iteration, total, loss, metrics=None):\n", + " metrics = \" - \".join([\"{}: {:.4f}\".format(m.name, m.result())for m in [loss] + (metrics or [])])\n", + " end = \"\" if iteration < total else \"\\n\"\n", + " print(\"\\r{}/{} - \".format(iteration, total) + metrics, end=end)" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "epoch 0; iter: 0; batch classifier loss: 2.678376; batch adversarial loss: 3.734007\n", - "epoch 0; iter: 1; batch classifier loss: 2.585400; batch adversarial loss: 3.721679\n", - "epoch 0; iter: 2; batch classifier loss: 2.546495; batch adversarial loss: 3.730783\n", - "epoch 0; iter: 3; batch classifier loss: 2.499578; batch adversarial loss: 3.728687\n", - "epoch 0; iter: 4; batch classifier loss: 2.417089; batch adversarial loss: 3.737558\n", - "epoch 0; iter: 5; batch classifier loss: 2.289476; batch adversarial loss: 3.731979\n", - "epoch 0; iter: 6; batch classifier loss: 2.252623; batch adversarial loss: 3.734449\n", - "epoch 0; iter: 7; batch classifier loss: 2.161414; batch adversarial loss: 3.727175\n", - "epoch 0; iter: 8; batch classifier loss: 2.166079; batch adversarial loss: 3.722918\n", - "epoch 0; iter: 9; batch classifier loss: 2.101038; batch adversarial loss: 3.724814\n", - "epoch 0; iter: 10; batch classifier loss: 2.035396; batch adversarial loss: 3.725134\n", - "epoch 0; iter: 11; batch classifier loss: 1.998918; batch adversarial loss: 3.726599\n", - "epoch 0; iter: 12; batch classifier loss: 1.931057; batch adversarial loss: 3.723634\n", - "epoch 0; iter: 13; batch classifier loss: 1.819195; batch adversarial loss: 3.734349\n", - "epoch 0; iter: 14; batch classifier loss: 1.808850; batch adversarial loss: 3.730552\n", - "epoch 0; iter: 15; batch classifier loss: 1.762997; batch adversarial loss: 3.734618\n", - "epoch 0; iter: 16; batch classifier loss: 1.707420; batch adversarial loss: 3.734702\n", - "epoch 0; iter: 17; batch classifier loss: 1.643114; batch adversarial loss: 3.716355\n", - "epoch 0; iter: 18; batch classifier loss: 1.570874; batch adversarial loss: 3.739947\n", - "epoch 0; iter: 19; batch classifier loss: 1.546776; batch adversarial loss: 3.720227\n", - "epoch 0; iter: 20; batch classifier loss: 1.488895; batch adversarial loss: 3.728689\n", - "epoch 0; iter: 21; batch classifier loss: 1.446590; batch adversarial loss: 3.730140\n", - "epoch 0; iter: 22; batch classifier loss: 1.414689; batch adversarial loss: 3.727072\n", - "epoch 0; iter: 23; batch classifier loss: 1.382561; batch adversarial loss: 3.728362\n", - "epoch 0; iter: 24; batch classifier loss: 1.364309; batch adversarial loss: 3.729270\n", - "epoch 0; iter: 25; batch classifier loss: 1.285569; batch adversarial loss: 3.721978\n", - "epoch 0; iter: 26; batch classifier loss: 1.275097; batch adversarial loss: 3.718330\n", - "epoch 0; iter: 27; batch classifier loss: 1.234471; batch adversarial loss: 3.723352\n", - "epoch 0; iter: 28; batch classifier loss: 1.148277; batch adversarial loss: 3.728103\n", - "epoch 0; iter: 29; batch classifier loss: 1.121372; batch adversarial loss: 3.718870\n", - "epoch 0; iter: 30; batch classifier loss: 1.128097; batch adversarial loss: 3.732084\n", - "epoch 1; iter: 0; batch classifier loss: 1.036381; batch adversarial loss: 3.732941\n", - "epoch 1; iter: 1; batch classifier loss: 1.042026; batch adversarial loss: 3.727338\n", - "epoch 1; iter: 2; batch classifier loss: 0.952354; batch adversarial loss: 3.720267\n", - "epoch 1; iter: 3; batch classifier loss: 0.996197; batch adversarial loss: 3.724107\n", - "epoch 1; iter: 4; batch classifier loss: 0.931898; batch adversarial loss: 3.728052\n", - "epoch 1; iter: 5; batch classifier loss: 0.954432; batch adversarial loss: 3.721893\n", - "epoch 1; iter: 6; batch classifier loss: 0.950264; batch adversarial loss: 3.719858\n", - "epoch 1; iter: 7; batch classifier loss: 0.868991; batch adversarial loss: 3.719575\n", - "epoch 1; iter: 8; batch classifier loss: 0.874406; batch adversarial loss: 3.732162\n", - "epoch 1; iter: 9; batch classifier loss: 0.846509; batch adversarial loss: 3.721153\n", - "epoch 1; iter: 10; batch classifier loss: 0.844095; batch adversarial loss: 3.712003\n", - "epoch 1; iter: 11; batch classifier loss: 0.810338; batch adversarial loss: 3.724535\n", - "epoch 1; iter: 12; batch classifier loss: 0.783894; batch adversarial loss: 3.725021\n", - "epoch 1; iter: 13; batch classifier loss: 0.763638; batch adversarial loss: 3.725904\n", - "epoch 1; iter: 14; batch classifier loss: 0.753987; batch adversarial loss: 3.720403\n", - "epoch 1; iter: 15; batch classifier loss: 0.731780; batch adversarial loss: 3.709384\n", - "epoch 1; iter: 16; batch classifier loss: 0.776539; batch adversarial loss: 3.714598\n", - "epoch 1; iter: 17; batch classifier loss: 0.743523; batch adversarial loss: 3.729327\n", - "epoch 1; iter: 18; batch classifier loss: 0.700741; batch adversarial loss: 3.721736\n", - "epoch 1; iter: 19; batch classifier loss: 0.632679; batch adversarial loss: 3.724268\n", - "epoch 1; iter: 20; batch classifier loss: 0.706350; batch adversarial loss: 3.709811\n", - "epoch 1; iter: 21; batch classifier loss: 0.683652; batch adversarial loss: 3.716700\n", - "epoch 1; iter: 22; batch classifier loss: 0.672041; batch adversarial loss: 3.719959\n", - "epoch 1; iter: 23; batch classifier loss: 0.662070; batch adversarial loss: 3.711993\n", - "epoch 1; iter: 24; batch classifier loss: 0.630526; batch adversarial loss: 3.728816\n", - "epoch 1; iter: 25; batch classifier loss: 0.574854; batch adversarial loss: 3.728230\n", - "epoch 1; iter: 26; batch classifier loss: 0.562681; batch adversarial loss: 3.742302\n", - "epoch 1; iter: 27; batch classifier loss: 0.540649; batch adversarial loss: 3.716567\n", - "epoch 1; iter: 28; batch classifier loss: 0.576717; batch adversarial loss: 3.715358\n", - "epoch 1; iter: 29; batch classifier loss: 0.503479; batch adversarial loss: 3.725264\n", - "epoch 1; iter: 30; batch classifier loss: 0.556868; batch adversarial loss: 3.727194\n", - "epoch 2; iter: 0; batch classifier loss: 0.537611; batch adversarial loss: 3.713457\n", - "epoch 2; iter: 1; batch classifier loss: 0.509668; batch adversarial loss: 3.730114\n", - "epoch 2; iter: 2; batch classifier loss: 0.476044; batch adversarial loss: 3.721471\n", - "epoch 2; iter: 3; batch classifier loss: 0.522587; batch adversarial loss: 3.727117\n", - "epoch 2; iter: 4; batch classifier loss: 0.490088; batch adversarial loss: 3.721257\n", - "epoch 2; iter: 5; batch classifier loss: 0.465905; batch adversarial loss: 3.740016\n", - "epoch 2; iter: 6; batch classifier loss: 0.447686; batch adversarial loss: 3.729166\n", - "epoch 2; iter: 7; batch classifier loss: 0.458686; batch adversarial loss: 3.725034\n", - "epoch 2; iter: 8; batch classifier loss: 0.430064; batch adversarial loss: 3.741273\n", - "epoch 2; iter: 9; batch classifier loss: 0.425212; batch adversarial loss: 3.732696\n", - "epoch 2; iter: 10; batch classifier loss: 0.405280; batch adversarial loss: 3.734439\n", - "epoch 2; iter: 11; batch classifier loss: 0.423221; batch adversarial loss: 3.728695\n", - "epoch 2; iter: 12; batch classifier loss: 0.426862; batch adversarial loss: 3.707139\n", - "epoch 2; iter: 13; batch classifier loss: 0.410972; batch adversarial loss: 3.722204\n", - "epoch 2; iter: 14; batch classifier loss: 0.394498; batch adversarial loss: 3.713705\n", - "epoch 2; iter: 15; batch classifier loss: 0.413409; batch adversarial loss: 3.737690\n", - "epoch 2; iter: 16; batch classifier loss: 0.414758; batch adversarial loss: 3.726439\n", - "epoch 2; iter: 17; batch classifier loss: 0.399764; batch adversarial loss: 3.715286\n", - "epoch 2; iter: 18; batch classifier loss: 0.342232; batch adversarial loss: 3.741620\n", - "epoch 2; iter: 19; batch classifier loss: 0.392566; batch adversarial loss: 3.739787\n", - "epoch 2; iter: 20; batch classifier loss: 0.389585; batch adversarial loss: 3.715882\n", - "epoch 2; iter: 21; batch classifier loss: 0.378386; batch adversarial loss: 3.728272\n", - "epoch 2; iter: 22; batch classifier loss: 0.392694; batch adversarial loss: 3.724486\n", - "epoch 2; iter: 23; batch classifier loss: 0.364721; batch adversarial loss: 3.715400\n", - "epoch 2; iter: 24; batch classifier loss: 0.363144; batch adversarial loss: 3.718371\n", - "epoch 2; iter: 25; batch classifier loss: 0.326667; batch adversarial loss: 3.734735\n", - "epoch 2; iter: 26; batch classifier loss: 0.348976; batch adversarial loss: 3.718245\n", - "epoch 2; iter: 27; batch classifier loss: 0.344584; batch adversarial loss: 3.722294\n", - "epoch 2; iter: 28; batch classifier loss: 0.347387; batch adversarial loss: 3.725701\n", - "epoch 2; iter: 29; batch classifier loss: 0.310200; batch adversarial loss: 3.722409\n", - "epoch 2; iter: 30; batch classifier loss: 0.316144; batch adversarial loss: 3.739583\n", - "epoch 3; iter: 0; batch classifier loss: 0.313580; batch adversarial loss: 3.734560\n", - "epoch 3; iter: 1; batch classifier loss: 0.315343; batch adversarial loss: 3.734612\n", - "epoch 3; iter: 2; batch classifier loss: 0.292728; batch adversarial loss: 3.741746\n", - "epoch 3; iter: 3; batch classifier loss: 0.316134; batch adversarial loss: 3.730970\n" + "Epoch 0/8\n" ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch 3; iter: 4; batch classifier loss: 0.306072; batch adversarial loss: 3.728918\n", - "epoch 3; iter: 5; batch classifier loss: 0.307814; batch adversarial loss: 3.742266\n", - "epoch 3; iter: 6; batch classifier loss: 0.313967; batch adversarial loss: 3.729022\n", - "epoch 3; iter: 7; batch classifier loss: 0.299893; batch adversarial loss: 3.728607\n", - "epoch 3; iter: 8; batch classifier loss: 0.287814; batch adversarial loss: 3.718416\n", - "epoch 3; iter: 9; batch classifier loss: 0.287074; batch adversarial loss: 3.725401\n", - "epoch 3; iter: 10; batch classifier loss: 0.279586; batch adversarial loss: 3.746905\n", - "epoch 3; iter: 11; batch classifier loss: 0.273372; batch adversarial loss: 3.758242\n", - "epoch 3; iter: 12; batch classifier loss: 0.268778; batch adversarial loss: 3.740175\n", - "epoch 3; iter: 13; batch classifier loss: 0.271226; batch adversarial loss: 3.721240\n", - "epoch 3; iter: 14; batch classifier loss: 0.253401; batch adversarial loss: 3.761651\n", - "epoch 3; iter: 15; batch classifier loss: 0.261009; batch adversarial loss: 3.731069\n", - "epoch 3; iter: 16; batch classifier loss: 0.279937; batch adversarial loss: 3.721155\n", - "epoch 3; iter: 17; batch classifier loss: 0.254943; batch adversarial loss: 3.763981\n", - "epoch 3; iter: 18; batch classifier loss: 0.255405; batch adversarial loss: 3.744275\n", - "epoch 3; iter: 19; batch classifier loss: 0.271349; batch adversarial loss: 3.713373\n", - "epoch 3; iter: 20; batch classifier loss: 0.262028; batch adversarial loss: 3.721984\n", - "epoch 3; iter: 21; batch classifier loss: 0.269227; batch adversarial loss: 3.732561\n", - "epoch 3; iter: 22; batch classifier loss: 0.248011; batch adversarial loss: 3.753270\n", - "epoch 3; iter: 23; batch classifier loss: 0.227638; batch adversarial loss: 3.749408\n", - "epoch 3; iter: 24; batch classifier loss: 0.237888; batch adversarial loss: 3.763052\n", - "epoch 3; iter: 25; batch classifier loss: 0.229842; batch adversarial loss: 3.753456\n", - "epoch 3; iter: 26; batch classifier loss: 0.241909; batch adversarial loss: 3.736616\n", - "epoch 3; iter: 27; batch classifier loss: 0.227037; batch adversarial loss: 3.748862\n", - "epoch 3; iter: 28; batch classifier loss: 0.229994; batch adversarial loss: 3.738607\n", - "epoch 3; iter: 29; batch classifier loss: 0.214898; batch adversarial loss: 3.764717\n", - "epoch 3; iter: 30; batch classifier loss: 0.239942; batch adversarial loss: 3.731864\n", - "epoch 4; iter: 0; batch classifier loss: 0.215722; batch adversarial loss: 3.755016\n", - "epoch 4; iter: 1; batch classifier loss: 0.235006; batch adversarial loss: 3.740727\n", - "epoch 4; iter: 2; batch classifier loss: 0.223695; batch adversarial loss: 3.737964\n", - "epoch 4; iter: 3; batch classifier loss: 0.219583; batch adversarial loss: 3.754513\n", - "epoch 4; iter: 4; batch classifier loss: 0.206847; batch adversarial loss: 3.772835\n", - "epoch 4; iter: 5; batch classifier loss: 0.225351; batch adversarial loss: 3.744326\n", - "epoch 4; iter: 6; batch classifier loss: 0.194216; batch adversarial loss: 3.752586\n", - "epoch 4; iter: 7; batch classifier loss: 0.203559; batch adversarial loss: 3.741700\n", - "epoch 4; iter: 8; batch classifier loss: 0.218903; batch adversarial loss: 3.767123\n", - "epoch 4; iter: 9; batch classifier loss: 0.192884; batch adversarial loss: 3.761532\n", - "epoch 4; iter: 10; batch classifier loss: 0.210927; batch adversarial loss: 3.754419\n", - "epoch 4; iter: 11; batch classifier loss: 0.189782; batch adversarial loss: 3.791569\n", - "epoch 4; iter: 12; batch classifier loss: 0.213103; batch adversarial loss: 3.763950\n", - "epoch 4; iter: 13; batch classifier loss: 0.202236; batch adversarial loss: 3.775052\n", - "epoch 4; iter: 14; batch classifier loss: 0.175027; batch adversarial loss: 3.775881\n", - "epoch 4; iter: 15; batch classifier loss: 0.213836; batch adversarial loss: 3.743286\n", - "epoch 4; iter: 16; batch classifier loss: 0.202140; batch adversarial loss: 3.749633\n", - "epoch 4; iter: 17; batch classifier loss: 0.200901; batch adversarial loss: 3.780682\n", - "epoch 4; iter: 18; batch classifier loss: 0.190720; batch adversarial loss: 3.780779\n", - "epoch 4; iter: 19; batch classifier loss: 0.199202; batch adversarial loss: 3.754910\n", - "epoch 4; iter: 20; batch classifier loss: 0.178533; batch adversarial loss: 3.774585\n", - "epoch 4; iter: 21; batch classifier loss: 0.175255; batch adversarial loss: 3.792412\n", - "epoch 4; iter: 22; batch classifier loss: 0.187553; batch adversarial loss: 3.789207\n", - "epoch 4; iter: 23; batch classifier loss: 0.192283; batch adversarial loss: 3.767899\n", - "epoch 4; iter: 24; batch classifier loss: 0.178437; batch adversarial loss: 3.779642\n", - "epoch 4; iter: 25; batch classifier loss: 0.192460; batch adversarial loss: 3.762065\n", - "epoch 4; iter: 26; batch classifier loss: 0.189481; batch adversarial loss: 3.753385\n", - "epoch 4; iter: 27; batch classifier loss: 0.174796; batch adversarial loss: 3.777086\n", - "epoch 4; iter: 28; batch classifier loss: 0.167956; batch adversarial loss: 3.797268\n", - "epoch 4; iter: 29; batch classifier loss: 0.164461; batch adversarial loss: 3.790265\n", - "epoch 4; iter: 30; batch classifier loss: 0.172177; batch adversarial loss: 3.776642\n" + "ename": "AttributeError", + "evalue": "'numpy.float32' object has no attribute 'name'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_epochs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m\u001b[0m in \u001b[0;36mfit\u001b[0;34m(train_data, epochs, batch_size, debias, protect_loss_weight, pred_learning_rate, protect_learning_rate)\u001b[0m\n\u001b[1;32m 70\u001b[0m ], feed_dict=batch_feed_dict)\n\u001b[1;32m 71\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 72\u001b[0;31m \u001b[0mprint_status_bar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"nerLabels\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpred_labels_loss_value\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 73\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[0mprint_status_bar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"nerLabels\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"nerLabels\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpred_labels_loss_value\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mprint_status_bar\u001b[0;34m(iteration, total, loss, metrics)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprint_status_bar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miteration\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtotal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mmetrics\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\" - \"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"{}: {:.4f}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;32mfor\u001b[0m \u001b[0mm\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mmetrics\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"\"\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0miteration\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mtotal\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m\"\\n\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"\\r{}/{} - \"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miteration\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtotal\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mend\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprint_status_bar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miteration\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtotal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mmetrics\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\" - \"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"{}: {:.4f}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;32mfor\u001b[0m \u001b[0mm\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mmetrics\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"\"\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0miteration\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mtotal\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m\"\\n\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"\\r{}/{} - \"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miteration\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtotal\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mend\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAttributeError\u001b[0m: 'numpy.float32' object has no attribute 'name'" ] } ], "source": [ - "ids_ph = tf.placeholder(tf.float32, shape=[32,128])\n", - "masks_ph = tf.placeholder(tf.float32, shape=[32,128])\n", - "sentenceIds_ph = tf.placeholder(tf.float32, shape=[32,128])\n", - "\n", - "gender_ph = tf.placeholder(tf.float32, shape=[32,128])\n", - "ner_labels_ph = tf.placeholder(tf.float32, shape=[32,128])\n", - "\n", - "global_step = tf.Variable(0, trainable=False)\n", - "starter_learning_rate = 0.001\n", - "learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step, 1000, 0.96, staircase=True)\n", - "\n", - "protect_vars = [var for var in tf.trainable_variables() if 'gender' in var.name]\n", - "pred_vars = model.layers[3]._trainable_weights + [var for var in tf.trainable_variables() if any(x in var.name for x in [\"pred_dense\",\"ner\"])]\n", - " \n", - "y_pred = model([ids_ph, masks_ph, sentenceIds_ph], training=True)\n", - "\n", - "ner_loss = model_utils.custom_loss(ner_labels_ph, y_pred[\"ner\"])\n", - "gender_loss = model_utils.custom_loss_protected(gender_ph, y_pred[\"gender\"])\n", - "\n", - "protect_opt = tf.train.AdamOptimizer(protect_learning_rate)\n", - "pred_opt = tf.train.AdamOptimizer(pred_learning_rate)\n", - "\n", - "protect_grads = {var: grad for (grad, var) in protect_opt.compute_gradients(gender_loss,var_list=pred_vars)}\n", - "pred_grads = []\n", - "\n", - "for (grad, var) in pred_opt.compute_gradients(ner_loss, var_list=pred_vars):\n", - " unit_protect = tf_normalize(protect_grads[var])\n", - " # the two lines below can be commented out to train without debiasing\n", - " grad -= tf.reduce_sum(grad * unit_protect) * unit_protect\n", - " grad -= tf.math.scalar_mul(protect_loss_weight, protect_grads[var])\n", - " pred_grads.append((grad, var))\n", - "\n", - "pred_min = pred_opt.apply_gradients(pred_grads, global_step=global_step)\n", - "protect_min = protect_opt.minimize(gender_loss, var_list=[protect_vars], global_step=global_step)\n", - "\n", - "model_utils.initialize_vars(sess)\n", - "\n", - "# Begin training\n", - "for epoch in range(num_epochs):\n", - " \n", - " shuffled_ids = np.random.choice(num_train_samples, num_train_samples)\n", - "\n", - " for i in range(num_train_samples//32):\n", - "\n", - " ids, masks, sentence_ids, ner_labels, gender_labels, race_labels = random_batch(train_data)\n", - "\n", - " batch_feed_dict = {ids_ph: ids, \n", - " masks_ph: masks,\n", - " sentenceIds_ph: sentence_ids,\n", - " gender_ph: gender_labels,\n", - " ner_labels_ph: race_labels}\n", - "\n", - "\n", - " _, _, pred_labels_loss_value, pred_protected_attributes_loss_vale = sess.run([\n", - " pred_min,\n", - " protect_min,\n", - " ner_loss,\n", - " gender_loss\n", - " ], feed_dict=batch_feed_dict)\n", - "\n", - " #if i % 200 == 0:\n", - " print(\"epoch %d; iter: %d; batch classifier loss: %f; batch adversarial loss: %f\" % (epoch, i, pred_labels_loss_value,\n", - " pred_protected_attributes_loss_vale))\n" + "fit(train_data, num_epochs, batch_size, True)" ] }, { diff --git a/notebooks/Script Tester.ipynb b/notebooks/Script Tester.ipynb index 20db5cd..bc7cbf2 100644 --- a/notebooks/Script Tester.ipynb +++ b/notebooks/Script Tester.ipynb @@ -2,27 +2,70 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from importlib import reload\n", "import sys\n", + "\n", "src_path = '../src' # change as needed\n", - "sys.path.insert(0,src_path)" + "sys.path.insert(0,src_path)\n", + "\n", + "max_length = 128" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:absl:Using /tmp/tfhub_modules to cache modules.\n", + "100%|██████████| 397080/397080 [02:03<00:00, 3212.44it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " tag cat occurences\n", + "0 B-MISC 0 800\n", + "1 I-LOC 1 328160\n", + "2 I-MISC 2 139600\n", + "3 I-ORG 3 237760\n", + "4 I-PER 4 977800\n", + "5 O 5 6661880\n", + "6 [nerCLS] 6 397080\n", + "7 [nerPAD] 7 39623065\n", + "8 [nerSEP] 8 397080\n", + "9 [nerX] 9 2063015\n", + "\n", + " tag cat occurences\n", + "0 AFRICAN-AMERICAN 0 436788\n", + "1 EUROPEAN 1 208467\n", + "2 [raceCLS] 2 397080\n", + "3 [racePAD] 3 39623065\n", + "4 [raceSEP] 4 397080\n", + "5 [raceX] 5 9763760\n", + "\n", + " tag cat occurences\n", + "0 FEMALE 0 367299\n", + "1 MALE 1 277956\n", + "2 [genderCLS] 2 397080\n", + "3 [genderPAD] 3 39623065\n", + "4 [genderSEP] 4 397080\n", + "5 [genderX] 5 9763760\n", + "\n" + ] + } + ], "source": [ "import data_generator;reload(data_generator)\n", "\n", - "#Start session\n", - "max_length = 128\n", - "\n", "train_data, val_data, test_data = data_generator.GetData(max_length)" ] }, @@ -37,25 +80,31 @@ "import tensorflow as tf\n", "tf.logging.set_verbosity(tf.logging.ERROR)\n", "import model_utils; reload(model_utils)\n", - "\n", - "adam_customized = tf.keras.optimizers.Adam(lr=0.001, beta_1=0.91, beta_2=0.999, epsilon=None, decay=0.1, amsgrad=False)\n", " \n", "config = tf.ConfigProto()\n", "config.gpu_options.allow_growth = True\n", "sess = tf.Session(config=config)\n", "\n", - "model = model_utils.NER()\n", + "model = model_utils.NER(max_length)\n", " \n", - "model.generate(max_length, train_layers=4, optimizer = adam_customized, debias=False, debiasWeight=0.95)\n", - "\n", - "# Instantiate variables\n", - "model_utils.initialize_vars(sess)\n", - "\n", + "model.generate(bert_train_layers=4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ "model.fit(\n", - " train_data, \n", + " sess,\n", + " train_data,\n", " val_data,\n", " epochs=8,\n", - " batch_size=32\n", + " batch_size=32,\n", + " debias=True\n", ")" ] }, @@ -65,7 +114,7 @@ "metadata": {}, "outputs": [], "source": [ - "cm = model.score(test_data)" + "bias = model.getBiasedPValues(test_data, num_iterations=1000)" ] }, { @@ -74,7 +123,7 @@ "metadata": {}, "outputs": [], "source": [ - "bias = model.getBiasedPValues(test_data, num_iterations=10000)" + "bias" ] }, { @@ -83,9 +132,16 @@ "metadata": {}, "outputs": [], "source": [ - "bias" + "sess.close()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, diff --git a/src/__pycache__/data_generator.cpython-37.pyc b/src/__pycache__/data_generator.cpython-37.pyc index e1d9a16..726f040 100644 Binary files a/src/__pycache__/data_generator.cpython-37.pyc and b/src/__pycache__/data_generator.cpython-37.pyc differ diff --git a/src/__pycache__/model_utils.cpython-37.pyc b/src/__pycache__/model_utils.cpython-37.pyc index 2295405..3c07717 100644 Binary files a/src/__pycache__/model_utils.cpython-37.pyc and b/src/__pycache__/model_utils.cpython-37.pyc differ diff --git a/src/__pycache__/token_generator.cpython-37.pyc b/src/__pycache__/token_generator.cpython-37.pyc index 8b01d7f..f89b369 100644 Binary files a/src/__pycache__/token_generator.cpython-37.pyc and b/src/__pycache__/token_generator.cpython-37.pyc differ diff --git a/src/model_utils.py b/src/model_utils.py index e19e609..3fc036f 100644 --- a/src/model_utils.py +++ b/src/model_utils.py @@ -11,7 +11,7 @@ from scipy import spatial import matplotlib.pyplot as plt -from tqdm import tqdm +from tqdm.auto import tqdm def initialize_vars(sess): sess.run(tf.local_variables_initializer()) @@ -36,8 +36,7 @@ def custom_acc_orig_tokens(y_true, y_pred): mask = (y_label < 6) y_label_masked = tf.boolean_mask(y_label, mask) - y_predicted = tf.math.argmax(input = tf.reshape(tf.layers.Flatten()(tf.cast(y_pred, tf.float64)),\ - [-1, 10]), axis=1) + y_predicted = tf.math.argmax(input = tf.reshape(tf.layers.Flatten()(tf.cast(y_pred, tf.float64)), [-1, 10]), axis=1) y_predicted_masked = tf.boolean_mask(y_predicted, mask) @@ -67,6 +66,29 @@ def custom_acc_orig_non_other_tokens(y_true, y_pred): return tf.reduce_mean(tf.cast(tf.equal(y_predicted_masked,y_label_masked) , dtype=tf.float64)) +def custom_acc_protected(y_true, y_pred): + """ + calculate loss dfunction filtering out also the newly inserted labels + + y_true: Shape: (batch x (max_length) ) + y_pred: predictions. Shape: (batch x x (max_length + 1) x num_distinct_ner_tokens ) + + returns: accuracy + """ + + #get labels and predictions + + y_label = tf.reshape(tf.layers.Flatten()(tf.cast(y_true, tf.int64)),[-1]) + + mask = (y_label < 2) + y_label_masked = tf.boolean_mask(y_label, mask) + + y_predicted = tf.math.argmax(input = tf.reshape(tf.layers.Flatten()(tf.cast(y_pred, tf.float64)), [-1, 6]), axis=1) + + y_predicted_masked = tf.boolean_mask(y_predicted, mask) + + return tf.reduce_mean(tf.cast(tf.equal(y_predicted_masked,y_label_masked) , dtype=tf.float64)) + def custom_loss(y_true, y_pred): """ calculate loss function explicitly, filtering out 'extra inserted labels' @@ -200,92 +222,181 @@ def call(self, inputs): return result + def compute_output_shape(self, input_shape): - return (input_shape[0], self.output_size) + return (None, 128, 768) + class NER(): - def __init__(self, filename=None): + def __init__(self, max_input_length, filename=None): + + self.max_input_length = max_input_length + if filename: filename = "../models/"+filename self.model = tf.keras.load(filename) - def generate(self, max_input_length, train_layers, optimizer, debias, debiasWeight=0.5): + def generate(self, bert_train_layers): - in_id = tf.keras.layers.Input(shape=(max_input_length,), name="input_ids") - in_mask = tf.keras.layers.Input(shape=(max_input_length,), name="input_masks") - in_segment = tf.keras.layers.Input(shape=(max_input_length,), name="segment_ids") - - bert_inputs = [in_id, in_mask, in_segment] + in_id = tf.keras.layers.Input(shape=(self.max_input_length,), name="input_ids") + in_mask = tf.keras.layers.Input(shape=(self.max_input_length,), name="input_masks") + in_segment = tf.keras.layers.Input(shape=(self.max_input_length,), name="segment_ids") + in_nerLabels = tf.keras.layers.Input(shape=(self.max_input_length, 10), name="ner_labels_true") + + bert_sequence = BertLayer(n_fine_tune_layers=bert_train_layers)([in_id, in_mask, in_segment]) + + dense = tf.keras.layers.Dense(256, activation='relu', name='pred_dense')(bert_sequence) + + dense = tf.keras.layers.Dropout(rate=0.1)(dense) + + pred = tf.keras.layers.Dense(10, activation='softmax', name='ner')(dense) + + reshape = tf.keras.layers.Reshape((self.max_input_length, 10))(pred) + + concatenate = tf.keras.layers.Concatenate(axis=-1)([in_nerLabels, reshape]) - bert_sequence = BertLayer(n_fine_tune_layers=train_layers)(bert_inputs) - - dense = tf.keras.layers.Dense(256, activation='relu', name='dense')(bert_sequence) + genderPred = tf.keras.layers.Dense(6, activation='softmax', name='gender')(concatenate) + + racePred = tf.keras.layers.Dense(6, activation='softmax', name='race')(concatenate) - dense = tf.keras.layers.Dropout(rate=0.1)(dense) + self.model = tf.keras.models.Model(inputs=[in_id, in_mask, in_segment, in_nerLabels], outputs={ + "ner": pred, + "race": racePred, + "gender": genderPred + }) - pred = tf.keras.layers.Dense(10, activation='softmax', name='ner')(dense) + self.model.summary() - if(debias): + def fit(self, sess, train_data, val_data, epochs, batch_size, debias, + gender_loss_weight = 0.1, race_loss_weight = 0.1, pred_learning_rate = 2**-16, protect_learning_rate = 2**-16): - genderPred = tf.keras.layers.Dense(6, activation='softmax', name='gender')(pred) + num_train_samples = len(train_data["nerLabels"]) - racePred = tf.keras.layers.Dense(6, activation='softmax', name='race')(pred) + ids_ph = tf.placeholder(tf.float32, shape=[batch_size, self.max_input_length]) + masks_ph = tf.placeholder(tf.float32, shape=[batch_size, self.max_input_length]) + sentenceIds_ph = tf.placeholder(tf.float32, shape=[batch_size, self.max_input_length]) - losses = { - "ner": custom_loss, - "race": custom_loss_protected, - "gender": custom_loss_protected - } + ner_ph = tf.placeholder(tf.float32, shape=[batch_size, self.max_input_length]) + gender_ph = tf.placeholder(tf.float32, shape=[batch_size, self.max_input_length]) + race_ph = tf.placeholder(tf.float32, shape=[batch_size, self.max_input_length]) + ner_onehot_ph = tf.placeholder(tf.float32, shape=[batch_size, self.max_input_length, 10]) - lossWeights = { - "ner": 1.0-debiasWeight, - "race": debiasWeight/2.0, - "gender": debiasWeight/2.0 - } + global_step = tf.Variable(0, trainable=False) + starter_learning_rate = 0.001 + learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step, 1000, 0.96, staircase=True) - self.model = tf.keras.models.Model(inputs=bert_inputs, outputs={ - "ner": pred, - "race": racePred, - "gender": genderPred - }) + gender_vars = [var for var in tf.trainable_variables() if 'gender' in var.name] + race_vars = [var for var in tf.trainable_variables() if 'race' in var.name] + ner_vars = self.model.layers[3]._trainable_weights + [var for var in tf.trainable_variables() if any(x in var.name for x in ["pred_dense","ner"])] - self.model.compile( - loss=losses, - loss_weights=lossWeights, - optimizer=optimizer, - metrics={"ner": [custom_acc_orig_tokens,custom_acc_orig_non_other_tokens]}) + y_pred = self.model([ids_ph, masks_ph, sentenceIds_ph, ner_onehot_ph], training=True) - else: - - self.model = tf.keras.models.Model(inputs=bert_inputs, outputs=pred) + ner_loss = custom_loss(ner_ph, y_pred["ner"]) + gender_loss = custom_loss_protected(gender_ph, y_pred["gender"]) + race_loss = custom_loss_protected(race_ph, y_pred["race"]) + + ner_opt = tf.train.AdamOptimizer(pred_learning_rate) + gender_opt = tf.train.AdamOptimizer(protect_learning_rate) + race_opt = tf.train.AdamOptimizer(protect_learning_rate) + + gender_grads = {var: grad for (grad, var) in ner_opt.compute_gradients( + gender_loss, + var_list=ner_vars + )} + + race_grads = {var: grad for (grad, var) in ner_opt.compute_gradients( + race_loss, + var_list=ner_vars + )} + + ner_grads = [] + + tf_normalize = lambda x: x / (tf.norm(x) + np.finfo(np.float32).tiny) + + for (grad, var) in ner_opt.compute_gradients(ner_loss, var_list=ner_vars): + + if debias: + + gender_unit_protect = tf_normalize(gender_grads[var]) + race_unit_protect = tf_normalize(race_grads[var]) + + grad -= tf.reduce_sum(grad * gender_unit_protect) * gender_unit_protect + grad -= tf.math.scalar_mul(gender_loss_weight, gender_grads[var]) + + grad -= tf.reduce_sum(grad * race_unit_protect) * race_unit_protect + grad -= tf.math.scalar_mul(race_loss_weight, race_grads[var]) + + ner_grads.append((grad, var)) + + ner_min = ner_opt.apply_gradients(ner_grads, global_step=global_step) + + gender_min = gender_opt.minimize(gender_loss, var_list=[gender_vars], global_step=global_step) + + race_min = race_opt.minimize(race_loss, var_list=[race_vars], global_step=global_step) + + initialize_vars(sess) + + epoch_pb = tqdm(range(1, epochs+1)) - self.model.compile(loss=custom_loss, optimizer=optimizer, metrics=[custom_acc_orig_tokens, - custom_acc_orig_non_other_tokens]) + for epoch in epoch_pb: + + epoch_pb.set_description("Epoch %s" % epoch) + + shuffled_ids = np.random.choice(num_train_samples, num_train_samples) + + run_pb = tqdm(range(num_train_samples//batch_size)) + + for i in run_pb: + + batch_ids = shuffled_ids[batch_size*i: batch_size*(i+1)] + + batch_feed_dict = {ids_ph: train_data["inputs"][0][batch_ids], + masks_ph: train_data["inputs"][1][batch_ids], + sentenceIds_ph: train_data["inputs"][2][batch_ids], + ner_onehot_ph: np.array([np.eye(10)[i.reshape(-1)] for i in train_data["nerLabels"][batch_ids]]), + gender_ph: train_data["genderLabels"][batch_ids], + race_ph: train_data["raceLabels"][batch_ids], + ner_ph: train_data["nerLabels"][batch_ids]} + + _, _, _, ner_loss_value, gender_loss_value, race_loss_value = sess.run([ + ner_min, + gender_min, + race_min, + ner_loss, + gender_loss, + race_loss + ], feed_dict=batch_feed_dict) + + run_pb.set_description("nl: %.2f; gl: %.2f; rl: %.2f" % \ + (ner_loss_value, gender_loss_value, race_loss_value)) + + inputs = val_data["inputs"] - self.model.summary() - - def fit(self, train_data, val_data, epochs, batch_size): - - self.model.fit( - train_data["inputs"], - { - "ner": train_data["nerLabels"], - "gender": train_data["genderLabels"], - "race": train_data["raceLabels"] - }, - validation_data=(val_data["inputs"], { - "ner": val_data["nerLabels"], - "gender": val_data["genderLabels"], - "race": val_data["raceLabels"] - }), - epochs=epochs, - batch_size=batch_size - ) + inputs.append(np.array([np.eye(10)[i.reshape(-1)] for i in train_data["nerLabels"]])) + + val_y_pred = self.model.predict(inputs, batch_size=32) + + ner_pred = val_y_pred[1] + ner_true = val_data["nerLabels"] + + acc_orig_tokens = custom_acc_orig_tokens(ner_true, ner_pred).eval(session=sess) + acc_orig_non_other_tokens = custom_acc_orig_non_other_tokens(ner_true, ner_pred).eval(session=sess) + + gender_pred = val_y_pred[0] + gender_true = val_data["genderLabels"] + + acc_gender = custom_acc_protected(gender_true, gender_pred).eval(session=sess) + + race_pred = val_y_pred[2] + race_true = val_data["raceLabels"] + + acc_race = custom_acc_protected(race_true, race_pred).eval(session=sess) + print("acc_ner: %.2f; acc_ner_non_other: %.2f; acc_gender: %.2f; acc_race: %.2f" % (acc_orig_tokens, acc_orig_non_other_tokens, acc_gender, acc_race)) def score(self, data, batch_size=32): - y_pred = self.model.predict(data["inputs"], batch_size=batch_size) + y_pred = self.model.predict(data["inputs"], batch_size=batch_size)[1] y_true = data["nerLabels"] @@ -361,7 +472,7 @@ def getCosineDistances(self, inputs, name_masks): return np.array(distances) - def getBiasedPValues(self, data, num_iterations=1000): + def getBiasedPValues(self, data, num_iterations=10000): distances = self.getCosineDistances(data["inputs"], data["nameMasks"])