Skip to content

Commit

Permalink
uploading plotting utilities
Browse files Browse the repository at this point in the history
  • Loading branch information
lnalborczyk committed Jun 13, 2023
1 parent d7435f6 commit 1147d7a
Show file tree
Hide file tree
Showing 39 changed files with 572 additions and 215 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ Imports:
rlang,
geomtextpath,
ggplot2,
stringr,
tgp,
tidyr
Suggests:
knitr,
patchwork,
rmarkdown,
testthat (>= 3.0.0)
Config/testthat/edition: 3
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand

S3method(plot,DEoptim_momimi)
S3method(plot,momimi_full)
export(activation)
export(fitting)
Expand All @@ -8,6 +9,7 @@ export(loss)
export(model)
export(onset_offset)
export(quantiles_props)
export(simulating)
importFrom(magrittr,"%>%")
importFrom(rlang,.data)
importFrom(stats,median)
Expand Down
205 changes: 195 additions & 10 deletions R/fitting.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
#' @param nstudies Numeric, number of starting values in the LHS.
#' @param initial_pop_constraints Boolean, whether to use additional constraints when sampling initial parameter values.
#' @param error_function Character, error function to be used when fitting the model.
#' @param model_version Version of the model ("TMM" or "PIM").
#' @param method Optimisation method.
#' @param maxit Maximum number of iterations.
#' @param model_version Character, version of the model ("TMM" or "PIM").
#' @param method Character, optimisation method (DEoptim seems to work best).
#' @param maxit Numeric, maximum number of iterations.
#' @param verbose Boolean, whether to print progress during fitting.
#'
#' @return The optimised parameter values and further convergence informatio
#' @return The optimised parameter values and further convergence information.
#'
#' @importFrom magrittr %>%
#'
Expand All @@ -29,16 +30,33 @@ fitting <- function (
nsims = NULL,
par_names = NULL,
lower_bounds, upper_bounds,
nstudies = 200, initial_pop_constraints = FALSE,
error_function, model_version,
nstudies = 200,
initial_pop_constraints = FALSE,
error_function = c("g2", "rmse", "sse", "wsse", "ks"),
model_version = c("TMM", "PIM"),
method = c(
"SANN", "GenSA", "pso", "hydroPSO", "DEoptim",
"Nelder-Mead", "BFGS", "L-BFGS-B", "bobyqa", "nlminb",
"all_methods", "optimParallel"
),
maxit = 100
maxit = 100, verbose = TRUE
) {

# defining parameter names according to the chosen model (if null)
if (is.null(par_names) ) {

if (model_version == "TMM") {

par_names <- c("amplitude_activ", "peak_time_activ", "curvature_activ", "exec_threshold")

} else if (model_version == "PIM") {

par_names <- c("amplitude_ratio", "peak_time", "curvature_activ", "curvature_inhib")

}

}

# some tests for variable types
stopifnot("data must be a dataframe..." = is.data.frame(data) )
stopifnot("nsims must be a numeric..." = is.numeric(nsims) )
Expand Down Expand Up @@ -143,7 +161,7 @@ fitting <- function (

# starting the optimisation
fit <- DEoptim::DEoptim(
fn = loss,
fn = momimi::loss,
data = data,
nsims = nsims,
error_function = error_function,
Expand All @@ -154,7 +172,7 @@ fitting <- function (
# maximum number of iterations
itermax = maxit,
# printing progress iteration
trace = TRUE,
trace = verbose,
# printing progress every 10 iterations
# trace = 10,
# defines the differential evolution strategy (defaults to 2)
Expand Down Expand Up @@ -194,7 +212,7 @@ fitting <- function (
)

# setting the class of the resulting object
class(fitt) <- c("DEoptim_momimi", "DEoptim", "data.frame")
class(fit) <- c("DEoptim_momimi", "DEoptim", "data.frame")

} else if (method %in% c("Nelder-Mead", "BFGS", "L-BFGS-B", "bobyqa", "nlminb") ) {

Expand Down Expand Up @@ -257,3 +275,170 @@ fitting <- function (
return (fit)

}

#' @export

plot.DEoptim_momimi <- function (
x, original_data,
method = c("ppc", "latent"),
action_mode = c("executed", "imagined"),
model_version = c("TMM", "PIM"),
...
) {

# some tests
method <- match.arg(method)
action_mode <- match.arg(action_mode)
model_version <- match.arg(model_version)

# retrieving estimated pars
estimated_pars <- as.numeric(x$optim$bestmem)

if (method == "ppc") {

# simulating data using these parameter values
simulating(
nsims = 200,
nsamples = 2000,
true_pars = estimated_pars,
action_mode = action_mode,
model_version = model_version
) %>%
# removing NAs or aberrant simulated data
stats::na.omit() %>%
dplyr::filter(.data$reaction_time < 3 & .data$movement_time < 3) %>%
tidyr::pivot_longer(cols = .data$reaction_time:.data$movement_time) %>%
ggplot2::ggplot(ggplot2::aes(x = .data$value, colour = .data$name, fill = .data$name) ) +
ggplot2::geom_density(
data = original_data %>% tidyr::pivot_longer(cols = .data$reaction_time:.data$movement_time),
color = "white",
position = "identity",
alpha = 0.5, show.legend = FALSE
) +
ggplot2::geom_density(size = 1, fill = NA, show.legend = FALSE) +
ggplot2::theme_bw(base_size = 12, base_family = "Open Sans") +
ggplot2::labs(
title = "Observed and simulated distributions of RTs/MTs",
x = "Reaction/Movement time (in seconds)", y = "Probability density"
)

} else if (method == "latent") {

if (model_version == "TMM") {

par_names <- c("amplitude_activ", "peak_time_activ", "curvature_activ", "exec_threshold")

parameters_estimates_summary <- paste(as.vector(rbind(
paste0(par_names, ": "),
paste0(as.character(round(estimated_pars, 3) ), "\n")
) ), collapse = "") %>% stringr::str_sub(end = -2)

model(
nsims = 500, nsamples = 2000,
exec_threshold = estimated_pars[4] * estimated_pars[1],
imag_threshold = 0.5 * estimated_pars[4] * estimated_pars[1],
amplitude_activ = estimated_pars[1],
peak_time_activ = log(estimated_pars[2]),
curvature_activ = estimated_pars[3],
model_version = "TMM",
full_output = TRUE
) %>%
tidyr::pivot_longer(cols = .data$activation) %>%
ggplot2::ggplot(
ggplot2::aes(
x = .data$time, y = .data$value,
group = interaction(.data$sim, .data$name)
)
) +
ggplot2::geom_hline(
yintercept = estimated_pars[4] * estimated_pars[1],
linetype = 2
) +
ggplot2::geom_hline(
yintercept = 0.5 * estimated_pars[4] * estimated_pars[1],
linetype = 2
) +
# plotting average
ggplot2::stat_summary(
ggplot2::aes(group = .data$name, colour = .data$name),
fun = "median", geom = "line",
linewidth = 1, alpha = 1,
show.legend = FALSE
) +
# displaying estimated parameter values
ggplot2::annotate(
geom = "label",
x = Inf, y = Inf,
hjust = 1, vjust = 1,
label = parameters_estimates_summary,
family = "Courier"
) +
ggplot2::theme_bw(base_size = 12, base_family = "Open Sans") +
ggplot2::labs(
title = "Latent functions",
x = "Time within a trial (in seconds)",
y = "Activation/inhibition (a.u.)",
colour = "",
fill = ""
)

} else if (model_version == "PIM") {

par_names <- c("amplitude_ratio", "peak_time", "curvature_activ", "curvature_inhib")

parameters_estimates_summary <- paste(as.vector(rbind(
paste0(par_names, ": "),
paste0(as.character(round(estimated_pars, 3) ), "\n")
) ), collapse = "") %>% stringr::str_sub(end = -2)

model(
nsims = 500, nsamples = 2000,
exec_threshold = 1, imag_threshold = 0.5,
amplitude_activ = 1.5,
peak_time_activ = log(estimated_pars[2]),
curvature_activ = estimated_pars[3],
amplitude_inhib = 1.5 / estimated_pars[1],
peak_time_inhib = log(estimated_pars[2]),
curvature_inhib = estimated_pars[4] * estimated_pars[3],
model_version = "PIM",
full_output = TRUE
) %>%
tidyr::pivot_longer(cols = .data$activation:.data$balance) %>%
ggplot2::ggplot(
ggplot2::aes(
x = .data$time, y = .data$value,
group = interaction(.data$sim, .data$name),
colour = .data$name
)
) +
ggplot2::geom_hline(yintercept = 1, linetype = 2) +
ggplot2::geom_hline(yintercept = 0.5, linetype = 2) +
# plotting average
ggplot2::stat_summary(
ggplot2::aes(group = .data$name, colour = .data$name),
fun = "median", geom = "line",
linewidth = 1, alpha = 1,
show.legend = TRUE
) +
# displaying estimated parameter values
ggplot2::annotate(
geom = "label",
x = Inf, y = Inf,
hjust = 1, vjust = 1,
label = parameters_estimates_summary,
family = "Courier"
) +
ggplot2::theme_bw(base_size = 12, base_family = "Open Sans") +
ggplot2::labs(
title = "Latent activation, inhibition, and balance functions",
x = "Time within a trial (in seconds)",
y = "Activation/inhibition (a.u.)",
colour = "",
fill = ""
)

}

}

}
27 changes: 21 additions & 6 deletions R/generating_initialpop.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#' @param lower_bounds Numeric, vector of lower bounds for parameters.
#' @param upper_bounds Numeric, vector of upper bounds for parameters.
#' @param model_version Character, threshold modulation model ("TMM") or parallel inhibition model ("PIM").
#' @param verbose Boolean, whether to print progress during fitting.
#'
#' @return Returns a dataframe a plausible (according to custom constraints) initial parameter values.
#'
Expand All @@ -22,7 +23,7 @@
generating_initialpop <- function (
nstudies, action_mode,
par_names, lower_bounds, upper_bounds,
model_version = c("TMM", "PIM")
model_version = c("TMM", "PIM"), verbose = TRUE
) {

# some tests for variable types
Expand Down Expand Up @@ -56,6 +57,16 @@ generating_initialpop <- function (
# setting columns names
colnames(lhs_pars) <- par_names

if (model_version == "TMM") {



} else if (model_version == "PIM") {



}

# defining the balance function
# basically a ratio of two rescaled lognormal functions
balance_function <- function (
Expand Down Expand Up @@ -121,9 +132,9 @@ generating_initialpop <- function (
amplitude_inhib = 1.5 / .data$amplitude_ratio,
peak_time_inhib = log(.data$peak_time),
curvature_inhib = .data$curvature_inhib * .data$curvature_activ
)
)
)
)

if (action_mode == "imagined") {

Expand Down Expand Up @@ -178,11 +189,15 @@ generating_initialpop <- function (

}

# updating the result_nrow variable
result_nrow <- nrow(lhs_initial_pop)
if (verbose) {

# printing progress
print(result_nrow)
# updating the result_nrow variable
result_nrow <- nrow(lhs_initial_pop)

# printing progress
print(result_nrow)

}

}

Expand Down
Loading

0 comments on commit 1147d7a

Please sign in to comment.