Skip to content

Commit

Permalink
feat: 多任务lasso回归
Browse files Browse the repository at this point in the history
  • Loading branch information
tianxuzhang committed Dec 3, 2023
1 parent 12fd94c commit b6a05a8
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 7 deletions.
70 changes: 70 additions & 0 deletions docs/回归/正则化线性回归.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,76 @@
"print(best_alpha_lassolars)\n"
]
},
{
"cell_type": "markdown",
"id": "8c102d8d",
"metadata": {},
"source": [
"### 多任务lasso回归\n",
"\n",
"多任务Lasso(Multi-Task Lasso)是一种用于多任务回归的正则化方法。它是对Lasso回归的扩展,可以同时处理多个相关联的目标变量。\n",
"\n",
"在多任务Lasso中,目标函数的表达式如下:\n",
"\n",
"$\\underset{W}{\\text{minimize }} \\frac{1}{2n_{samples}} ||X W - Y||F ^ 2 + \\alpha ||W||{21}$\n",
"\n",
"其中,$X$ 是输入特征矩阵,$Y$ 是多个目标变量的观测值矩阵,**$W$** 是模型的系数矩阵,(\\alpha) 是正则化参数。\n",
"\n",
"* 第一项 $\\frac{1}{2n_{samples}} ||X W - Y||F ^ 2$ 衡量了模型预测值与真实观测值之间的误差。这里使用了Frobenius范数来计算误差的平方和,同时将其除以 (2n{samples}) 进行归一化,其中 (n_{samples}) 是样本数量。\n",
"\n",
"* 第二项 $\\alpha ||W||_{21}$ 是指系数矩阵 $W$ 的 $L2,1$ 范数,也称为分组Lasso范数。$L2,1$ 范数将每个任务之间的系数向量视为一个分组,并对每个分组的L2范数进行惩罚。这鼓励模型选择相同或相似的特征子集来适应多个相关任务。\n",
"\n",
"通过调整正则化参数 $\\alpha$ 的大小,可以控制模型的拟合程度和稀疏性。\n",
"\n",
"多任务Lasso适用于在多个相关任务上进行回归建模,并且鼓励共享特征选择。它可以通过联合优化多个相关任务来提高模型的性能,并且可以识别出对所有任务都有影响的重要特征。\n",
"\n",
"在Scikit-learn中,可以使用MultiTaskLasso类来实现多任务Lasso回归,并进行参数估计和模型拟合。"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "2bb9e85b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"模型系数:\n",
"[[ 0.02731864 0.08910204 0.03778079 0.00933277 -0.01963377 -0.00207699\n",
" 0.02826692 0.00827997 -0.12920136 -0.04683641]\n",
" [-0.06873417 -0.0625006 0.02790663 -0.00367414 0.11322009 0.00652524\n",
" -0.00146783 -0.05008547 -0.02371156 0.03941871]\n",
" [-0.02352807 -0.08121243 0.06746598 0.02794926 -0.04672092 -0.03314402\n",
" 0.02533989 -0.03688576 0.08089949 0.032083 ]\n",
" [ 0.029961 -0.03568944 0.07128624 -0.01492458 0.0525043 0.09876974\n",
" 0.10219089 -0.01597094 -0.02868559 0.05300582]\n",
" [-0.03094184 -0.12592359 -0.17639916 0.00561484 -0.01498715 -0.06237257\n",
" -0.00310203 0.01522918 0.07753082 0.02590595]]\n"
]
}
],
"source": [
"from sklearn.linear_model import MultiTaskLasso\n",
"import numpy as np\n",
"\n",
"# 生成样本数据\n",
"np.random.seed(0)\n",
"n_samples, n_features = 100, 10\n",
"X = np.random.randn(n_samples, n_features)\n",
"Y = np.random.randn(n_samples, 5) # 生成5个相关联的目标变量\n",
"\n",
"# 创建MultiTaskLasso对象并进行拟合\n",
"alpha = 0.1 # 正则化参数\n",
"multi_task_lasso = MultiTaskLasso(alpha=alpha)\n",
"multi_task_lasso.fit(X, Y)\n",
"\n",
"# 输出模型系数\n",
"print(\"模型系数:\")\n",
"print(multi_task_lasso.coef_)"
]
},
{
"cell_type": "markdown",
"id": "12cc34bc",
Expand Down
28 changes: 21 additions & 7 deletions docs/回归/距离和范数.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -250,25 +250,39 @@
"\n",
"以下是几种常见的范数及其公式:\n",
"\n",
"L1范数(曼哈顿距离或绝对值范数):\n",
"* L1范数(曼哈顿距离或绝对值范数):\n",
"\n",
"对于一个n维向量x,L1范数表示为:\n",
"$$|x|_1 = \\sum_{i=1}^{n} |x_i|$$\n",
"\n",
"L2范数(欧几里得范数):\n",
"* L2范数(欧几里得范数):\n",
"\n",
"对于一个n维向量x,L2范数表示为:\n",
"$$|x|_2 = \\sqrt{\\sum_{i=1}^{n} |x_i|^2}$$\n",
"\n",
"Lp范数(p范数):\n",
"* Lp范数(p范数):\n",
"\n",
"对于一个n维向量x,Lp范数表示为:\n",
"$$|x|_p = \\left(\\sum_{i=1}^{n} |x_i|^p\\right)^{\\frac{1}{p}}$$\n",
"\n",
"无穷范数(最大值范数):\n",
"* 无穷范数(最大值范数):\n",
"\n",
"对于一个n维向量x,无穷范数表示为:\n",
"$$|x|_{\\infty} = \\max(|x_1|, |x_2|, …, |x_n|)$$\n",
"\n",
"矩阵Frobenius范数:\n",
"* 矩阵Frobenius范数:\n",
"\n",
"弗罗贝尼乌斯\n",
"\n",
"对于一个m×n矩阵A,Frobenius范数表示为:\n",
"$$|A|_F = \\sqrt{\\sum_{i=1}^{m}\\sum_{j=1}^{n} |a_{ij}|^2}$$"
"$$|A|_F = \\sqrt{\\sum_{i=1}^{m}\\sum_{j=1}^{n} |a_{ij}|^2}$$\n",
"\n",
"\n",
"* L2,1范数(多任务学习的正则化项):\n",
"\n",
"对于一个m×n的系数矩阵 $A$,L2,1范数定义为各个任务之间的L2范数的和。其表达式为:\n",
"\n",
"$$ |A|{2,1} = \\sum{j=1}^{m} \\sqrt{\\sum_{i=1}^{n} w_{ij}^2} $$"
]
},
{
Expand Down Expand Up @@ -309,7 +323,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.11.5"
}
},
"nbformat": 4,
Expand Down

0 comments on commit b6a05a8

Please sign in to comment.