Skip to content

Commit

Permalink
feat: 贝叶斯回归
Browse files Browse the repository at this point in the history
  • Loading branch information
tianxuzhang committed Dec 6, 2023
1 parent db84e2d commit 9f962f9
Showing 1 changed file with 70 additions and 6 deletions.
76 changes: 70 additions & 6 deletions docs/回归/贝叶斯回归.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
"source": [
"\n",
"## 贝叶斯回归\n",
"贝叶斯回归是一种基于贝叶斯统计推断的回归方法。它通过引入先验分布来表达对参数的不确定性,并利用观测数据来更新参数的后验分布。假设我们有一个训练集包含N个样本,每个样本由输入特征X和对应的输出标签y组成。具体步骤如下:\n",
"贝叶斯回归是一种基于贝叶斯统计推断的回归方法。它通过引入先验分布来表达对参数的不确定性,并利用观测数据来更新参数的后验分布。假设我们有一个训练集包含$N$个样本,每个样本由输入特征$X$和对应的输出标签$y$组成。具体步骤如下:\n",
"\n",
"* 参数建模,定义先验分布:选择适当的先验分布来表示参数的先验知识或假设。\n",
"\n",
" * 建立参数 $w$ 的先验分布:$p(w)$。通常选择高斯分布作为w的先验,即 $p(w) = N(w|0, Σ0)$,其中 $0$ 是均值向量,$Σ0$ 是协方差矩阵。\n",
" * 建立参数 $w$ 的先验分布:$p(w)$。通常选择高斯分布作为$w$的先验,即 $p(w) = N(w|0, Σ0)$,其中 $0$ 是均值向量,$Σ0$ 是协方差矩阵。\n",
" * 建立输出标签y的条件分布:$p(y|X, w)$。通常假设 $y$ 服从高斯分布,即 $p(y|X, w) = N(y|Xw, σ^2I)$,其中 $σ^2$ 是噪声方差,$I$ 是单位矩阵。\n",
"\n",
"* 后验推断,计算后验分布:根据观测数据和先验分布,使用贝叶斯定理计算参数的后验分布。\n",
Expand All @@ -20,8 +20,8 @@
"\n",
"* 参数估计和预测,推断和预测:利用后验分布进行参数估计和预测。可以使用后验分布的均值、中位数等作为点估计,还可以计算预测分布来预测新数据。\n",
"\n",
" * 参数估计:根据后验分布,可以获得参数w的点估计,如后验均值或最大后验估计(MAP)。\n",
" * 预测:通过获取参数w的后验分布,可以计算新数据点的预测分布,即 p(y*|x*, X, y) = ∫ p(y*|x*, w) * p(w|X, y) dw。这里,y表示预测的输出标签,x表示新的输入特征\n",
" * 参数估计:根据后验分布,可以获得参数$w$的点估计,如后验均值或最大后验估计(MAP)。\n",
" * 预测:通过获取参数$w$的后验分布,可以计算新数据点的预测分布,即 $p(y*|x*, X, y) = ∫ p(y*|x*, w) * p(w|X, y) dw$。这里,$y$表示预测的输出标签,$x$表示新的输入特征\n",
"\n",
"贝叶斯回归提供了全面的概率建模方式,能够量化参数的不确定性,并灵活地引入先验知识。它对小样本、高噪声数据以及需要考虑模型不确定性的情况特别有帮助。\n",
"\n",
Expand All @@ -32,10 +32,74 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "e84cfb5a",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Auto-assigning NUTS sampler...\n",
"Initializing NUTS using jitter+adapt_diag...\n",
"Multiprocess sampling (4 chains in 4 jobs)\n",
"NUTS: [sigma, w]\n",
"Sampling 4 chains: 100%|██████████| 8000/8000 [00:02<00:00, 3036.46draws/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"参数估计结果:\n",
"w__0 2.890521\n",
"w__1 4.865710\n",
"sigma 0.945172\n",
"Name: mean, dtype: float64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1000/1000 [00:01<00:00, 503.68it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"预测结果:\n",
"[ 1.03153314 -3.61562589 2.16812673 -3.11588044 9.60127944\n",
" 12.16905026 -0.23648124 4.25748233 -4.45208439 -7.32287496\n",
" 3.27508941 6.5172929 -0.836755 2.46292051 -6.08647484\n",
" 0.58920231 -3.83147137 6.42264627 -3.61320578 1.41008255\n",
" 6.48528022 0.24595292 -0.12118109 0.07075903 -8.62226662\n",
" 11.63699736 3.94074076 -0.54817445 1.58446021 0.80333973\n",
" -8.5854543 -2.80430425 -5.25514791 -4.58258886 -14.3224442\n",
" -1.84326201 4.46406574 6.46984307 0.21513803 -14.46026694\n",
" -2.29835858 0.19768268 2.16854864 -4.15003522 9.75786452\n",
" -3.88164949 -7.4700798 3.36505357 4.478443 1.69431611\n",
" -2.70376263 7.05479675 1.63607418 -4.06979375 -4.62857722\n",
" -2.06437145 5.40670935 3.61310883 3.05419556 -1.92042677\n",
" -3.81055859 -0.7341311 -2.46628125 0.46042887 6.3505802\n",
" -5.52622373 -3.90441137 -4.68569514 2.0306271 0.6378601\n",
" 2.38703032 -2.07502364 4.05460138 6.58315489 -0.94999262\n",
" 3.896083 -12.56547747 -0.50816651 -1.16460501 10.06830438\n",
" -0.7681246 -1.35430004 -13.34554212 -5.37841587 -1.85885039\n",
" -8.55399374 -1.31399772 0.86910898 -1.72386609 1.79894702\n",
" -8.58868727 4.80504887 9.87574364 -5.17028682 6.4685171\n",
" 1.12598447 8.74019895 0.1646509 -5.30968101 8.1008152 ]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"import pymc3 as pm\n",
"import numpy as np\n",
Expand Down

0 comments on commit 9f962f9

Please sign in to comment.