Skip to content

Commit

Permalink
autodiff: no_std support (switch std:: to core::)
Browse files Browse the repository at this point in the history
I can now do this no a device function and the IR looks okay by eyeball.

argo +enzyme rustc --release --target=nvptx64-nvidia-cuda -Zbuild-std -- --emit=llvm-ir
  • Loading branch information
jedbrown authored and ZuseZ4 committed Dec 22, 2023
1 parent efad9fd commit d9e9c9c
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 14 deletions.
8 changes: 4 additions & 4 deletions library/autodiff/src/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ pub(crate) fn adjoint_fnc(item: &DiffItem) -> TokenStream {
res_inputs.push(input.clone());

match (item.header.mode, activity, is_ref_mut(&input)) {
(Mode::Forward, Activity::Duplicated|Activity::DuplicatedNoNeed, Some(true)) => {
(Mode::Forward, Activity::Duplicated | Activity::DuplicatedNoNeed, Some(true)) => {
res_inputs.push(as_ref_mut(&input, "grad", true));
add_inputs.push(as_ref_mut(&input, "grad", true));
}
(Mode::Forward, Activity::Duplicated|Activity::DuplicatedNoNeed, Some(false)) => {
(Mode::Forward, Activity::Duplicated | Activity::DuplicatedNoNeed, Some(false)) => {
res_inputs.push(as_ref_mut(&input, "dual", false));
add_inputs.push(as_ref_mut(&input, "dual", false));
out_type.clone().map(|x| outputs.push(x));
Expand Down Expand Up @@ -203,9 +203,9 @@ pub(crate) fn adjoint_fnc(item: &DiffItem) -> TokenStream {
};

let body = quote!({
std::hint::black_box((#call_ident(#(#inputs,)*), #(#add_inputs,)*));
core::hint::black_box((#call_ident(#(#inputs,)*), #(#add_inputs,)*));

std::hint::black_box(unsafe { std::mem::zeroed() })
core::hint::black_box(unsafe { core::mem::zeroed() })
});
let header = generate_header(&item);

Expand Down
4 changes: 2 additions & 2 deletions library/autodiff/tests/expand/forward_duplicated.expanded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ fn square(a: &Vec<f32>, b: &mut f32) {
}
#[autodiff_into(Forward, Const, Duplicated, Duplicated)]
fn d_square(a: &Vec<f32>, dual_a: &Vec<f32>, b: &mut f32, grad_b: &mut f32) {
std::hint::black_box((square(a, b), dual_a, grad_b));
std::hint::black_box(unsafe { std::mem::zeroed() })
core::hint::black_box((square(a, b), dual_a, grad_b));
core::hint::black_box(unsafe { core::mem::zeroed() })
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ fn d_square2(
b: &Vec<f32>,
dual_b: &Vec<f32>,
) -> (f32, f32, f32) {
std::hint::black_box((square2(a, b), dual_a, dual_b));
std::hint::black_box(unsafe { std::mem::zeroed() })
core::hint::black_box((square2(a, b), dual_a, dual_b));
core::hint::black_box(unsafe { core::mem::zeroed() })
}
4 changes: 2 additions & 2 deletions library/autodiff/tests/expand/reverse_duplicated.expanded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ fn square(a: &Vec<f32>, b: &mut f32) {
}
#[autodiff_into(Reverse, Const, Duplicated, Duplicated)]
fn d_square(a: &Vec<f32>, grad_a: &mut Vec<f32>, b: &mut f32, grad_b: &f32) {
std::hint::black_box((square(a, b), grad_a, grad_b));
std::hint::black_box(unsafe { std::mem::zeroed() })
core::hint::black_box((square(a, b), grad_a, grad_b));
core::hint::black_box(unsafe { core::mem::zeroed() })
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ fn array(arr: &[[[f32; 2]; 2]; 2]) -> f32 {
}
#[autodiff_into(Reverse, Active, Duplicated)]
fn d_array(arr: &[[[f32; 2]; 2]; 2], grad_arr: &mut [[[f32; 2]; 2]; 2], tang_y: f32) {
std::hint::black_box((array(arr), grad_arr, tang_y));
std::hint::black_box(unsafe { std::mem::zeroed() })
core::hint::black_box((array(arr), grad_arr, tang_y));
core::hint::black_box(unsafe { core::mem::zeroed() })
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ fn d_sqrt(
d: f32,
tang_y: f32,
) -> (f32, f32) {
std::hint::black_box((sqrt(a, b, c, d), grad_b, tang_y));
std::hint::black_box(unsafe { std::mem::zeroed() })
core::hint::black_box((sqrt(a, b, c, d), grad_b, tang_y));
core::hint::black_box(unsafe { core::mem::zeroed() })
}

0 comments on commit d9e9c9c

Please sign in to comment.