diff --git "a/docs/\345\233\236\345\275\222/\350\264\235\345\217\266\346\226\257\345\233\236\345\275\222.ipynb" "b/docs/\345\233\236\345\275\222/\350\264\235\345\217\266\346\226\257\345\233\236\345\275\222.ipynb" index 3f8c02d..d648b47 100644 --- "a/docs/\345\233\236\345\275\222/\350\264\235\345\217\266\346\226\257\345\233\236\345\275\222.ipynb" +++ "b/docs/\345\233\236\345\275\222/\350\264\235\345\217\266\346\226\257\345\233\236\345\275\222.ipynb" @@ -7,11 +7,11 @@ "source": [ "\n", "## 贝叶斯回归\n", - "贝叶斯回归是一种基于贝叶斯统计推断的回归方法。它通过引入先验分布来表达对参数的不确定性,并利用观测数据来更新参数的后验分布。假设我们有一个训练集包含N个样本,每个样本由输入特征X和对应的输出标签y组成。具体步骤如下:\n", + "贝叶斯回归是一种基于贝叶斯统计推断的回归方法。它通过引入先验分布来表达对参数的不确定性,并利用观测数据来更新参数的后验分布。假设我们有一个训练集包含$N$个样本,每个样本由输入特征$X$和对应的输出标签$y$组成。具体步骤如下:\n", "\n", "* 参数建模,定义先验分布:选择适当的先验分布来表示参数的先验知识或假设。\n", "\n", - " * 建立参数 $w$ 的先验分布:$p(w)$。通常选择高斯分布作为w的先验,即 $p(w) = N(w|0, Σ0)$,其中 $0$ 是均值向量,$Σ0$ 是协方差矩阵。\n", + " * 建立参数 $w$ 的先验分布:$p(w)$。通常选择高斯分布作为$w$的先验,即 $p(w) = N(w|0, Σ0)$,其中 $0$ 是均值向量,$Σ0$ 是协方差矩阵。\n", " * 建立输出标签y的条件分布:$p(y|X, w)$。通常假设 $y$ 服从高斯分布,即 $p(y|X, w) = N(y|Xw, σ^2I)$,其中 $σ^2$ 是噪声方差,$I$ 是单位矩阵。\n", "\n", "* 后验推断,计算后验分布:根据观测数据和先验分布,使用贝叶斯定理计算参数的后验分布。\n", @@ -20,8 +20,8 @@ "\n", "* 参数估计和预测,推断和预测:利用后验分布进行参数估计和预测。可以使用后验分布的均值、中位数等作为点估计,还可以计算预测分布来预测新数据。\n", "\n", - " * 参数估计:根据后验分布,可以获得参数w的点估计,如后验均值或最大后验估计(MAP)。\n", - " * 预测:通过获取参数w的后验分布,可以计算新数据点的预测分布,即 p(y*|x*, X, y) = ∫ p(y*|x*, w) * p(w|X, y) dw。这里,y表示预测的输出标签,x表示新的输入特征。\n", + " * 参数估计:根据后验分布,可以获得参数$w$的点估计,如后验均值或最大后验估计(MAP)。\n", + " * 预测:通过获取参数$w$的后验分布,可以计算新数据点的预测分布,即 $p(y*|x*, X, y) = ∫ p(y*|x*, w) * p(w|X, y) dw$。这里,$y$表示预测的输出标签,$x$表示新的输入特征。\n", "\n", "贝叶斯回归提供了全面的概率建模方式,能够量化参数的不确定性,并灵活地引入先验知识。它对小样本、高噪声数据以及需要考虑模型不确定性的情况特别有帮助。\n", "\n", @@ -32,10 +32,74 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "e84cfb5a", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Auto-assigning NUTS sampler...\n", + "Initializing NUTS using jitter+adapt_diag...\n", + "Multiprocess sampling (4 chains in 4 jobs)\n", + "NUTS: [sigma, w]\n", + "Sampling 4 chains: 100%|██████████| 8000/8000 [00:02<00:00, 3036.46draws/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "参数估计结果:\n", + "w__0 2.890521\n", + "w__1 4.865710\n", + "sigma 0.945172\n", + "Name: mean, dtype: float64\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1000/1000 [00:01<00:00, 503.68it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "预测结果:\n", + "[ 1.03153314 -3.61562589 2.16812673 -3.11588044 9.60127944\n", + " 12.16905026 -0.23648124 4.25748233 -4.45208439 -7.32287496\n", + " 3.27508941 6.5172929 -0.836755 2.46292051 -6.08647484\n", + " 0.58920231 -3.83147137 6.42264627 -3.61320578 1.41008255\n", + " 6.48528022 0.24595292 -0.12118109 0.07075903 -8.62226662\n", + " 11.63699736 3.94074076 -0.54817445 1.58446021 0.80333973\n", + " -8.5854543 -2.80430425 -5.25514791 -4.58258886 -14.3224442\n", + " -1.84326201 4.46406574 6.46984307 0.21513803 -14.46026694\n", + " -2.29835858 0.19768268 2.16854864 -4.15003522 9.75786452\n", + " -3.88164949 -7.4700798 3.36505357 4.478443 1.69431611\n", + " -2.70376263 7.05479675 1.63607418 -4.06979375 -4.62857722\n", + " -2.06437145 5.40670935 3.61310883 3.05419556 -1.92042677\n", + " -3.81055859 -0.7341311 -2.46628125 0.46042887 6.3505802\n", + " -5.52622373 -3.90441137 -4.68569514 2.0306271 0.6378601\n", + " 2.38703032 -2.07502364 4.05460138 6.58315489 -0.94999262\n", + " 3.896083 -12.56547747 -0.50816651 -1.16460501 10.06830438\n", + " -0.7681246 -1.35430004 -13.34554212 -5.37841587 -1.85885039\n", + " -8.55399374 -1.31399772 0.86910898 -1.72386609 1.79894702\n", + " -8.58868727 4.80504887 9.87574364 -5.17028682 6.4685171\n", + " 1.12598447 8.74019895 0.1646509 -5.30968101 8.1008152 ]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "import pymc3 as pm\n", "import numpy as np\n",