Skip to content

Commit

Permalink
Remove tar_cpm_main()
Browse files Browse the repository at this point in the history
Signed-off-by: Liang Zhang <psychelzh@outlook.com>
  • Loading branch information
psychelzh committed Jan 26, 2024
1 parent 92a0280 commit e6b486a
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 99 deletions.
65 changes: 0 additions & 65 deletions R/tar_factories.R
Original file line number Diff line number Diff line change
Expand Up @@ -367,71 +367,6 @@ tar_prep_files_cpm <- function(params_subset = NULL) {
)
}

tar_cpm_main <- function(scores_factor, subjs_to_keep, ...,
params_subset = NULL,
batches = 4,
reps = 5,
combine = NULL) {
cpm_branches <- tarchetypes::tar_map(
config_files({{ params_subset }}),
names = !starts_with("file"),
tarchetypes::tar_rep_raw(
"cpm_result",
substitute(
apply(
match_cases(scores_factor, subjs_to_keep), 2,
\(scores) {
cpmr::cpm(
match_cases(qs::qread(file_fc), subjs_to_keep),
scores,
confounds = match_cases(
qs::qread(file_confounds),
subjs_to_keep
),
thresh_method = thresh_method,
thresh_level = thresh_level,
kfolds = 10
)
}
)
),
batches = batches,
reps = reps,
iteration = "list"
),
tarchetypes::tar_rep2(
cpm_performance,
aggregate_performance(cpm_result),
cpm_result
)
)
c(
cpm_branches,
lapply(
intersect(names(cpm_branches), combine),
\(name) {
tarchetypes::tar_combine_raw(
name,
cpm_branches[[name]],
command = substitute(
bind_rows(!!!.x, .id = ".id") |>
zutils::separate_wider_dsv(
".id",
c(
names(params_fmri_tasks),
names(params_xcpd),
names(hypers_cpm)
),
patterns = c(rep(".+?", 2), ".+", rep(".+?", 3)),
prefix = "cpm_performance"
)
)
)
}
)
)
}

# helper functions ----
config_files <- function(params_subset = NULL) {
if (rlang::quo_is_null(rlang::enquo(params_subset))) params_subset <- TRUE
Expand Down
12 changes: 0 additions & 12 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,6 @@ replace_as_name_cn <- function(game_index,
str_c(splitted[, 1], splitted[, 2], sep = delim)
}

match_cases <- function(data, subjs) {
data_subjs <- attr(data, "id")
matched <- match(subjs, data_subjs)
if (anyNA(matched)) {
stop("Some subjects are not found in the data.")
}
structure(
data[matched, ],
id = attr(data, "id")[matched]
)
}

match_dim_label <- function(latent) {
dimensions <- read_csv("config/dimensions.csv", show_col_types = FALSE) |>
mutate(latent = str_c("F", cluster)) |>
Expand Down
69 changes: 47 additions & 22 deletions _scripts/predict_phenotypes.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,35 @@ tar_option_set(
)
tar_source()

cpm_branches <- tarchetypes::tar_map(
config_files(),
names = !starts_with("file"),
tarchetypes::tar_rep(
cpm_result,
apply(
qs::qread(file_scores_factor)[subjs_to_keep, ], 2,
\(scores) {
cpmr::cpm(
qs::qread(file_fc)[subjs_to_keep, ],
scores,
confounds = qs::qread(file_confounds)[subjs_to_keep, ],
thresh_method = thresh_method,
thresh_level = thresh_level,
kfolds = 10
)
}
),
batches = 4,
reps = 5,
iteration = "list"
),
tarchetypes::tar_rep2(
cpm_performance,
aggregate_performance(cpm_result),
cpm_result
)
)

list(
tar_target(
file_scores_factor,
Expand All @@ -30,29 +59,25 @@ list(
format = "file_fast"
),
tar_target(
scores_factor, {
tbl <- qs::qread(file_scores_factor)
structure(
as.matrix(select(tbl, !user_id)),
id = tbl$user_id
)
}
),
tar_target(dim_labels, colnames(scores_factor)),
tar_target(
subjs_to_keep, {
# intersect() does not work for integer64
# https://github.com/truecluster/bit64/issues/29
subjs_neural <- qs::qread(file_subjs_keep_neural)
subjs_behav <- attr(scores_factor, "id")
matched <- match(subjs_neural, subjs_behav)
subjs_behav[matched[!is.na(matched)]]
}
subjs_to_keep,
# intersect() does not work for integer64
# https://github.com/truecluster/bit64/issues/29
intersect(
as.character(qs::qread(file_subjs_keep_neural)),
rownames(qs::qread(file_scores_factor))
)
),
tar_prep_files_cpm(),
tar_cpm_main(
scores_factor,
subjs_to_keep,
combine = "cpm_performance"
cpm_branches,
tarchetypes::tar_combine(
cpm_performance,
cpm_branches$cpm_performance,
command = bind_rows(!!!.x, .id = ".id") |>
zutils::separate_wider_dsv(
".id",
c(names(params_fmri_tasks), names(params_xcpd), names(hypers_cpm)),
patterns = c(rep(".+?", 2), ".+", rep(".+?", 3)),
prefix = "cpm_performance"
)
)
)

0 comments on commit e6b486a

Please sign in to comment.