-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
212 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
FROM r30min | ||
|
||
## Install a package that Stan and cmdstanr require | ||
RUN apt-get install -y libtbb2 libxt-dev | ||
|
||
## Install R Packages without specifying their versions | ||
RUN Rscript -e 'remotes::install_version("cmdstanr", repos = c("https://mc-stan.org/r-packages/", getOption("repos")))' | ||
RUN Rscript -e 'remotes::install_version("posterior")' | ||
RUN Rscript -e 'remotes::install_version("bayesplot")' | ||
|
||
## Create a mount point | ||
ARG R_USERNAME=rstudio | ||
RUN mkdir -p "/home/${R_USERNAME}/work" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// Based on | ||
// https://mc-stan.org/users/documentation/case-studies/identifying_mixture_models.html | ||
// https://mc-stan.org/docs/2_27/stan-users-guide/vectorizing-mixtures.html | ||
data { | ||
int<lower=1> N; | ||
vector[N] Y; | ||
} | ||
|
||
parameters { | ||
ordered[2] mus; | ||
real<lower=0> sigma_s; | ||
real<lower=0> sigma_l; | ||
real<lower=0, upper=1> theta; | ||
} | ||
|
||
model { | ||
sigma_s ~ exponential(1); | ||
sigma_l ~ exponential(1); | ||
theta ~ beta(2, 2); | ||
for (i in 1:N) { | ||
target += log_mix(theta, | ||
normal_lpdf(Y[i] | mus[1], sigma_s), | ||
normal_lpdf(Y[i] | mus[2], sigma_l)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
--- | ||
title: "Rプログラマが30分ではじめるStan" | ||
author: "Zettsu Tatsuya" | ||
date: '`r format(Sys.time(), "%Y/%m/%d")`' | ||
output: | ||
html_document: | ||
toc: true | ||
toc_float: true | ||
toc_collapsed: true | ||
pdf_document: | ||
latex_engine: xelatex | ||
beamer_presentation: | ||
pandoc_args: | ||
- --latex-engine | ||
- xelatex | ||
header-includes: | ||
\usepackage{float} | ||
documentclass: bxjsarticle | ||
classoption: xelatex,ja=standard | ||
urlcolor: blue | ||
--- | ||
|
||
```{r setup_r_packages, include=FALSE, cache=TRUE} | ||
## knitrで使うパッケージ | ||
library(kableExtra) | ||
library(reticulate) | ||
library(xfun) | ||
## 必要ならpython3にパスを通す | ||
reticulate::use_python("/usr/bin/python3.8") | ||
``` | ||
|
||
R から Stan を使う方法について説明する。 cmdstanr を導入する方法は、[こちらの記事](https://norimune.net/3609)を参考にした。Windows で cmdstanr を使うと grep がエラーになるので、Docker上で RStudio Serverを起動して、ウェブブラウザから localhost:8787 にアクセスして使う。 | ||
|
||
## Dockerイメージを作って起動する | ||
|
||
RStudioにログインするとき、Username は rstudio 、PasswordはDockerコンテナを起動するときに下記(-e PASSWORD=)で指定するパスワードを入力する。 -v オプションでWindowsのファイルシステムをマウントすると便利だろう。 | ||
|
||
```{bash docker, eval=FALSE, echo=TRUE, cache=FALSE} | ||
docker build -t r30min . | ||
docker build -f Dockerfile_stan -t stan . | ||
docker run -e PASSWORD=yourpassword -p 8787:8787 -d -v C:/path/to/Rin30minutes:/home/rstudio/work -it stan | ||
``` | ||
|
||
## Stanを使う準備をする | ||
|
||
本 R Markdown ファイルをknitすると、下記の内容に従って cmdstan をインストールするので注意すること。 | ||
|
||
### Rのパッケージを使う | ||
|
||
Stanをインストールする。 | ||
|
||
```{r setup_stan, message=FALSE, warning=FALSE, cache=FALSE} | ||
library(tidyverse) | ||
library(cmdstanr) | ||
library(posterior) | ||
library(bayesplot) | ||
invisible(install_cmdstan()) | ||
``` | ||
|
||
### モデルをコンパイルする | ||
|
||
2つの正規分布が混じった観測値を分ける、Stanのコードを示す。[Stan公式の例](https://mc-stan.org/users/documentation/case-studies/identifying_mixture_models.html)とほぼ同じである。より簡潔な実装例は[こちら](https://mc-stan.org/docs/2_27/stan-users-guide/vectorizing-mixtures.html)を参照。 | ||
|
||
log_mix は2分布の log_sum_exp を簡単に書く記法である。ラベルスイッチングを防ぐために、平均値パラメータ mus の大小を固定している(添え字が大きい方が平均値が大きい)。 | ||
|
||
```{r show_model, eval=TRUE, echo=FALSE, comment=NA, cache=TRUE, cache.extra=file.info("mixture_normals.stan")} | ||
cat(xfun::read_utf8("mixture_normals.stan"), sep = "\n") | ||
``` | ||
|
||
Stanモデルをコンパイルする。 | ||
|
||
```{r setup_model, cache=TRUE, cache.extra=file.info("mixture_normals.stan")} | ||
model <- cmdstanr::cmdstan_model("mixture_normals.stan") | ||
``` | ||
|
||
### データを準備する | ||
|
||
Small, large の2つの正規分布を混ぜる。 | ||
|
||
```{r setup_data, message=FALSE, warning=FALSE, cache=TRUE} | ||
given_seed <- 12345 | ||
set.seed(given_seed) | ||
n <- 2000 | ||
ratio <- 0.4 | ||
n_small <- ceiling(n * ratio) | ||
n_large <- n - n_small | ||
mu_small <- 0.0 | ||
mu_large <- 2.5 | ||
sigma_small <- 0.75 | ||
sigma_large <- 1.25 | ||
make_data_set <- function(n, mu, sigma, label) { | ||
df <- tibble::tibble(y = rnorm(n, mu, sigma)) | ||
df$label <- label | ||
df | ||
} | ||
df_small <- make_data_set(n_small, mu_small, sigma_small, "small") | ||
df_large <- make_data_set(n_large, mu_large, sigma_large, "large") | ||
df <- dplyr::bind_rows(df_small, df_large) %>% | ||
dplyr::mutate(label = factor(label)) %>% | ||
dplyr::mutate(label = forcats::fct_relevel(label, c("small", "large"))) %>% | ||
dplyr::sample_n(NROW(.)) | ||
input_data <- list(N = NROW(df), Y = df$y) | ||
``` | ||
|
||
```{r draw_data, echo=FALSE, message=FALSE, warning=FALSE, cache=TRUE} | ||
g <- ggplot(df) | ||
g <- g + geom_histogram(aes(x = y, fill = label), position = "dodge") | ||
g <- g + scale_fill_manual(values = c("royalblue3", "orange")) | ||
plot(g) | ||
df %>% | ||
dplyr::group_by(label) %>% | ||
dplyr::summarize_all(list(mean=mean, stddev=sd, n=length)) %>% | ||
dplyr::ungroup() %>% | ||
kable() %>% | ||
kable_styling() | ||
``` | ||
|
||
## Stanを実行する | ||
|
||
### Stanでモデルのパラメータを求める | ||
|
||
数分掛かるので待つ。 | ||
|
||
```{r fit_by_stan, message=FALSE, warning=FALSE, cache=TRUE} | ||
fit <- model$sample( | ||
data = input_data, seed = given_seed, chains = 4, parallel_chains = 2, | ||
iter_warmup = 5000, iter_sampling = 2500, refresh = 2500 | ||
) | ||
``` | ||
|
||
結果を見る。 | ||
|
||
```{r summarize_stan, message=FALSE, warning=FALSE, cache=TRUE} | ||
fit$summary() | ||
``` | ||
|
||
### 事後分布を表示する | ||
|
||
R hatをみて、Stanモデルのパラメータが収束したことを確認する。 | ||
|
||
```{r set_stan_color, message=FALSE, warning=FALSE, cache=TRUE} | ||
bayesplot::color_scheme_set("brightblue") | ||
``` | ||
|
||
```{r draw_rhat, message=FALSE, warning=FALSE, cache=TRUE} | ||
mcmc_rhat(rhat(fit)) | ||
``` | ||
|
||
Stanモデルのパラメータの事後分布をみる。元の分布と大体同じである。 | ||
|
||
```{r draw_hist, message=FALSE, warning=FALSE, cache=TRUE} | ||
mcmc_hist(fit$draws("mus")) | ||
mcmc_hist(fit$draws("sigma_s")) | ||
mcmc_hist(fit$draws("sigma_l")) | ||
mcmc_hist(fit$draws("theta")) | ||
mcmc_trace(fit$draws("mus")) | ||
mcmc_trace(fit$draws("sigma_s")) | ||
mcmc_trace(fit$draws("sigma_l")) | ||
mcmc_trace(fit$draws("theta")) | ||
``` |