Skip to content

Commit

Permalink
Add tar_cpm_main()
Browse files Browse the repository at this point in the history
Signed-off-by: Liang Zhang <psychelzh@outlook.com>
  • Loading branch information
psychelzh committed Jan 25, 2024
1 parent 9548c24 commit bd0c112
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 67 deletions.
10 changes: 5 additions & 5 deletions R/cpm_aggregate.R
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
aggregate_performance <- function(cpm_result, dim_labels) {
aggregate_performance <- function(cpm_result, names_to = "latent") {
lapply(
cpm_result,
# targets will append batching information to the list
zutils::select_list(cpm_result, !starts_with("tar")),
\(result) {
apply(result$pred, 2, cor.test, result$real) |>
lapply(broom::tidy) |>
list_rbind(names_to = "network")
list_rbind(names_to = "include")
}
) |>
set_names(dim_labels) |>
list_rbind(names_to = "dim_label")
list_rbind(names_to = names_to)
}
126 changes: 126 additions & 0 deletions R/tar_factories.R
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,129 @@ tar_fit_cfa <- function(config, data, theory,
}
)
}

tar_prep_files_cpm <- function(params_subset = NULL) {
values <- config_files({{ params_subset }})
c(
tarchetypes::tar_eval(
tar_target(
file_confounds,
path_obj_from_proj(
paste(
"confounds_cpm",
session, task,
sep = "_"
),
"preproc_neural"
),
format = "file_fast"
),
dplyr::distinct(values, session, task, file_confounds)
),
tarchetypes::tar_eval(
tar_target(
file_fc,
path_obj_from_proj(
paste(
"fc_orig_full",
session, task, config, atlas,
sep = "_"
),
"preproc_neural"
),
format = "file_fast"
),
dplyr::distinct(values, session, task, config, atlas, file_fc)
)
)
}

tar_cpm_main <- function(scores_factor, subjs_to_keep, ...,
params_subset = NULL,
batches = 4,
reps = 5,
combine = NULL) {
cpm_branches <- tarchetypes::tar_map(
config_files({{ params_subset }}),
names = !starts_with("file"),
tarchetypes::tar_rep_raw(
"cpm_result",
substitute(
apply(
match_cases(scores_factor, subjs_to_keep), 2,
\(scores) {
cpmr::cpm(
match_cases(qs::qread(file_fc), subjs_to_keep),
scores,
confounds = match_cases(
qs::qread(file_confounds),
subjs_to_keep
),
thresh_method = thresh_method,
thresh_level = thresh_level,
kfolds = 10
)
}
)
),
batches = batches,
reps = reps,
iteration = "list"
),
tarchetypes::tar_rep2(
cpm_performance,
aggregate_performance(cpm_result),
cpm_result
)
)
c(
cpm_branches,
lapply(
intersect(names(cpm_branches), combine),
\(name) {
tarchetypes::tar_combine_raw(
name,
cpm_branches[[name]],
command = substitute(
bind_rows(!!!.x, .id = ".id") |>
zutils::separate_wider_dsv(
".id",
c(
names(params_fmri_tasks),
names(params_xcpd),
names(hypers_cpm)
),
patterns = c(rep(".+?", 2), ".+", rep(".+?", 3)),
prefix = "cpm_performance"
)
)
)
}
)
)
}

# helper functions ----
config_files <- function(params_subset = NULL) {
if (rlang::quo_is_null(rlang::enquo(params_subset))) params_subset <- TRUE
tidyr::expand_grid(
params_fmri_tasks,
params_xcpd,
hypers_cpm
) |>
dplyr::filter({{ params_subset }}) |>
dplyr::mutate(
file_fc = rlang::syms(
sprintf(
"file_fc_%s_%s_%s_%s",
session, task, config, atlas
)
),
file_confounds = rlang::syms(
sprintf(
"file_confounds_%s_%s",
session, task
)
)
)
}
67 changes: 5 additions & 62 deletions _scripts/predict_phenotypes.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,57 +18,6 @@ tar_option_set(
)
tar_source()

targets_cpm <- tarchetypes::tar_map(
params_fmri_tasks,
tar_target(
file_confounds,
path_obj_from_proj(
paste(
"confounds_cpm",
session, task,
sep = "_"
),
"preproc_neural"
),
format = "file_fast"
),
tarchetypes::tar_map(
params_xcpd,
tar_target(
file_fc,
path_obj_from_proj(
paste(
"fc_orig_full",
config, session, task, atlas,
sep = "_"
),
"preproc_neural"
),
format = "file_fast"
),
tarchetypes::tar_map(
hypers_cpm,
tar_target(
cpm_result,
cpmr::cpm(
match_cases(qs::qread(file_fc), subjs_to_keep),
match_cases(scores_factor, subjs_to_keep)[, dim_labels],
confounds = match_cases(qs::qread(file_confounds), subjs_to_keep),
thresh_method = thresh_method,
thresh_level = thresh_level,
kfolds = 10
),
pattern = map(dim_labels),
iteration = "list"
),
tar_target(
cpm_performance,
aggregate_performance(cpm_result, dim_labels)
)
)
)
)

list(
tar_target(
file_scores_factor,
Expand Down Expand Up @@ -100,16 +49,10 @@ list(
subjs_behav[matched[!is.na(matched)]]
}
),
targets_cpm,
tarchetypes::tar_combine(
cpm_performance,
zutils::select_list(targets_cpm, starts_with("cpm_performance")),
command = bind_rows(!!!.x, .id = ".id") |>
zutils::separate_wider_dsv(
".id",
c(names(hypers_cpm), names(params_xcpd), names(params_fmri_tasks)),
patterns = c(rep(".+?", 2), ".+_?.+", rep(".+?", 3)),
prefix = "cpm_performance"
)
tar_prep_files_cpm(),
tar_cpm_main(
scores_factor,
subjs_to_keep,
combine = "cpm_performance"
)
)

0 comments on commit bd0c112

Please sign in to comment.