Skip to content

Commit

Permalink
Use cmdstanr
Browse files Browse the repository at this point in the history
  • Loading branch information
zettsu-t committed Sep 20, 2021
1 parent 528aba6 commit 8089f09
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 0 deletions.
5 changes: 5 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ RUN python3.8 -m pip install numpy
## Install R Packages without specifying their versions
RUN R -e 'install.packages("remotes")'
RUN Rscript -e 'remotes::install_version("knitr")'
RUN Rscript -e 'remotes::install_version("markdown")'
RUN Rscript -e 'remotes::install_version("assertthat")'
RUN Rscript -e 'remotes::install_version("extrafont")'
RUN Rscript -e 'remotes::install_version("functional")'
Expand All @@ -37,7 +38,11 @@ RUN Rscript -e 'remotes::install_version("reticulate")'
RUN Rscript -e 'remotes::install_version("rlang")'
RUN Rscript -e 'remotes::install_version("xfun")'
RUN Rscript -e 'remotes::install_version("cloc", repos = c("https://cinc.rud.is", "https://cloud.r-project.org/"))'
RUN Rscript -e 'remotes::install_version("styler")'
RUN Rscript -e 'remotes::install_version("lintr")'

## A workaround to use extrafont
RUN Rscript -e 'remotes::install_version("Rttf2pt1", version = "1.3.8")'
## Setup Japanese fonts for R
RUN Rscript -e 'extrafont::font_import(prompt = FALSE)'

Expand Down
13 changes: 13 additions & 0 deletions Dockerfile_stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
FROM r30min

## Install a package that Stan and cmdstanr require
RUN apt-get install -y libtbb2 libxt-dev

## Install R Packages without specifying their versions
RUN Rscript -e 'remotes::install_version("cmdstanr", repos = c("https://mc-stan.org/r-packages/", getOption("repos")))'
RUN Rscript -e 'remotes::install_version("posterior")'
RUN Rscript -e 'remotes::install_version("bayesplot")'

## Create a mount point
ARG R_USERNAME=rstudio
RUN mkdir -p "/home/${R_USERNAME}/work"
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,7 @@ NumPyがインストールされていなければ、インストールしてく
## ライセンス

本レポジトリのライセンスは、[MITライセンス](LICENSE.txt)です。

## Rプログラマが30分ではじめるStan

R から Stan を使う方法を stan_example.Rmd に書きました。 Releases に含めましたのでご覧ください。
25 changes: 25 additions & 0 deletions mixture_normals.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Based on
// https://mc-stan.org/users/documentation/case-studies/identifying_mixture_models.html
// https://mc-stan.org/docs/2_27/stan-users-guide/vectorizing-mixtures.html
data {
int<lower=1> N;
vector[N] Y;
}

parameters {
ordered[2] mus;
real<lower=0> sigma_s;
real<lower=0> sigma_l;
real<lower=0, upper=1> theta;
}

model {
sigma_s ~ exponential(1);
sigma_l ~ exponential(1);
theta ~ beta(2, 2);
for (i in 1:N) {
target += log_mix(theta,
normal_lpdf(Y[i] | mus[1], sigma_s),
normal_lpdf(Y[i] | mus[2], sigma_l));
}
}
165 changes: 165 additions & 0 deletions stan_example.Rmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
---
title: "Rプログラマが30分ではじめるStan"
author: "Zettsu Tatsuya"
date: '`r format(Sys.time(), "%Y/%m/%d")`'
output:
html_document:
toc: true
toc_float: true
toc_collapsed: true
pdf_document:
latex_engine: xelatex
beamer_presentation:
pandoc_args:
- --latex-engine
- xelatex
header-includes:
\usepackage{float}
documentclass: bxjsarticle
classoption: xelatex,ja=standard
urlcolor: blue
---

```{r setup_r_packages, include=FALSE, cache=TRUE}
## knitrで使うパッケージ
library(kableExtra)
library(reticulate)
library(xfun)
## 必要ならpython3にパスを通す
reticulate::use_python("/usr/bin/python3.8")
```

R から Stan を使う方法について説明する。 cmdstanr を導入する方法は、[こちらの記事](https://norimune.net/3609)を参考にした。Windows で cmdstanr を使うと grep がエラーになるので、Docker上で RStudio Serverを起動して、ウェブブラウザから localhost:8787 にアクセスして使う。

## Dockerイメージを作って起動する

RStudioにログインするとき、Username は rstudio 、PasswordはDockerコンテナを起動するときに下記(-e PASSWORD=)で指定するパスワードを入力する。 -v オプションでWindowsのファイルシステムをマウントすると便利だろう。

```{bash docker, eval=FALSE, echo=TRUE, cache=FALSE}
docker build -t r30min .
docker build -f Dockerfile_stan -t stan .
docker run -e PASSWORD=yourpassword -p 8787:8787 -d -v C:/path/to/Rin30minutes:/home/rstudio/work -it stan
```

## Stanを使う準備をする

本 R Markdown ファイルをknitすると、下記の内容に従って cmdstan をインストールするので注意すること。

### Rのパッケージを使う

Stanをインストールする。

```{r setup_stan, message=FALSE, warning=FALSE, cache=FALSE}
library(tidyverse)
library(cmdstanr)
library(posterior)
library(bayesplot)
invisible(install_cmdstan())
```

### モデルをコンパイルする

2つの正規分布が混じった観測値を分ける、Stanのコードを示す。[Stan公式の例](https://mc-stan.org/users/documentation/case-studies/identifying_mixture_models.html)とほぼ同じである。より簡潔な実装例は[こちら](https://mc-stan.org/docs/2_27/stan-users-guide/vectorizing-mixtures.html)を参照。

log_mix は2分布の log_sum_exp を簡単に書く記法である。ラベルスイッチングを防ぐために、平均値パラメータ mus の大小を固定している(添え字が大きい方が平均値が大きい)。

```{r show_model, eval=TRUE, echo=FALSE, comment=NA, cache=TRUE, cache.extra=file.info("mixture_normals.stan")}
cat(xfun::read_utf8("mixture_normals.stan"), sep = "\n")
```

Stanモデルをコンパイルする。

```{r setup_model, cache=TRUE, cache.extra=file.info("mixture_normals.stan")}
model <- cmdstanr::cmdstan_model("mixture_normals.stan")
```

### データを準備する

Small, large の2つの正規分布を混ぜる。

```{r setup_data, message=FALSE, warning=FALSE, cache=TRUE}
given_seed <- 12345
set.seed(given_seed)
n <- 2000
ratio <- 0.4
n_small <- ceiling(n * ratio)
n_large <- n - n_small
mu_small <- 0.0
mu_large <- 2.5
sigma_small <- 0.75
sigma_large <- 1.25
make_data_set <- function(n, mu, sigma, label) {
df <- tibble::tibble(y = rnorm(n, mu, sigma))
df$label <- label
df
}
df_small <- make_data_set(n_small, mu_small, sigma_small, "small")
df_large <- make_data_set(n_large, mu_large, sigma_large, "large")
df <- dplyr::bind_rows(df_small, df_large) %>%
dplyr::mutate(label = factor(label)) %>%
dplyr::mutate(label = forcats::fct_relevel(label, c("small", "large"))) %>%
dplyr::sample_n(NROW(.))
input_data <- list(N = NROW(df), Y = df$y)
```

```{r draw_data, echo=FALSE, message=FALSE, warning=FALSE, cache=TRUE}
g <- ggplot(df)
g <- g + geom_histogram(aes(x = y, fill = label), position = "dodge")
g <- g + scale_fill_manual(values = c("royalblue3", "orange"))
plot(g)
df %>%
dplyr::group_by(label) %>%
dplyr::summarize_all(list(mean=mean, stddev=sd, n=length)) %>%
dplyr::ungroup() %>%
kable() %>%
kable_styling()
```

## Stanを実行する

### Stanでモデルのパラメータを求める

数分掛かるので待つ。

```{r fit_by_stan, message=FALSE, warning=FALSE, cache=TRUE}
fit <- model$sample(
data = input_data, seed = given_seed, chains = 4, parallel_chains = 2,
iter_warmup = 5000, iter_sampling = 2500, refresh = 2500
)
```

結果を見る。

```{r summarize_stan, message=FALSE, warning=FALSE, cache=TRUE}
fit$summary()
```

### 事後分布を表示する

R hatをみて、Stanモデルのパラメータが収束したことを確認する。

```{r set_stan_color, message=FALSE, warning=FALSE, cache=TRUE}
bayesplot::color_scheme_set("brightblue")
```

```{r draw_rhat, message=FALSE, warning=FALSE, cache=TRUE}
mcmc_rhat(rhat(fit))
```

Stanモデルのパラメータの事後分布をみる。元の分布と大体同じである。

```{r draw_hist, message=FALSE, warning=FALSE, cache=TRUE}
mcmc_hist(fit$draws("mus"))
mcmc_hist(fit$draws("sigma_s"))
mcmc_hist(fit$draws("sigma_l"))
mcmc_hist(fit$draws("theta"))
mcmc_trace(fit$draws("mus"))
mcmc_trace(fit$draws("sigma_s"))
mcmc_trace(fit$draws("sigma_l"))
mcmc_trace(fit$draws("theta"))
```

0 comments on commit 8089f09

Please sign in to comment.