Skip to content

Commit

Permalink
Expose setting intput/getting output tensors on the model (#106)
Browse files Browse the repository at this point in the history
Expose setting input/getting output tensors by index
  • Loading branch information
pnehrer committed May 16, 2024
1 parent 836dd87 commit 3a404ee
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions crates/openvino/src/request.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::tensor::Tensor;
use crate::{cstr, drop_using_function, try_unsafe, util::Result};
use openvino_sys::{
ov_infer_request_free, ov_infer_request_get_tensor, ov_infer_request_infer,
ov_infer_request_set_tensor, ov_infer_request_start_async, ov_infer_request_t,
ov_infer_request_wait_for,
ov_infer_request_free, ov_infer_request_get_output_tensor_by_index,
ov_infer_request_get_tensor, ov_infer_request_infer,
ov_infer_request_set_input_tensor_by_index, ov_infer_request_set_tensor,
ov_infer_request_start_async, ov_infer_request_t, ov_infer_request_wait_for,
};

/// See [`InferRequest`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__infer__request__c__api.html).
Expand All @@ -20,6 +21,7 @@ impl InferRequest {
pub(crate) fn from_ptr(ptr: *mut ov_infer_request_t) -> Self {
Self { ptr }
}

/// Assign a [`Tensor`] to the input on the model.
pub fn set_tensor(&mut self, name: &str, tensor: &Tensor) -> Result<()> {
try_unsafe!(ov_infer_request_set_tensor(
Expand All @@ -41,6 +43,27 @@ impl InferRequest {
Ok(Tensor::from_ptr(tensor))
}

/// Assing an input [`Tensor`] to the model by its index.
pub fn set_input_tensor_by_index(&mut self, index: usize, tensor: &Tensor) -> Result<()> {
try_unsafe!(ov_infer_request_set_input_tensor_by_index(
self.ptr,
index,
tensor.as_ptr()
))?;
Ok(())
}

/// Retrieve an output [`Tensor`] from the model by its index.
pub fn get_output_tensor_by_index(&self, index: usize) -> Result<Tensor> {
let mut tensor = std::ptr::null_mut();
try_unsafe!(ov_infer_request_get_output_tensor_by_index(
self.ptr,
index,
std::ptr::addr_of_mut!(tensor)
))?;
Ok(Tensor::from_ptr(tensor))
}

/// Execute the inference request.
pub fn infer(&mut self) -> Result<()> {
try_unsafe!(ov_infer_request_infer(self.ptr))
Expand Down

0 comments on commit 3a404ee

Please sign in to comment.