diff --git a/crates/openvino/src/request.rs b/crates/openvino/src/request.rs index 2ca3d9a..67b2ec3 100644 --- a/crates/openvino/src/request.rs +++ b/crates/openvino/src/request.rs @@ -1,9 +1,10 @@ use crate::tensor::Tensor; use crate::{cstr, drop_using_function, try_unsafe, util::Result}; use openvino_sys::{ - ov_infer_request_free, ov_infer_request_get_tensor, ov_infer_request_infer, - ov_infer_request_set_tensor, ov_infer_request_start_async, ov_infer_request_t, - ov_infer_request_wait_for, + ov_infer_request_free, ov_infer_request_get_output_tensor_by_index, + ov_infer_request_get_tensor, ov_infer_request_infer, + ov_infer_request_set_input_tensor_by_index, ov_infer_request_set_tensor, + ov_infer_request_start_async, ov_infer_request_t, ov_infer_request_wait_for, }; /// See [`InferRequest`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__infer__request__c__api.html). @@ -20,6 +21,7 @@ impl InferRequest { pub(crate) fn from_ptr(ptr: *mut ov_infer_request_t) -> Self { Self { ptr } } + /// Assign a [`Tensor`] to the input on the model. pub fn set_tensor(&mut self, name: &str, tensor: &Tensor) -> Result<()> { try_unsafe!(ov_infer_request_set_tensor( @@ -41,6 +43,27 @@ impl InferRequest { Ok(Tensor::from_ptr(tensor)) } + /// Assing an input [`Tensor`] to the model by its index. + pub fn set_input_tensor_by_index(&mut self, index: usize, tensor: &Tensor) -> Result<()> { + try_unsafe!(ov_infer_request_set_input_tensor_by_index( + self.ptr, + index, + tensor.as_ptr() + ))?; + Ok(()) + } + + /// Retrieve an output [`Tensor`] from the model by its index. + pub fn get_output_tensor_by_index(&self, index: usize) -> Result { + let mut tensor = std::ptr::null_mut(); + try_unsafe!(ov_infer_request_get_output_tensor_by_index( + self.ptr, + index, + std::ptr::addr_of_mut!(tensor) + ))?; + Ok(Tensor::from_ptr(tensor)) + } + /// Execute the inference request. pub fn infer(&mut self) -> Result<()> { try_unsafe!(ov_infer_request_infer(self.ptr))