-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
233dace
commit 531881f
Showing
31 changed files
with
1,039 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file removed
BIN
-30.1 KB
_images/280b0983173c7ba7761b8269ee671a9e09a211b665e5057675d259f419bdf5dc.png
Binary file not shown.
Binary file added
BIN
+30.1 KB
_images/7b8c3ec584a87ebc33714dff897bc08357a6a2ef9584c266903b4a8efddcfacc.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ae968ff7", | ||
"metadata": {}, | ||
"source": [ | ||
"# 决策树\n", | ||
"\n", | ||
"假设我们有一个数据集包含鸢尾花的特征(萼片长度、萼片宽度、花瓣长度和花瓣宽度)以及对应的类别(Setosa、Versicolor、Virginica)。最简单的决策树就是个if-else-then的分支。例如对鸢尾花数据分类可以这样做。" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "dd11abb5", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Sample 1 prediction: setosa\n", | ||
"Sample 2 prediction: virginica\n", | ||
"Sample 3 prediction: setosa\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import numpy as np\n", | ||
"\n", | ||
"def predict(features):\n", | ||
" # 决策树的判定条件和结果\n", | ||
" if features[2] <= 2.45:\n", | ||
" return 'setosa'\n", | ||
" elif features[3] <= 1.75:\n", | ||
" if features[2] <= 4.95:\n", | ||
" if features[3] <= 1.65:\n", | ||
" return 'versicolor'\n", | ||
" else:\n", | ||
" return 'virginica'\n", | ||
" else:\n", | ||
" if features[3] <= 1.55:\n", | ||
" return 'virginica'\n", | ||
" else:\n", | ||
" return 'versicolor'\n", | ||
" else:\n", | ||
" return 'virginica'\n", | ||
"\n", | ||
"# 测试样例\n", | ||
"X_test = np.array([[5.1, 3.5, 1.4, 0.2],\n", | ||
" [6.3, 2.9, 5.6, 1.8],\n", | ||
" [4.9, 3.0, 1.4, 0.2]])\n", | ||
"\n", | ||
"for i in range(len(X_test)):\n", | ||
" prediction = predict(X_test[i])\n", | ||
" print(\"Sample\", i+1, \"prediction:\", prediction)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "6abd8538", | ||
"metadata": {}, | ||
"source": [ | ||
"## 决策树算法流程\n", | ||
"\n", | ||
"将这个经验法则通过数学和算法的方式来自动化处理,就衍生了很多决策树算法。以鸢尾花分类为例,这些算法基本上是这样的过程:\n", | ||
"\n", | ||
"\n", | ||
"* 特征选择:从训练数据集中选择最优特征作为当前节点的划分特征。通常使用某种准则(如信息增益、基尼指数或信息增益比)来评估特征的重要性。\n", | ||
"\n", | ||
"![](../../images/dicision-tree/dicision-tree-label.png)\n", | ||
"\n", | ||
"![](../../images/dicision-tree/dicision-tree-feature.png)\n", | ||
"\n", | ||
"* 树节点划分:根据选择的特征将训练数据集划分成子集。对于分类问题,每个子集对应于一个特征值或特征值范围;对于回归问题,则根据特征的阈值进行划分。\n", | ||
"\n", | ||
"![](../../images/dicision-tree/dicision-tree.png)\n", | ||
"\n", | ||
"* 递归构建子树:对每个子集递归地应用上述步骤,构建决策树的子树。如果子集中的样本属于同一类别(或具有相似的回归值),则停止划分。\n", | ||
"\n", | ||
"第一层\n", | ||
"\n", | ||
"根节点:被分成17份,8是%2F9否,总体的信息熵为:\n", | ||
"\n", | ||
"$\n", | ||
"H_0 = - p(是) * log_2(p(是)) - p(否) * log_2(p(否))\n", | ||
" = - 0.471 * log2(0.471) - 0.529 * log2(0.529)\n", | ||
" ≈ 0.998\n", | ||
"$\n", | ||
"\n", | ||
"第二层\n", | ||
"\n", | ||
"清晰:被分成9份,7是/2否,它的信息熵为:\n", | ||
"\n", | ||
"$H_1 = - 7 / 9 * log2(7 / 9) - 2 / 9 * log2(2 / 9) = 0.764$\n", | ||
"\n", | ||
"稍糊:被分成5份,1是/4否,它的信息熵为:\n", | ||
"\n", | ||
"$H_2 = - 1 / 5 * log2(4 / 5) - 1 / 5 * log2(4 / 5) = 0.722$\n", | ||
"\n", | ||
"模糊:被分成3份,0是/3否,它的信息熵为:\n", | ||
"\n", | ||
"$H_3 = 0$\n", | ||
"\n", | ||
"假设我们选取纹理为分类依据,把它作为根节点,那么第二层的加权信息熵可以定义为:\n", | ||
"\n", | ||
"$H’ = 9/17 * H_1 + 5/17 * H_2 + 3/17 * H_3 $\n", | ||
"\n", | ||
"因为$H’< H$,也就是随着决策的进行,其不确定度要减小才行,决策肯定是一个由不确定到确定状态的转变。\n", | ||
"\n", | ||
"\n", | ||
"* 剪枝:对生成的决策树进行剪枝操作,以减小过拟合风险。剪枝方法可以是预剪枝(在构建树时提前停止划分)或后剪枝(在完整构建树之后剪掉部分叶节点)。\n", | ||
"\n", | ||
"* 终止条件:根据停止条件,确定是否继续构建子树。常见的停止条件包括达到最大深度、样本数量不足或没有更多特征可用。\n", | ||
"\n", | ||
"* 输出决策树:得到最终的决策树模型,可以将其用于预测新的输入数据。\n", | ||
"\n", | ||
"对于集成学习算法(如随机森林和GBDT),会有一些额外步骤:\n", | ||
"\n", | ||
"* 集成学习:对多个决策树进行集成。对于随机森林,每个决策树通过自助采样从原始训练数据集中获得;对于GBDT,每个决策树都是基于前一棵树的残差进行训练。\n", | ||
"\n", | ||
"* 预测结果:对于分类问题,通过投票或多数表决来确定最终的类别;对于回归问题,则取平均或加权平均作为最终的预测值。\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "c78064ee", | ||
"metadata": {}, | ||
"source": [ | ||
"## 常见决策树算法\n", | ||
"\n", | ||
"信息熵(information entropy)是信息论的基本概念。描述信息源各可能事件发生的不确定性。20世纪40年代,香农(C.E.Shannon)借鉴了热力学的概念,把信息中排除了冗余后的平均信息量称为“信息熵”,并给出了计算信息熵的数学表达式。信息熵的提出解决了对信息的量化度量问题[1]。\n", | ||
"\n", | ||
"\n", | ||
"设有一个分类问题的训练数据集S,其中包含C个类别。对于每个类别c,假设样本属于该类别的概率为p©。则数据集S的信息熵定义如下:\n", | ||
"\n", | ||
"$\n", | ||
"H = \\text{Entropy}(S) = -\\sum_{c=1}^{C} p(c) \\log_{2}(p(c))\n", | ||
"$\n", | ||
"\n", | ||
"\n", | ||
"其中,对数函数可以选择任意基数(通常选择2作为基数),p©表示样本属于类别c的概率。\n", | ||
"\n", | ||
"以下是计算信息熵的简单示例:\n", | ||
"\n", | ||
"$S = \\{ A, A, A, B, B, C \\} $\n", | ||
"\n", | ||
"$\n", | ||
"p(A) = 3/6 = 0.5 \\\\\n", | ||
"p(B) = 2/6 ≈ 0.333 \\\\\n", | ||
"p(C) = 1/6 ≈ 0.167\n", | ||
"$\n", | ||
"\n", | ||
"$\n", | ||
"H = Entropy(S) \\\\\n", | ||
"= -(0.5 * log_2(0.5) + 0.333 * log_2(0.333) + 0.167 * log_2(0.167)) \\\\\n", | ||
"≈ -(0.5 * (-1) + 0.333 * (-1.585) + 0.167 * (-2)) \\\\\n", | ||
"≈ 1.459 \\\\\n", | ||
"$\n", | ||
"\n", | ||
"因此,该数据集的信息熵约为1.459。\n", | ||
"\n", | ||
"通过计算信息熵,我们可以衡量数据集中样本的不确定性程度。在决策树算法中,我们希望选择具有最低信息熵(或最大信息增益)的特征来进行划分,以使得子集的不确定性减少,并提高决策树模型的预测能力。\n", | ||
"\n", | ||
"越高 = 越混乱 = 越不纯 = 越不确定\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "4f606e65", | ||
"metadata": {}, | ||
"source": [ | ||
"### ID3(Iterative Dichotomiser 3)\n", | ||
"\n", | ||
"使用信息增益作为特征选择准则来构建决策树。适用于离散型特征和多类别问题。\n", | ||
"\n", | ||
"$Gain(X, Y) = H(Y) - H(Y|X)$\n", | ||
"\n", | ||
"比如上面实例中我选择纹理作为根节点,将根节点一分为三,则:\n", | ||
"\n", | ||
"$Gain(X, 纹理) = 0.998 - 0.764 = 0.234$\n", | ||
"\n", | ||
"意思是,没有选择纹理特征前,是否是好瓜的信息熵为0.998,在我选择了纹理这一特征之后,信息熵下降为0.764,信息熵下降了0.234,也就是信息增益为0.234。\n", | ||
"\n", | ||
"如果某个特征的信息增益较大,意味着使用该特征进行划分能够带来更多的信息量,因此该特征被认为是比较重要的。\n", | ||
"\n", | ||
"\n", | ||
"### C4.5\n", | ||
"\n", | ||
"C4.5是ID3算法的改进版,使用信息增益比作为特征选择准则。能够处理缺失值,并具有更好的鲁棒性。\n", | ||
"\n", | ||
"### CART(Classification and Regression Trees)\n", | ||
"\n", | ||
"通用的决策树算法,可以处理分类和回归问题。使用基尼系数作为特征选择准则,在每个节点上生成二叉树结构。\n", | ||
"\n", | ||
"### CHAID(Chi-squared Automatic Interaction Detection)\n", | ||
"\n", | ||
"一种基于卡方检验的决策树算法,适用于分类问题。能够处理离散型和连续型特征,并支持多类别问题。\n", | ||
"\n", | ||
"### MARS(Multivariate Adaptive Regression Splines)\n", | ||
"\n", | ||
"基于样条函数的非参数回归方法,通过构建多个分段线性的子模型构建决策树。适用于回归和分类任务。\n", | ||
"\n", | ||
"### Random Forest(随机森林)\n", | ||
"一种集成学习算法,基于决策树构建多个决策树,并通过投票或平均预测结果来做出最终的分类或回归决策。具有鲁棒性和泛化能力。\n", | ||
"\n", | ||
"### GBDT(Gradient Boosting Decision Trees)\n", | ||
"一种梯度提升决策树算法,通过连续训练多个决策树来提高预测性能。每棵树都是基于前一棵树的残差进行训练。\n", | ||
"\n", | ||
"### XGBoost(eXtreme Gradient Boosting)\n", | ||
"一种梯度提升决策树算法,结合了梯度提升和正则化技术,具有较高的准确性和泛化能力。" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "d23aac64", | ||
"metadata": {}, | ||
"source": [ | ||
"\n", | ||
"[1] 信息熵 https://baike.baidu.com/item/%E4%BF%A1%E6%81%AF%E7%86%B5/7302318" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.