Skip to content

Commit

Permalink
Expose model's partial shape and related types (#98)
Browse files Browse the repository at this point in the history
* Expose port element type and shape

* Expose model's partial shape

* Clean up casts
  • Loading branch information
pnehrer committed May 13, 2024
1 parent cfec608 commit 656ae28
Show file tree
Hide file tree
Showing 7 changed files with 398 additions and 3 deletions.
78 changes: 78 additions & 0 deletions crates/openvino/src/dimension.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use openvino_sys::{ov_dimension_is_dynamic, ov_dimension_t};

/// See [`Dimension`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__dimension__c__api.html).
#[derive(Copy, Clone, Debug)]
#[repr(transparent)]
pub struct Dimension {
instance: ov_dimension_t,
}

impl PartialEq for Dimension {
fn eq(&self, other: &Self) -> bool {
self.instance.min == other.instance.min && self.instance.max == other.instance.max
}
}

impl Eq for Dimension {}

impl Dimension {
/// Get the pointer to the underlying OpenVINO dimension.
#[allow(dead_code)]
pub(crate) fn instance(&self) -> ov_dimension_t {
self.instance
}

/// Create a new dimension object from `ov_dimension_t`.
#[allow(dead_code)]
pub(crate) fn new_from_instance(instance: ov_dimension_t) -> Self {
Self { instance }
}

/// Creates a new Dimension with minimum and maximum values.
pub fn new(min: i64, max: i64) -> Self {
let instance = ov_dimension_t { min, max };
Self { instance }
}

/// Returns the minimum value.
pub fn get_min(&self) -> i64 {
self.instance.min
}

/// Returns the maximum value.
pub fn get_max(&self) -> i64 {
self.instance.max
}

/// Returns `true` if the dimension is dynamic.
pub fn is_dynamic(&self) -> bool {
unsafe { ov_dimension_is_dynamic(self.instance) }
}
}

#[cfg(test)]
mod tests {
use crate::LoadingError;

use super::Dimension;

#[test]
fn test_static() {
openvino_sys::library::load()
.map_err(LoadingError::SystemFailure)
.unwrap();

let dim = Dimension::new(1, 1);
assert!(!dim.is_dynamic());
}

#[test]
fn test_dynamic() {
openvino_sys::library::load()
.map_err(LoadingError::SystemFailure)
.unwrap();

let dim = Dimension::new(1, 2);
assert!(dim.is_dynamic());
}
}
6 changes: 6 additions & 0 deletions crates/openvino/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,29 @@
)]

mod core;
mod dimension;
mod element_type;
mod error;
mod layout;
mod model;
mod node;
mod partial_shape;
pub mod prepostprocess;
mod rank;
mod request;
mod shape;
mod tensor;
mod util;

pub use crate::core::Core;
pub use dimension::Dimension;
pub use element_type::ElementType;
pub use error::{InferenceError, LoadingError, SetupError};
pub use layout::Layout;
pub use model::{CompiledModel, Model};
pub use node::Node;
pub use partial_shape::PartialShape;
pub use rank::Rank;
pub use request::InferRequest;
pub use shape::Shape;
pub use tensor::Tensor;
Expand Down
7 changes: 6 additions & 1 deletion crates/openvino/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{drop_using_function, try_unsafe, util::Result};
use openvino_sys::{
ov_compiled_model_create_infer_request, ov_compiled_model_free, ov_compiled_model_t,
ov_model_const_input_by_index, ov_model_const_output_by_index, ov_model_free,
ov_model_inputs_size, ov_model_outputs_size, ov_model_t,
ov_model_inputs_size, ov_model_is_dynamic, ov_model_outputs_size, ov_model_t,
};

/// See [`Model`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__model__c__api.html).
Expand Down Expand Up @@ -78,6 +78,11 @@ impl Model {
))?;
Ok(Node::new(node))
}

/// Returns `true` if the model contains dynamic shapes.
pub fn is_dynamic(&self) -> bool {
unsafe { ov_model_is_dynamic(self.instance) }
}
}

/// See [`CompiledModel`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__compiled__model__c__api.html).
Expand Down
44 changes: 42 additions & 2 deletions crates/openvino/src/node.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use crate::{try_unsafe, util::Result};
use openvino_sys::{ov_output_const_port_t, ov_port_get_any_name};
use crate::{try_unsafe, util::Result, ElementType, PartialShape, Shape};
use openvino_sys::{
ov_const_port_get_shape, ov_output_const_port_t, ov_partial_shape_t, ov_port_get_any_name,
ov_port_get_element_type, ov_port_get_partial_shape, ov_rank_t, ov_shape_t,
};

use std::ffi::CStr;

/// See [`Node`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__node__c__api.html).
Expand All @@ -25,4 +29,40 @@ impl Node {
.into_owned();
Ok(rust_name)
}

/// Get the data type of elements of the port.
pub fn get_element_type(&self) -> Result<u32> {
let mut element_type = ElementType::Undefined as u32;
try_unsafe!(ov_port_get_element_type(
self.instance,
std::ptr::addr_of_mut!(element_type),
))?;
Ok(element_type)
}

/// Get the shape of the port.
pub fn get_shape(&self) -> Result<Shape> {
let mut instance = ov_shape_t {
rank: 0,
dims: std::ptr::null_mut(),
};
try_unsafe!(ov_const_port_get_shape(
self.instance,
std::ptr::addr_of_mut!(instance),
))?;
Ok(Shape::new_from_instance(instance))
}

/// Get the partial shape of the port.
pub fn get_partial_shape(&self) -> Result<PartialShape> {
let mut instance = ov_partial_shape_t {
rank: ov_rank_t { min: 0, max: 0 },
dims: std::ptr::null_mut(),
};
try_unsafe!(ov_port_get_partial_shape(
self.instance,
std::ptr::addr_of_mut!(instance),
))?;
Ok(PartialShape::new_from_instance(instance))
}
}
176 changes: 176 additions & 0 deletions crates/openvino/src/partial_shape.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
use crate::{dimension::Dimension, try_unsafe, util::Result, Rank};
use openvino_sys::{
ov_dimension_t, ov_partial_shape_create, ov_partial_shape_create_dynamic,
ov_partial_shape_create_static, ov_partial_shape_free, ov_partial_shape_is_dynamic,
ov_partial_shape_t, ov_rank_t,
};

use std::convert::TryInto;

/// See [`PartialShape`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__partial__shape__c__api.html).
pub struct PartialShape {
instance: ov_partial_shape_t,
}

impl Drop for PartialShape {
/// Drops the `PartialShape` instance and frees the associated memory.
fn drop(&mut self) {
unsafe { ov_partial_shape_free(std::ptr::addr_of_mut!(self.instance)) }
}
}

impl PartialShape {
/// Get the pointer to the underlying OpenVINO partial shape.
#[allow(dead_code)]
pub(crate) fn instance(&self) -> ov_partial_shape_t {
self.instance
}

/// Create a new partial shape object from `ov_partial_shape_t`.
pub(crate) fn new_from_instance(instance: ov_partial_shape_t) -> Self {
Self { instance }
}

/// Creates a new `PartialShape` instance with a static rank and dynamic dimensions.
pub fn new(rank: i64, dimensions: &[Dimension]) -> Result<Self> {
let mut partial_shape = ov_partial_shape_t {
rank: ov_rank_t { min: 0, max: 0 },
dims: std::ptr::null_mut(),
};
try_unsafe!(ov_partial_shape_create(
rank,
dimensions.as_ptr().cast::<ov_dimension_t>(),
std::ptr::addr_of_mut!(partial_shape)
))?;
Ok(Self {
instance: partial_shape,
})
}

/// Creates a new `PartialShape` instance with a dynamic rank and dynamic dimensions.
pub fn new_dynamic(rank: Rank, dimensions: &[Dimension]) -> Result<Self> {
let mut partial_shape = ov_partial_shape_t {
rank: ov_rank_t { min: 0, max: 0 },
dims: std::ptr::null_mut(),
};
try_unsafe!(ov_partial_shape_create_dynamic(
rank.instance(),
dimensions.as_ptr().cast::<ov_dimension_t>(),
std::ptr::addr_of_mut!(partial_shape)
))?;
Ok(Self {
instance: partial_shape,
})
}

/// Creates a new `PartialShape` instance with a static rank and static dimensions.
pub fn new_static(rank: i64, dimensions: &[i64]) -> Result<Self> {
let mut partial_shape = ov_partial_shape_t {
rank: ov_rank_t { min: 0, max: 0 },
dims: std::ptr::null_mut(),
};
try_unsafe!(ov_partial_shape_create_static(
rank,
dimensions.as_ptr(),
std::ptr::addr_of_mut!(partial_shape)
))?;
Ok(Self {
instance: partial_shape,
})
}

/// Returns the rank of the partial shape.
pub fn get_rank(&self) -> Rank {
let rank = self.instance.rank;
Rank::new_from_instance(rank)
}

/// Returns the dimensions of the partial shape.
pub fn get_dimensions(&self) -> &[Dimension] {
if self.instance.dims.is_null() {
&[]
} else {
unsafe {
std::slice::from_raw_parts(
self.instance.dims.cast::<Dimension>(),
self.instance.rank.max.try_into().unwrap(),
)
}
}
}

/// Returns `true` if the partial shape is dynamic.
pub fn is_dynamic(&self) -> bool {
unsafe { ov_partial_shape_is_dynamic(self.instance) }
}
}

#[cfg(test)]
mod tests {
use crate::LoadingError;

use super::*;

#[test]
fn test_new_partial_shape() {
openvino_sys::library::load()
.map_err(LoadingError::SystemFailure)
.unwrap();

let dimensions = vec![
Dimension::new(0, 1),
Dimension::new(1, 2),
Dimension::new(2, 3),
Dimension::new(3, 4),
];

let shape = PartialShape::new(4, &dimensions).unwrap();
assert_eq!(shape.get_rank().get_min(), 4);
assert_eq!(shape.get_rank().get_max(), 4);
assert!(shape.is_dynamic());
}

#[test]
fn test_new_dynamic_partial_shape() {
openvino_sys::library::load()
.map_err(LoadingError::SystemFailure)
.unwrap();

let dimensions = vec![Dimension::new(1, 1), Dimension::new(2, 2)];

let shape = PartialShape::new_dynamic(Rank::new(0, 2), &dimensions).unwrap();
assert!(shape.is_dynamic());
}

#[test]
fn test_new_static_partial_shape() {
openvino_sys::library::load()
.map_err(LoadingError::SystemFailure)
.unwrap();

let dimensions = vec![1, 2];

let shape = PartialShape::new_static(2, &dimensions).unwrap();
assert!(!shape.is_dynamic());
}

#[test]
fn test_get_dimensions() {
openvino_sys::library::load()
.map_err(LoadingError::SystemFailure)
.unwrap();

let dimensions = vec![
Dimension::new(0, 1),
Dimension::new(1, 2),
Dimension::new(2, 3),
Dimension::new(3, 4),
];

let shape = PartialShape::new(4, &dimensions).unwrap();

let dims = shape.get_dimensions();

assert_eq!(dims, &dimensions);
}
}
Loading

0 comments on commit 656ae28

Please sign in to comment.