diff --git a/ext/r2inference/gstinference.c b/ext/r2inference/gstinference.c index 19cf543..bd150fb 100644 --- a/ext/r2inference/gstinference.c +++ b/ext/r2inference/gstinference.c @@ -32,6 +32,7 @@ #include "gstresnet50v1.h" #include "gstmobilenetv2.h" #include "gstmobilenetv2ssd.h" +#include "gstrosetta.h" static gboolean plugin_init (GstPlugin * plugin) @@ -95,6 +96,12 @@ plugin_init (GstPlugin * plugin) goto out; } + ret = gst_element_register (plugin, "rosetta", GST_RANK_NONE, + GST_TYPE_ROSETTA); + if (!ret) { + goto out; + } + out: return ret; } diff --git a/ext/r2inference/gstrosetta.c b/ext/r2inference/gstrosetta.c new file mode 100644 index 0000000..24f6296 --- /dev/null +++ b/ext/r2inference/gstrosetta.c @@ -0,0 +1,274 @@ +/* + * GStreamer + * Copyright (C) 2018-2021 RidgeRun + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Library General Public + * License as published by the Free Software Foundation; either + * version 2 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Library General Public License for more details. + * + * You should have received a copy of the GNU Library General Public + * License along with this library; if not, write to the + * Free Software Foundation, Inc., 59 Temple Place - Suite 330, + * Boston, MA 02111-1307, USA. + * + */ + + /** + * SECTION:element-rosetta + * + * The rosetta element allows the user to infer/execute a pretrained model + * based on the ResNet architecture on incoming image frames and extract + * the characters from it. + * + * + * Source + * This element is based on the TensorFlow Lite Hub Rosetta Google + * Colaboratory script: + * https://tfhub.dev/tulasiram58827/lite-model/rosetta/dr/1 + * + */ + +#include "gstrosetta.h" + +#include "gst/r2inference/gstinferencedebug.h" +#include "gst/r2inference/gstinferencemeta.h" +#include "gst/r2inference/gstinferencepostprocess.h" +#include "gst/r2inference/gstinferencepreprocess.h" + +GST_DEBUG_CATEGORY_STATIC (gst_rosetta_debug_category); +#define GST_CAT_DEFAULT gst_rosetta_debug_category + +#define BLANK 0 +#define DEFAULT_MODEL_CHANNELS 1 +#define DEFAULT_DATA_MEAN 127.5 +#define DEFAULT_DATA_OFFSET -1 +#define MODEL_OUTPUT_ROWS 26 +#define MODEL_OUTPUT_COLS 37 + +/* prototypes */ +static gboolean gst_rosetta_preprocess (GstVideoInference * vi, + GstVideoFrame * inframe, GstVideoFrame * outframe); + +static gboolean +gst_rosetta_postprocess (GstVideoInference * vi, + const gpointer prediction, gsize predsize, GstMeta * meta_model, + GstVideoInfo * info_model, gboolean * valid_prediction, + gchar ** labels_list, gint num_labels); + +gint get_max_indices (gfloat row[MODEL_OUTPUT_COLS]); + +gchar *concatenate_chars (gint max_indices[MODEL_OUTPUT_ROWS]); +static gboolean gst_rosetta_start (GstVideoInference * vi); +static gboolean gst_rosetta_stop (GstVideoInference * vi); + +#define CAPS \ + "video/x-raw, " \ + "width=100, " \ + "height=32, " \ + "format={GRAY8}" + +static GstStaticPadTemplate sink_model_factory = +GST_STATIC_PAD_TEMPLATE ("sink_model", + GST_PAD_SINK, + GST_PAD_REQUEST, + GST_STATIC_CAPS (CAPS) + ); + +static GstStaticPadTemplate src_model_factory = +GST_STATIC_PAD_TEMPLATE ("src_model", + GST_PAD_SRC, + GST_PAD_REQUEST, + GST_STATIC_CAPS (CAPS) + ); + +struct _GstRosetta +{ + GstVideoInference parent; +}; + +struct _GstRosettaClass +{ + GstVideoInferenceClass parent; +}; + +/* class initialization */ + +G_DEFINE_TYPE_WITH_CODE (GstRosetta, gst_rosetta, + GST_TYPE_VIDEO_INFERENCE, + GST_DEBUG_CATEGORY_INIT (gst_rosetta_debug_category, + "rosetta", 0, "debug category for rosetta element")); + +static void +gst_rosetta_class_init (GstRosettaClass * klass) +{ + GstElementClass *element_class = GST_ELEMENT_CLASS (klass); + GstVideoInferenceClass *vi_class = GST_VIDEO_INFERENCE_CLASS (klass); + gst_element_class_add_static_pad_template (element_class, + &sink_model_factory); + gst_element_class_add_static_pad_template (element_class, &src_model_factory); + + gst_element_class_set_static_metadata (GST_ELEMENT_CLASS (klass), + "Rosetta", "Filter", + "Infers characters from an incoming image", + "Edgar Chaves \n\t\t\t" + " Luis Leon "); + + vi_class->preprocess = GST_DEBUG_FUNCPTR (gst_rosetta_preprocess); + vi_class->postprocess = GST_DEBUG_FUNCPTR (gst_rosetta_postprocess); + vi_class->start = GST_DEBUG_FUNCPTR (gst_rosetta_start); + vi_class->stop = GST_DEBUG_FUNCPTR (gst_rosetta_stop); +} + + +static void +gst_rosetta_init (GstRosetta * rosetta) +{ +} + +static gboolean +gst_rosetta_preprocess (GstVideoInference * vi, + GstVideoFrame * inframe, GstVideoFrame * outframe) +{ + GstRosetta *rosetta = NULL; + gint width = 0, height = 0; + g_return_val_if_fail (vi, FALSE); + g_return_val_if_fail (inframe, FALSE); + g_return_val_if_fail (outframe, FALSE); + + rosetta = GST_ROSETTA (vi); + + GST_LOG_OBJECT (rosetta, "Rosetta Preprocess"); + + width = GST_VIDEO_FRAME_WIDTH (inframe); + height = GST_VIDEO_FRAME_HEIGHT (inframe); + + GST_LOG_OBJECT (rosetta, "Input frame dimensions w = %i , h = %i", width, + height); + return gst_normalize_gray_image (inframe, outframe, DEFAULT_DATA_MEAN, + DEFAULT_DATA_OFFSET, DEFAULT_MODEL_CHANNELS); +} + +gint +get_max_indices (gfloat row[MODEL_OUTPUT_COLS]) +{ + gfloat largest_probability = row[0]; + gint temp_max_index = 0; + for (int i = 0; i < MODEL_OUTPUT_COLS; ++i) { + if (largest_probability < row[i]) { + largest_probability = row[i]; + temp_max_index = i; + } + } + return temp_max_index; +} + +/** + * NOTE: After using this function, please free the returned + * gchar with g_free(), due to this function is transfering + * the ownership of the allocated memory. + **/ +gchar * +concatenate_chars (int max_indices[MODEL_OUTPUT_ROWS]) +{ + gint i = 0; + gint counter = 0; + gchar *final_phrase = NULL; + const gchar chars[MODEL_OUTPUT_COLS] = + { '_', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', + 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z' + }; + /* Instead of using g_malloc() & memset g_strnfill(), will create + * the memory allocation and fill the string with empty spaces. + */ + final_phrase = g_strnfill (MODEL_OUTPUT_ROWS + 1, ' '); + + for (i = 0; i < MODEL_OUTPUT_ROWS; ++i) { + + /* Checking if the actual max index value is different from '_' character + * and also, checking if i is greater than 0, and finally, checking + * if the actual max index is equal from the previous one. + */ + if (BLANK != max_indices[i] && !(0 < i + && (max_indices[i - 1] == max_indices[i]))) { + final_phrase[counter] = chars[max_indices[i]]; + ++counter; + } + } + + final_phrase[MODEL_OUTPUT_ROWS] = '\0'; + + return final_phrase; +} + +static gboolean +gst_rosetta_postprocess (GstVideoInference * vi, + const gpointer prediction, gsize predsize, GstMeta * meta_model, + GstVideoInfo * info_model, gboolean * valid_prediction, + gchar ** labels_list, gint num_labels) +{ + GstRosetta *rosetta = NULL; + + gint max_indices[MODEL_OUTPUT_ROWS]; + gfloat row[MODEL_OUTPUT_COLS]; + gint index = 0; + const gfloat *pred = NULL; + gchar *output = NULL; + GstInferenceMeta *imeta = NULL; + GstInferencePrediction *root = NULL; + + g_return_val_if_fail (vi, FALSE); + g_return_val_if_fail (prediction, FALSE); + g_return_val_if_fail (meta_model, FALSE); + g_return_val_if_fail (info_model, FALSE); + + GST_LOG_OBJECT (rosetta, "Rosetta Postprocess"); + + imeta = (GstInferenceMeta *) meta_model; + rosetta = GST_ROSETTA (vi); + root = imeta->prediction; + if (!root) { + GST_ERROR_OBJECT (vi, "Prediction is not part of the Inference Meta"); + return FALSE; + } + pred = (const gfloat *) prediction; + GST_LOG_OBJECT (vi, "Predicting..."); + + for (int j = 0; j < MODEL_OUTPUT_ROWS; ++j) { + for (int i = 0; i < MODEL_OUTPUT_COLS; ++i) { + row[i] = pred[index]; + ++index; + } + max_indices[j] = get_max_indices (row); + } + GST_LOG_OBJECT (vi, "Rosetta prediction is done"); + + output = concatenate_chars (max_indices); + + GST_LOG_OBJECT (vi, "The phrase is %s", output); + + g_free (output); + return TRUE; +} + +static gboolean +gst_rosetta_start (GstVideoInference * vi) +{ + GST_INFO_OBJECT (vi, "Starting Rosetta"); + + return TRUE; +} + +static gboolean +gst_rosetta_stop (GstVideoInference * vi) +{ + GST_INFO_OBJECT (vi, "Stopping Rosetta"); + + return TRUE; +} diff --git a/ext/r2inference/gstrosetta.h b/ext/r2inference/gstrosetta.h new file mode 100644 index 0000000..6ac6fed --- /dev/null +++ b/ext/r2inference/gstrosetta.h @@ -0,0 +1,32 @@ +/* + * GStreamer + * Copyright (C) 2021 RidgeRun + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Library General Public + * License as published by the Free Software Foundation; either + * version 2 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Library General Public License for more details. + * + * You should have received a copy of the GNU Library General Public + * License along with this library; if not, write to the + * Free Software Foundation, Inc., 59 Temple Place - Suite 330, + * Boston, MA 02111-1307, USA. + * + */ + +#ifndef _GST_ROSETTA_H_ +#define _GST_ROSETTA_H_ + +#include + +G_BEGIN_DECLS +#define GST_TYPE_ROSETTA gst_rosetta_get_type() +G_DECLARE_FINAL_TYPE (GstRosetta, gst_rosetta, GST, ROSETTA, GstVideoInference) + +G_END_DECLS +#endif /* _GST_ROSETTA_H_ */ diff --git a/ext/r2inference/gsttinyyolov3.c b/ext/r2inference/gsttinyyolov3.c index f7faee7..808833c 100644 --- a/ext/r2inference/gsttinyyolov3.c +++ b/ext/r2inference/gsttinyyolov3.c @@ -62,8 +62,11 @@ GST_DEBUG_CATEGORY_STATIC (gst_tinyyolov3_debug_category); #define MAX_IOU_THRESH 1 #define MIN_IOU_THRESH 0 #define DEFAULT_IOU_THRESH 0.40 +/* Number of classes detected by the model*/ +#define MAX_NUM_CLASSES G_MAXUINT +#define MIN_NUM_CLASSES 1 +#define DEFAULT_NUM_CLASSES 80 -#define TOTAL_CLASSES 80 #define TOTAL_BOXES 2535 /* prototypes */ @@ -87,6 +90,7 @@ enum PROP_OBJ_THRESH, PROP_PROB_THRESH, PROP_IOU_THRESH, + PROP_NUM_CLASSES, }; /* pad templates */ @@ -117,6 +121,7 @@ struct _GstTinyyolov3 gdouble obj_thresh; gdouble prob_thresh; gdouble iou_thresh; + guint num_classes; }; struct _GstTinyyolov3Class @@ -150,7 +155,9 @@ gst_tinyyolov3_class_init (GstTinyyolov3Class * klass) " Michael Gruner \n\t\t\t" " Carlos Aguero \n\t\t\t" " Miguel Taylor \n\t\t\t" - " Greivin Fallas "); + " Greivin Fallas \n\t\t\t" + " Edgar Chaves \n\t\t\t" + " Luis Leon "); gobject_class->set_property = gst_tinyyolov3_set_property; gobject_class->get_property = gst_tinyyolov3_get_property; @@ -168,6 +175,11 @@ gst_tinyyolov3_class_init (GstTinyyolov3Class * klass) "Intersection over union threshold to merge similar boxes", MIN_IOU_THRESH, MAX_IOU_THRESH, DEFAULT_IOU_THRESH, G_PARAM_READWRITE)); + g_object_class_install_property (gobject_class, PROP_NUM_CLASSES, + g_param_spec_uint ("number-of-classes", "num-classes", + "Number of classes detected by the TinyYOLOv3 model", + MIN_NUM_CLASSES, MAX_NUM_CLASSES, DEFAULT_NUM_CLASSES, + G_PARAM_READWRITE)); vi_class->start = GST_DEBUG_FUNCPTR (gst_tinyyolov3_start); vi_class->stop = GST_DEBUG_FUNCPTR (gst_tinyyolov3_stop); @@ -181,6 +193,7 @@ gst_tinyyolov3_init (GstTinyyolov3 * tinyyolov3) tinyyolov3->obj_thresh = DEFAULT_OBJ_THRESH; tinyyolov3->prob_thresh = DEFAULT_PROB_THRESH; tinyyolov3->iou_thresh = DEFAULT_IOU_THRESH; + tinyyolov3->num_classes = DEFAULT_NUM_CLASSES; } static void @@ -208,6 +221,16 @@ gst_tinyyolov3_set_property (GObject * object, guint property_id, "Changed intersection over union threshold to %lf", tinyyolov3->iou_thresh); break; + case PROP_NUM_CLASSES: + if (GST_STATE (tinyyolov3) != GST_STATE_NULL) { + GST_ERROR_OBJECT (tinyyolov3, + "Can't set property if not on NULL state"); + return; + } else { + tinyyolov3->num_classes = g_value_get_uint (value); + GST_DEBUG_OBJECT (tinyyolov3, + "Changed the number of clases to %u", tinyyolov3->num_classes); + } default: G_OBJECT_WARN_INVALID_PROPERTY_ID (object, property_id, pspec); break; @@ -232,6 +255,9 @@ gst_tinyyolov3_get_property (GObject * object, guint property_id, case PROP_IOU_THRESH: g_value_set_double (value, tinyyolov3->iou_thresh); break; + case PROP_NUM_CLASSES: + g_value_set_uint (value, tinyyolov3->num_classes); + break; default: G_OBJECT_WARN_INVALID_PROPERTY_ID (object, property_id, pspec); break; @@ -274,7 +300,7 @@ gst_tinyyolov3_postprocess (GstVideoInference * vi, const gpointer prediction, gst_create_boxes_float (vi, prediction, valid_prediction, &boxes, &num_boxes, tinyyolov3->obj_thresh, tinyyolov3->prob_thresh, tinyyolov3->iou_thresh, probabilities, - TOTAL_CLASSES); + tinyyolov3->num_classes); GST_LOG_OBJECT (tinyyolov3, "Number of predictions: %d", num_boxes); diff --git a/ext/r2inference/meson.build b/ext/r2inference/meson.build index c708020..8fe00bf 100644 --- a/ext/r2inference/meson.build +++ b/ext/r2inference/meson.build @@ -8,7 +8,8 @@ gstinference_sources = [ 'gstmobilenetv2ssd.c', 'gstresnet50v1.c', 'gsttinyyolov2.c', - 'gsttinyyolov3.c' + 'gsttinyyolov3.c', + 'gstrosetta.c' ] gstinference = library('gstinference', diff --git a/gst-libs/gst/r2inference/gstinferencepreprocess.c b/gst-libs/gst/r2inference/gstinferencepreprocess.c index 6345c7e..8ed7498 100644 --- a/gst-libs/gst/r2inference/gstinferencepreprocess.c +++ b/gst-libs/gst/r2inference/gstinferencepreprocess.c @@ -31,6 +31,9 @@ static void gst_apply_means_std (GstVideoFrame * inframe, const gdouble std_r, const gdouble std_g, const gdouble std_b, const gint model_channels); +static void gst_apply_gray_normalization (GstVideoFrame * inframe, + GstVideoFrame * outframe, gdouble std, gdouble offset); + static void gst_apply_means_std (GstVideoFrame * inframe, GstVideoFrame * outframe, gint first_index, gint last_index, gint offset, gint channels, @@ -168,3 +171,43 @@ gst_pixel_to_float (GstVideoFrame * inframe, GstVideoFrame * outframe, channels, mean, mean, mean, std, std, std, model_channels); return TRUE; } + +gboolean +gst_normalize_gray_image (GstVideoFrame * inframe, GstVideoFrame * outframe, + gdouble mean, gint offset, gint model_channels) +{ + gint first_index = 0, last_index = 0; + g_return_val_if_fail (inframe != NULL, FALSE); + g_return_val_if_fail (outframe != NULL, FALSE); + if (gst_configure_format_values (inframe, &first_index, &last_index, &offset, + &model_channels) == FALSE) { + return FALSE; + } + + gst_apply_gray_normalization (inframe, outframe, mean, offset); + + return TRUE; +} + +static void +gst_apply_gray_normalization (GstVideoFrame * inframe, GstVideoFrame * outframe, + gdouble mean, gdouble offset) +{ + gint i = 0, j = 0, pixel_stride = 0, width = 0, height = 0; + const gdouble rcp_mean = 1. / mean; + + g_return_if_fail (inframe != NULL); + g_return_if_fail (outframe != NULL); + + pixel_stride = GST_VIDEO_FRAME_COMP_STRIDE (inframe, 0); + width = GST_VIDEO_FRAME_WIDTH (inframe); + height = GST_VIDEO_FRAME_HEIGHT (inframe); + + for (i = 0; i < height; ++i) { + for (j = 0; j < width; ++j) { + ((gfloat *) outframe->data[0])[(i * width + j)] = + (((guchar *) inframe->data[0])[(i * pixel_stride + + j)] * rcp_mean - offset); + } + } +} diff --git a/gst-libs/gst/r2inference/gstinferencepreprocess.h b/gst-libs/gst/r2inference/gstinferencepreprocess.h index 1ed4218..f7c1312 100644 --- a/gst-libs/gst/r2inference/gstinferencepreprocess.h +++ b/gst-libs/gst/r2inference/gstinferencepreprocess.h @@ -61,6 +61,18 @@ gboolean gst_subtract_mean(GstVideoFrame * inframe, GstVideoFrame * outframe, gd gboolean gst_pixel_to_float(GstVideoFrame * inframe, GstVideoFrame * outframe, gint model_channels); +/** + * \brief Normalize grayscale image the image within a given mean and offset + * + * \param inframe The input frame + * \param outframe The output frame after preprocess + * \param mean The mean value of the image + * \param offset The value that will be substracted to every pixel + * \param model_channels The number of channels of the model + */ +gboolean +gst_normalize_gray_image (GstVideoFrame * inframe, GstVideoFrame * outframe, + gdouble mean, gint offset, gint model_channels); G_END_DECLS #endif diff --git a/meson.build b/meson.build index fd6e47a..b4c8eff 100644 --- a/meson.build +++ b/meson.build @@ -1,5 +1,5 @@ project('GStreamer Inference', ['c', 'cpp'], default_options : ['cpp_std=c++11'], - version : '0.11.0.1', + version : '0.12.0.1', meson_version : '>= 0.50',) project_name = meson.project_name() @@ -21,7 +21,7 @@ api_version = '1.0' # Required versions gst_req = '>= 1.8.0.1' -r2i_req = '>= 0.8.0' +r2i_req = '>= 0.12.0' # Find external dependencies gst_dep = dependency('gstreamer-1.0', version : gst_req)