Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose model's partial shape and related types #98

Merged
merged 3 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions crates/openvino/src/dimension.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use openvino_sys::{ov_dimension_is_dynamic, ov_dimension_t};

/// See [`Dimension`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__dimension__c__api.html).
#[derive(Copy, Clone, Debug)]
#[repr(transparent)]
pub struct Dimension {
instance: ov_dimension_t,
}

impl PartialEq for Dimension {
fn eq(&self, other: &Self) -> bool {
self.instance.min == other.instance.min && self.instance.max == other.instance.max
}
}

impl Eq for Dimension {}

impl Dimension {
/// Get the pointer to the underlying OpenVINO dimension.
#[allow(dead_code)]
pub(crate) fn instance(&self) -> ov_dimension_t {
self.instance
}

/// Create a new dimension object from `ov_dimension_t`.
#[allow(dead_code)]
pub(crate) fn new_from_instance(instance: ov_dimension_t) -> Self {
Self { instance }
}

/// Creates a new Dimension with minimum and maximum values.
pub fn new(min: i64, max: i64) -> Self {
let instance = ov_dimension_t { min, max };
Self { instance }
}

/// Returns the minimum value.
pub fn get_min(&self) -> i64 {
self.instance.min
}

/// Returns the maximum value.
pub fn get_max(&self) -> i64 {
self.instance.max
}

/// Returns `true` if the dimension is dynamic.
pub fn is_dynamic(&self) -> bool {
unsafe { ov_dimension_is_dynamic(self.instance) }
}
}

#[cfg(test)]
mod tests {
use crate::LoadingError;

use super::Dimension;

#[test]
fn test_static() {
openvino_sys::library::load()
.map_err(LoadingError::SystemFailure)
.unwrap();

let dim = Dimension::new(1, 1);
assert!(!dim.is_dynamic());
}

#[test]
fn test_dynamic() {
openvino_sys::library::load()
.map_err(LoadingError::SystemFailure)
.unwrap();

let dim = Dimension::new(1, 2);
assert!(dim.is_dynamic());
}
}
6 changes: 6 additions & 0 deletions crates/openvino/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,29 @@
)]

mod core;
mod dimension;
mod element_type;
mod error;
mod layout;
mod model;
mod node;
mod partial_shape;
pub mod prepostprocess;
mod rank;
mod request;
mod shape;
mod tensor;
mod util;

pub use crate::core::Core;
pub use dimension::Dimension;
pub use element_type::ElementType;
pub use error::{InferenceError, LoadingError, SetupError};
pub use layout::Layout;
pub use model::{CompiledModel, Model};
pub use node::Node;
pub use partial_shape::PartialShape;
pub use rank::Rank;
pub use request::InferRequest;
pub use shape::Shape;
pub use tensor::Tensor;
Expand Down
7 changes: 6 additions & 1 deletion crates/openvino/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{drop_using_function, try_unsafe, util::Result};
use openvino_sys::{
ov_compiled_model_create_infer_request, ov_compiled_model_free, ov_compiled_model_t,
ov_model_const_input_by_index, ov_model_const_output_by_index, ov_model_free,
ov_model_inputs_size, ov_model_outputs_size, ov_model_t,
ov_model_inputs_size, ov_model_is_dynamic, ov_model_outputs_size, ov_model_t,
};

/// See [`Model`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__model__c__api.html).
Expand Down Expand Up @@ -78,6 +78,11 @@ impl Model {
))?;
Ok(Node::new(node))
}

/// Returns `true` if the model contains dynamic shapes.
pub fn is_dynamic(&self) -> bool {
unsafe { ov_model_is_dynamic(self.instance) }
}
}

/// See [`CompiledModel`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__compiled__model__c__api.html).
Expand Down
44 changes: 42 additions & 2 deletions crates/openvino/src/node.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use crate::{try_unsafe, util::Result};
use openvino_sys::{ov_output_const_port_t, ov_port_get_any_name};
use crate::{try_unsafe, util::Result, ElementType, PartialShape, Shape};
use openvino_sys::{
ov_const_port_get_shape, ov_output_const_port_t, ov_partial_shape_t, ov_port_get_any_name,
ov_port_get_element_type, ov_port_get_partial_shape, ov_rank_t, ov_shape_t,
};

use std::ffi::CStr;

/// See [`Node`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__node__c__api.html).
Expand All @@ -25,4 +29,40 @@ impl Node {
.into_owned();
Ok(rust_name)
}

/// Get the data type of elements of the port.
pub fn get_element_type(&self) -> Result<u32> {
let mut element_type = ElementType::Undefined as u32;
try_unsafe!(ov_port_get_element_type(
self.instance,
std::ptr::addr_of_mut!(element_type),
))?;
Ok(element_type)
}

/// Get the shape of the port.
pub fn get_shape(&self) -> Result<Shape> {
let mut instance = ov_shape_t {
rank: 0,
dims: std::ptr::null_mut(),
};
try_unsafe!(ov_const_port_get_shape(
self.instance,
std::ptr::addr_of_mut!(instance),
))?;
Ok(Shape::new_from_instance(instance))
}

/// Get the partial shape of the port.
pub fn get_partial_shape(&self) -> Result<PartialShape> {
let mut instance = ov_partial_shape_t {
rank: ov_rank_t { min: 0, max: 0 },
dims: std::ptr::null_mut(),
};
try_unsafe!(ov_port_get_partial_shape(
self.instance,
std::ptr::addr_of_mut!(instance),
))?;
Ok(PartialShape::new_from_instance(instance))
}
}
176 changes: 176 additions & 0 deletions crates/openvino/src/partial_shape.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
use crate::{dimension::Dimension, try_unsafe, util::Result, Rank};
use openvino_sys::{
ov_dimension_t, ov_partial_shape_create, ov_partial_shape_create_dynamic,
ov_partial_shape_create_static, ov_partial_shape_free, ov_partial_shape_is_dynamic,
ov_partial_shape_t, ov_rank_t,
};

use std::convert::TryInto;

/// See [`PartialShape`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__partial__shape__c__api.html).
pub struct PartialShape {
instance: ov_partial_shape_t,
}

impl Drop for PartialShape {
/// Drops the `PartialShape` instance and frees the associated memory.
fn drop(&mut self) {
unsafe { ov_partial_shape_free(std::ptr::addr_of_mut!(self.instance)) }
}
}

impl PartialShape {
/// Get the pointer to the underlying OpenVINO partial shape.
#[allow(dead_code)]
pub(crate) fn instance(&self) -> ov_partial_shape_t {
self.instance
}

/// Create a new partial shape object from `ov_partial_shape_t`.
pub(crate) fn new_from_instance(instance: ov_partial_shape_t) -> Self {
Self { instance }
}

/// Creates a new `PartialShape` instance with a static rank and dynamic dimensions.
pub fn new(rank: i64, dimensions: &[Dimension]) -> Result<Self> {
let mut partial_shape = ov_partial_shape_t {
rank: ov_rank_t { min: 0, max: 0 },
dims: std::ptr::null_mut(),
};
try_unsafe!(ov_partial_shape_create(
rank,
dimensions.as_ptr().cast::<ov_dimension_t>(),
std::ptr::addr_of_mut!(partial_shape)
))?;
Ok(Self {
instance: partial_shape,
})
}

/// Creates a new `PartialShape` instance with a dynamic rank and dynamic dimensions.
pub fn new_dynamic(rank: Rank, dimensions: &[Dimension]) -> Result<Self> {
let mut partial_shape = ov_partial_shape_t {
rank: ov_rank_t { min: 0, max: 0 },
dims: std::ptr::null_mut(),
};
try_unsafe!(ov_partial_shape_create_dynamic(
rank.instance(),
dimensions.as_ptr().cast::<ov_dimension_t>(),
std::ptr::addr_of_mut!(partial_shape)
))?;
Ok(Self {
instance: partial_shape,
})
}

/// Creates a new `PartialShape` instance with a static rank and static dimensions.
pub fn new_static(rank: i64, dimensions: &[i64]) -> Result<Self> {
let mut partial_shape = ov_partial_shape_t {
rank: ov_rank_t { min: 0, max: 0 },
dims: std::ptr::null_mut(),
};
try_unsafe!(ov_partial_shape_create_static(
rank,
dimensions.as_ptr(),
std::ptr::addr_of_mut!(partial_shape)
))?;
Ok(Self {
instance: partial_shape,
})
}

/// Returns the rank of the partial shape.
pub fn get_rank(&self) -> Rank {
let rank = self.instance.rank;
Rank::new_from_instance(rank)
}

/// Returns the dimensions of the partial shape.
pub fn get_dimensions(&self) -> &[Dimension] {
if self.instance.dims.is_null() {
&[]
} else {
unsafe {
std::slice::from_raw_parts(
self.instance.dims.cast::<Dimension>(),
self.instance.rank.max.try_into().unwrap(),
)
}
}
}

/// Returns `true` if the partial shape is dynamic.
pub fn is_dynamic(&self) -> bool {
unsafe { ov_partial_shape_is_dynamic(self.instance) }
}
}

#[cfg(test)]
mod tests {
use crate::LoadingError;

use super::*;

#[test]
fn test_new_partial_shape() {
openvino_sys::library::load()
.map_err(LoadingError::SystemFailure)
.unwrap();

let dimensions = vec![
Dimension::new(0, 1),
Dimension::new(1, 2),
Dimension::new(2, 3),
Dimension::new(3, 4),
];

let shape = PartialShape::new(4, &dimensions).unwrap();
assert_eq!(shape.get_rank().get_min(), 4);
assert_eq!(shape.get_rank().get_max(), 4);
assert!(shape.is_dynamic());
}

#[test]
fn test_new_dynamic_partial_shape() {
openvino_sys::library::load()
.map_err(LoadingError::SystemFailure)
.unwrap();

let dimensions = vec![Dimension::new(1, 1), Dimension::new(2, 2)];

let shape = PartialShape::new_dynamic(Rank::new(0, 2), &dimensions).unwrap();
assert!(shape.is_dynamic());
}

#[test]
fn test_new_static_partial_shape() {
openvino_sys::library::load()
.map_err(LoadingError::SystemFailure)
.unwrap();

let dimensions = vec![1, 2];

let shape = PartialShape::new_static(2, &dimensions).unwrap();
assert!(!shape.is_dynamic());
}

#[test]
fn test_get_dimensions() {
openvino_sys::library::load()
.map_err(LoadingError::SystemFailure)
.unwrap();

let dimensions = vec![
Dimension::new(0, 1),
Dimension::new(1, 2),
Dimension::new(2, 3),
Dimension::new(3, 4),
];

let shape = PartialShape::new(4, &dimensions).unwrap();

let dims = shape.get_dimensions();

assert_eq!(dims, &dimensions);
}
}
Loading
Loading