-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from adamingas/development
Minor fixes to gradient and hessian calculation
- Loading branch information
Showing
7 changed files
with
216 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Comparing ordinal with usual classification" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"In this notebook we use the sklearn diabetes dataset as a comparison between the LGBMOrdinal, LGBMClassifier, and Logistic regression models. We convert the continuous label to classes by binnging it using quantiles.\n", | ||
"\n", | ||
"We then train and test the models several times with different train/test splits and evaluate their mean absolute deviation instead of accuracy. This metric penalises wrong predictions that are further appart from the true label more than those which are closer." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import pandas as pd\n", | ||
"from lightgbm import LGBMClassifier\n", | ||
"from sklearn.datasets import load_diabetes\n", | ||
"from sklearn.linear_model import LinearRegression, LogisticRegression\n", | ||
"from sklearn.model_selection import train_test_split\n", | ||
"\n", | ||
"from ordinalgbt.lgb import LGBMOrdinal\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"data = load_diabetes()\n", | ||
"X = pd.DataFrame(data[\"data\"], columns = data[\"feature_names\"])\n", | ||
"y = data[\"target\"]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"nq = 10\n", | ||
"thresholds = np.append(np.append(y.min()-1,np.quantile(y,np.arange(0,1,1/nq)[1:])),y.max()+1)\n", | ||
"yq = pd.cut(x=y,bins=thresholds,right=True,labels=['q'+str(z+1) for z in range(nq)])\n", | ||
"yord = yq.astype('category').codes\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
" mdl MAE\n", | ||
"0 LGBMOrdinal 2.0\n", | ||
"1 SKlearn Multinomial 2.5\n", | ||
"2 LGBMClassifier 2.1\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"holder, coef = [], []\n", | ||
"nsim = 10\n", | ||
"for ii in range(nsim):\n", | ||
" # Do a train/test split (80/20)\n", | ||
" ytrain, ytest, Xtrain, Xtest = train_test_split(yord, X, stratify=yord,test_size=0.2,\n", | ||
" random_state=ii)\n", | ||
" # Ordinal model\n", | ||
" mdl_ord = LGBMOrdinal()\n", | ||
" mdl_ord.fit(Xtrain, ytrain)\n", | ||
" # Multinomial LGBM model\n", | ||
" mdl_class = LGBMClassifier()\n", | ||
" mdl_class.fit(Xtrain, ytrain)\n", | ||
" # Multinomial Regression model\n", | ||
" mdl_multi = LogisticRegression(penalty='l2',solver='lbfgs',max_iter=1000)\n", | ||
" mdl_multi.fit(Xtrain,ytrain)\n", | ||
" # Make predictions\n", | ||
" yhat_ord = mdl_ord.predict(Xtest)\n", | ||
" yhat_multi = mdl_multi.predict(Xtest)\n", | ||
" yhat_class = mdl_class.predict(Xtest)\n", | ||
" # Get MAE\n", | ||
" acc_class = np.abs(yhat_class - ytest).mean()\n", | ||
" acc_multi = np.abs(yhat_multi - ytest).mean()\n", | ||
" acc_ord = np.abs(yhat_ord - ytest).mean()\n", | ||
" holder.append(pd.DataFrame({'ord':acc_ord,'multi':acc_multi,'class':acc_class},index=[ii]))\n", | ||
"\n", | ||
"df_mae = pd.concat(holder).mean(axis=0).reset_index().rename(columns={'index':'mdl',0:'MAE'})\n", | ||
"di_lbls = {'ord':'LGBMOrdinal','multi':'SKlearn Multinomial','class':'LGBMClassifier'}\n", | ||
"df_mae = df_mae.assign(mdl=lambda x: x.mdl.map(di_lbls))\n", | ||
"print(np.round(df_mae,1))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.17" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters