Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added forecast visualizers #13

Merged
merged 8 commits into from
Jun 4, 2024
2 changes: 2 additions & 0 deletions src/Bonsai.ML.Visualizers/Bonsai.ML.Visualizers.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
<ItemGroup>
<PackageReference Include="Bonsai.Core" Version="2.8.1" />
<PackageReference Include="Bonsai.Design" Version="2.8.0" />
<PackageReference Include="Bonsai.Vision.Design" Version="2.8.1" />
<PackageReference Include="MathNet.Numerics" Version="5.0.0" />
<PackageReference Include="OxyPlot.Core" Version="2.1.2" />
<PackageReference Include="OxyPlot.WindowsForms" Version="2.1.2" />
</ItemGroup>
Expand Down
91 changes: 91 additions & 0 deletions src/Bonsai.ML.Visualizers/ForecastImageOverlay.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
using Bonsai.Design;
using Bonsai.Vision.Design;
using Bonsai;
using Bonsai.ML.Visualizers;
using Bonsai.ML.LinearDynamicalSystems.Kinematics;
using System;
using System.Collections.Generic;
using OpenCV.Net;
using MathNet.Numerics.LinearAlgebra;
using OxyPlot;

[assembly: TypeVisualizer(typeof(ForecastImageOverlay), Target = typeof(MashupSource<ImageMashupVisualizer, ForecastVisualizer>))]

namespace Bonsai.ML.Visualizers
{
/// <summary>
/// Provides a mashup visualizer to display the forecast of a Kalman Filter kinematics model overtime of an ImageMashupVisualizer.
/// </summary>
public class ForecastImageOverlay : DialogTypeVisualizer
{
private ImageMashupVisualizer visualizer;
private IplImage overlay;

/// <inheritdoc/>
public override void Show(object value)
{

var image = visualizer.VisualizerImage;
Size size = new Size(image.Width, image.Height);
IplDepth depth = image.Depth;
int channels = image.Channels;

overlay = new IplImage(size, depth, channels);
var alpha = 0.1;

Forecast forecast = (Forecast)value;
List<ForecastResult> forecastResults = forecast.ForecastResults;

for (int i = 0; i < forecastResults.Count; i++)
{
var forecastResult = forecastResults[i];
var kinematicState = forecastResult.KinematicState;

double xMean = kinematicState.Position.X.Mean;
double yMean = kinematicState.Position.Y.Mean;

Point center = new Point((int)Math.Round(xMean), (int)Math.Round(yMean));

double xVar = kinematicState.Position.X.Variance;
double yVar = kinematicState.Position.Y.Variance;
double xyCov = kinematicState.Position.Covariance;

var covariance = Matrix<double>.Build.DenseOfArray(new double[,] {
{ xVar, xyCov },
{ xyCov, yVar }
});

var evd = covariance.Evd();
var evals = evd.EigenValues.Real();
var evecs = evd.EigenVectors;

double angle = Math.Atan2(evecs[1, 0], evecs[0, 0]) * 180 / Math.PI;

Size axes = new Size
{
Width = (int)(2 * Math.Sqrt(evals[0])),
Height = (int)(2 * Math.Sqrt(evals[1]))
};

OxyColor color = OxyColors.Yellow;

CV.Ellipse(overlay, center, axes, angle, 0, 360, new Scalar(color.B, color.G, color.R, color.A), -1);
}

CV.AddWeighted(image, 1 - alpha, overlay, alpha, 1, image);
overlay.SetZero();
}

/// <inheritdoc/>
public override void Load(IServiceProvider provider)
{
visualizer = (ImageMashupVisualizer)provider.GetService(typeof(MashupVisualizer));
}

/// <inheritdoc/>
public override void Unload()
{
overlay.Dispose();
}
}
}
116 changes: 116 additions & 0 deletions src/Bonsai.ML.Visualizers/ForecastPlotOverlay.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
using Bonsai.Design;
using Bonsai;
using Bonsai.ML.Visualizers;
using Bonsai.ML.LinearDynamicalSystems;
using Bonsai.ML.LinearDynamicalSystems.Kinematics;
using System;
using System.Collections.Generic;
using OxyPlot.Series;
using OxyPlot;

[assembly: TypeVisualizer(typeof(ForecastPlotOverlay), Target = typeof(MashupSource<KinematicStateVisualizer, ForecastVisualizer>))]

namespace Bonsai.ML.Visualizers
{
/// <summary>
/// Provides a mashup visualizer to display the forecast of a Kalman Filter kinematics model overtime of a KinematicStateVisualizer.
/// </summary>
public class ForecastPlotOverlay : DialogTypeVisualizer
{
private List<LineSeries> lineSeriesList = new();

private List<AreaSeries> areaSeriesList = new();

private KinematicStateVisualizer visualizer;

/// <inheritdoc/>
public override void Show(object value)
{
var time = DateTime.Now;
Forecast forecast = (Forecast)value;
var componentVisualizers = visualizer.ComponentVisualizers;

for (int i = 0; i < componentVisualizers.Count; i++)
{
var plot = componentVisualizers[i].Plot;
var lineSeries = lineSeriesList[i];
var areaSeries = areaSeriesList[i];

plot.ResetLineSeries(lineSeries);
plot.ResetAreaSeries(areaSeries);

DateTime forecastTime = time;

for (int j = 0; j < forecast.ForecastResults.Count; j++)
{
var forecastResult = forecast.ForecastResults[j];
var kinematicState = forecastResult.KinematicState;
forecastTime = time + forecastResult.Timestep;

StateComponent[] stateComponents = new StateComponent[] {kinematicState.Position.X, kinematicState.Position.Y, kinematicState.Velocity.X, kinematicState.Velocity.Y, kinematicState.Acceleration.X, kinematicState.Acceleration.Y};

AddStateComponentDataToSeries(plot, stateComponents[i], lineSeries, areaSeries, forecastTime);

}

plot.SetAxes(minTime: forecastTime.AddSeconds(-plot.Capacity), maxTime: forecastTime);
}
}

private void AddStateComponentDataToSeries(TimeSeriesOxyPlotBase plot, StateComponent stateComponent, LineSeries lineSeries, AreaSeries areaSeries, DateTime time)
{
double mean = stateComponent.Mean;
double variance = stateComponent.Variance;

plot.AddToLineSeries(
lineSeries: lineSeries,
time: time,
value: mean
);

plot.AddToAreaSeries(
areaSeries: areaSeries,
time: time,
value1: mean + variance,
value2: mean - variance
);
}

/// <inheritdoc/>
public override void Load(IServiceProvider provider)
{
if (lineSeriesList.Count > 0)
{
lineSeriesList.Clear();
lineSeriesList = new();
}

if (areaSeriesList.Count > 0)
{
areaSeriesList.Clear();
areaSeriesList = new();
}

var service = provider.GetService(typeof(MashupVisualizer));
visualizer = (KinematicStateVisualizer)service;
var componentVisualizers = visualizer.ComponentVisualizers;

for (int i = 0; i < componentVisualizers.Count; i++)
{
var lineSeries = componentVisualizers[i].Plot.AddNewLineSeries($"Forecast {visualizer.Labels[i]} Mean", color: OxyColors.Yellow);
var areaSeries = componentVisualizers[i].Plot.AddNewAreaSeries($"Forecast {visualizer.Labels[i]} Variance", color: OxyColors.Yellow, opacity: 50);

componentVisualizers[i].Plot.ResetLineSeries(lineSeries);
componentVisualizers[i].Plot.ResetAreaSeries(areaSeries);

lineSeriesList.Add(lineSeries);
areaSeriesList.Add(areaSeries);
}
}

/// <inheritdoc/>
public override void Unload()
{
}
}
}
133 changes: 133 additions & 0 deletions src/Bonsai.ML.Visualizers/ForecastVisualizer.cs
glopesdev marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
using System;
using System.Windows.Forms;
using System.Collections.Generic;
using Bonsai;
using Bonsai.Design;
using Bonsai.ML.Visualizers;
using Bonsai.ML.LinearDynamicalSystems.Kinematics;
using OxyPlot;
using System.Reactive;
using System.Linq;
using System.Reactive.Linq;

[assembly: TypeVisualizer(typeof(ForecastVisualizer), Target = typeof(Forecast))]

namespace Bonsai.ML.Visualizers
{
/// <summary>
/// Provides a type visualizer to display the forecast of a Kalman Filter kinematics model.
/// </summary>
public class ForecastVisualizer : BufferedVisualizer
{

private int rowCount = 3;
private int columnCount = 2;
private string[] labels = new string[] {
"Forecast Position X",
"Forecast Position Y",
"Forecast Velocity X",
"Forecast Velocity Y",
"Forecast Acceleration X",
"Forecast Acceleration Y"
};

private List<StateComponentVisualizer> componentVisualizers = new();
private TableLayoutPanel container;

/// <inheritdoc/>
public override void Load(IServiceProvider provider)
{
container = new TableLayoutPanel
{
ColumnCount = columnCount,
RowCount = rowCount,
Dock = DockStyle.Fill
};

for (int i = 0; i < container.RowCount; i++)
{
container.RowStyles.Add(new RowStyle(SizeType.Percent, 100f / rowCount));
}

for (int i = 0; i < container.ColumnCount; i++)
{
container.ColumnStyles.Add(new ColumnStyle(SizeType.Percent, 100f / columnCount));
}

for (int i = 0 ; i < rowCount; i++)
{
for (int j = 0; j < columnCount; j++)
{
var StateComponentVisualizer = new StateComponentVisualizer() {
Label = labels[i * columnCount + j],
LineSeriesColor = OxyColors.Yellow,
AreaSeriesColor = OxyColors.Yellow
};
StateComponentVisualizer.Load(provider);
container.Controls.Add(StateComponentVisualizer.Plot, j, i);
componentVisualizers.Add(StateComponentVisualizer);
}
}

var visualizerService = (IDialogTypeVisualizerService)provider.GetService(typeof(IDialogTypeVisualizerService));

if (visualizerService != null)
{
visualizerService.AddControl(container);
}
}

/// <inheritdoc/>
public override void Show(object value)
{
}

/// <inheritdoc/>
protected override void ShowBuffer(IList<Timestamped<object>> values)
{
if (values.Count == 0) return;
var latestForecast = values.Last();
var timestamp = latestForecast.Timestamp;
var forecast = (Forecast)latestForecast.Value;
var futureTime = timestamp;

List<Timestamped<object>> positionX = new();
List<Timestamped<object>> positionY = new();
List<Timestamped<object>> velocityX = new();
List<Timestamped<object>> velocityY = new();
List<Timestamped<object>> accelerationX = new();
List<Timestamped<object>> accelerationY = new();

foreach (var forecastResult in forecast.ForecastResults)
{
futureTime = timestamp + forecastResult.Timestep;
positionX.Add(new Timestamped<object>(forecastResult.KinematicState.Position.X, futureTime));
positionY.Add(new Timestamped<object>(forecastResult.KinematicState.Position.Y, futureTime));
velocityX.Add(new Timestamped<object>(forecastResult.KinematicState.Velocity.X, futureTime));
velocityY.Add(new Timestamped<object>(forecastResult.KinematicState.Velocity.Y, futureTime));
accelerationX.Add(new Timestamped<object>(forecastResult.KinematicState.Acceleration.X, futureTime));
accelerationY.Add(new Timestamped<object>(forecastResult.KinematicState.Acceleration.Y, futureTime));
}

var dataList = new List<List<Timestamped<object>>>() { positionX, positionY, velocityX, velocityY, accelerationX, accelerationY };

var zippedData = dataList.Zip(componentVisualizers, (data, visualizer) => new { Data = data, Visualizer = visualizer });

foreach (var item in zippedData)
{
item.Visualizer.Plot.ResetLineSeries(item.Visualizer.LineSeries);
item.Visualizer.Plot.ResetAreaSeries(item.Visualizer.AreaSeries);
item.Visualizer.ShowDataBuffer(item.Data);
item.Visualizer.Plot.SetAxes(minTime: timestamp.DateTime, maxTime: futureTime.DateTime);
}
}

/// <inheritdoc/>
public override void Unload()
{
foreach (var componentVisualizer in componentVisualizers) componentVisualizer.Unload();
if (componentVisualizers.Count > 0) componentVisualizers.Clear();
if (!container.IsDisposed) container.Dispose();
}
}
}