Skip to content

Commit

Permalink
Update observations to reflect new model structure
Browse files Browse the repository at this point in the history
  • Loading branch information
ncguilbeault committed Sep 3, 2024
1 parent b17377e commit 6505dd6
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ public class AutoRegressiveObservations : ObservationsModel
/// The lags of the observations for each state.
/// </summary>
[Description("The lags of the observations for each state.")]
[JsonProperty]
public int Lags { get; set; } = 1;

/// <summary>
Expand Down Expand Up @@ -57,45 +56,45 @@ public class AutoRegressiveObservations : ObservationsModel
/// <inheritdoc/>
[JsonProperty]
[JsonConverter(typeof(ObservationsModelTypeJsonConverter))]
[Browsable(false)]
public override ObservationsModelType ObservationsModelType => ObservationsModelType.AutoRegressive;

/// <inheritdoc/>
[JsonProperty]
public override object[] Params
{
get =>[ As, Bs, Vs, SqrtSigmas ];
set
{
As = (double[,,])value[0];
Bs = (double[,])value[1];
Vs = (double[,,])value[2];
SqrtSigmas = (double[,,])value[3];
UpdateString();
}
}

/// <inheritdoc/>
[JsonProperty]
[XmlIgnore]
public override Dictionary<string, object> Kwargs => new Dictionary<string, object>
{
["lags"] = Lags,
};

/// <summary>
/// Initializes a new instance of the <see cref="AutoRegressiveObservations"/> class.
/// </summary>
/// <inheritdoc/>
[XmlIgnore]
public static new string[] KwargsArray => [ "lags" ];

/// <inheritdoc/>
public AutoRegressiveObservations () : base()
{
}

/// <inheritdoc/>
public AutoRegressiveObservations (params object[] args) : base(args)
{
}

/// <inheritdoc/>
protected override bool CheckConstructorArgs(params object[] args)
protected override void CheckKwargs(params object[] kwargs)
{
if (args is null || args.Length != 1)
if (kwargs is null || kwargs.Length != 1)
{
throw new ArgumentException("The AutoRegressiveObservations operator requires a single argument specifying the number of lags.");
throw new ArgumentException($"The AutoRegressiveObservations operator requires exactly one constructor argument: {nameof(Lags)}.");
}
return true;
}

/// <inheritdoc/>
Expand All @@ -104,7 +103,44 @@ protected override void UpdateKwargs(params object[] args)
Lags = args[0] switch
{
int lags => lags,
var lags => Convert.ToInt32(lags),
var lags => Convert.ToInt32(lags)
};
}

/// <inheritdoc/>
protected override void CheckParams(params object[] @params)
{
if (@params is not null && @params.Length != 4)
{
throw new ArgumentException($"The {nameof(AutoRegressiveObservations)} operator requires exactly four parameters: {nameof(As)}, {nameof(Bs)}, {nameof(Vs)}, and {nameof(SqrtSigmas)}.");
}
}

/// <inheritdoc/>
protected override void UpdateParams(params object[] @params)
{
As = @params[0] switch
{
double[,,] As => As,
_ => null
};

Bs = @params[1] switch
{
double[,] Bs => Bs,
_ => null
};

Vs = @params[2] switch
{
double[,,] Vs => Vs,
_ => null
};

SqrtSigmas = @params[3] switch
{
double[,,] SqrtSigmas => SqrtSigmas,
_ => null
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,50 @@ public class BernoulliObservations : ObservationsModel
/// <inheritdoc/>
[JsonProperty]
[JsonConverter(typeof(ObservationsModelTypeJsonConverter))]
[Browsable(false)]
public override ObservationsModelType ObservationsModelType => ObservationsModelType.Bernoulli;

/// <inheritdoc/>
[JsonProperty]
public override object[] Params
public override object[] Params
{
get { return [ LogitPs ]; }
set
{
LogitPs = (double[,])value[0];
UpdateString();
get => [ LogitPs ];
}

/// <inheritdoc/>
public BernoulliObservations () : base()
{
}

/// <inheritdoc/>
public BernoulliObservations (params object[] kwargs) : base(kwargs)
{
}

/// <inheritdoc/>
protected override void CheckParams(params object[] @params)
{
if (@params is not null && @params.Length != 1)
{
throw new ArgumentException($"The {nameof(BernoulliObservations)} operator requires exactly one parameter: {nameof(LogitPs)}.");
}
}

/// <inheritdoc/>
protected override void CheckKwargs(params object[] kwargs)
{
if (kwargs is null || kwargs.Length != 0)
{
throw new ArgumentException($"The {nameof(BernoulliObservations)} operator requires exactly zero constructor arguments.");
}
}

/// <inheritdoc/>
protected override void UpdateParams(params object[] @params)
{
LogitPs = (double[,])@params[0];
}

/// <summary>
/// Returns an observable sequence of <see cref="BernoulliObservations"/> objects.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ public class CategoricalObservations : ObservationsModel
/// The number of categories in the observations.
/// </summary>
[Description("The number of categories in the observations.")]
[JsonProperty]
public int Categories { get; set; } = 2;

/// <summary>
Expand All @@ -36,54 +35,74 @@ public class CategoricalObservations : ObservationsModel
/// <inheritdoc/>
[JsonProperty]
[JsonConverter(typeof(ObservationsModelTypeJsonConverter))]
[Browsable(false)]
public override ObservationsModelType ObservationsModelType => ObservationsModelType.Categorical;

/// <inheritdoc/>
[JsonProperty]
public override object[] Params
{
get =>[ Logits ];
set
{
Logits = (double[,,])value[0];
UpdateString();
}
get => [ Logits ];
}

/// <inheritdoc/>
[JsonProperty]
[XmlIgnore]
[Browsable(false)]
public override Dictionary<string, object> Kwargs => new Dictionary<string, object>
{
["C"] = Categories,
};

/// <summary>
/// Initializes a new instance of the <see cref="CategoricalObservations"/> class.
/// </summary>
public CategoricalObservations (params object[] args) : base(args)
/// <inheritdoc/>
[XmlIgnore]
[Browsable(false)]
public static new string[] KwargsArray => [ "C" ];

/// <inheritdoc/>
public CategoricalObservations() : base()
{
}

/// <inheritdoc/>
public CategoricalObservations (params object[] kwargs) : base(kwargs)
{
}

/// <inheritdoc/>
protected override bool CheckConstructorArgs(params object[] args)
protected override void CheckKwargs(params object[] kwargs)
{
if (args is null || args.Length != 1)
if (kwargs is null || kwargs.Length != 1)
{
throw new ArgumentException("The CategoricalObservations operator requires a single argument specifying the number of categories.");
throw new ArgumentException($"The {nameof(CategoricalObservations)} operator requires exactly one keyword argument: {nameof(Categories)}.");
}
return true;
}

/// <inheritdoc/>
protected override void UpdateKwargs(params object[] args)
protected override void UpdateKwargs(params object[] kwargs)
{
Categories = args[0] switch
Categories = kwargs[0] switch
{
int c => c,
var c => Convert.ToInt32(c),
};
}

/// <inheritdoc/>
protected override void CheckParams(params object[] @params)
{
if (@params is not null && @params.Length != 1)
{
throw new ArgumentException($"The {nameof(CategoricalObservations)} operator requires exactly one parameter: {nameof(Logits)}.");
}
}

/// <inheritdoc/>
protected override void UpdateParams(params object[] @params)
{
Logits = (double[,,])@params[0];
}

/// <summary>
/// Returns an observable sequence of <see cref="CategoricalObservations"/> objects.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,47 @@ public class ExponentialObservations : ObservationsModel
/// <inheritdoc/>
[JsonProperty]
[JsonConverter(typeof(ObservationsModelTypeJsonConverter))]
[Browsable(false)]
public override ObservationsModelType ObservationsModelType => ObservationsModelType.Exponential;

/// <inheritdoc/>
[JsonProperty]
public override object[] Params
{
get { return [ LogLambdas ]; }
set
get => [ LogLambdas ];
}

/// <inheritdoc/>
public ExponentialObservations () : base()
{
}

/// <inheritdoc/>
public ExponentialObservations (params object[] kwargs) : base(kwargs)
{
}

/// <inheritdoc/>
protected override void CheckParams(params object[] @params)
{
if (@params is not null && @params.Length != 1)
{
throw new ArgumentException($"The {nameof(ExponentialObservations)} operator requires exactly one parameter: {nameof(LogLambdas)}.");
}
}

/// <inheritdoc/>
protected override void UpdateParams(params object[] @params)
{
LogLambdas = (double[,])@params[0];
}

/// <inheritdoc/>
protected override void CheckKwargs(params object[] kwargs)
{
if (kwargs is null || kwargs.Length != 0)
{
LogLambdas = (double[,])value[0];
UpdateString();
throw new ArgumentException($"The {nameof(ExponentialObservations)} operator requires exactly zero constructor arguments.");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,48 @@ public class GaussianObservations : ObservationsModel
/// <inheritdoc/>
[JsonProperty]
[JsonConverter(typeof(ObservationsModelTypeJsonConverter))]
[Browsable(false)]
public override ObservationsModelType ObservationsModelType => ObservationsModelType.Gaussian;

/// <inheritdoc/>
[JsonProperty]
public override object[] Params
{
get => [ Mus, SqrtSigmas ];
set
}

/// <inheritdoc/>
public GaussianObservations () : base()
{
}

/// <inheritdoc/>
public GaussianObservations (params object[] kwargs) : base(kwargs)
{
}

/// <inheritdoc/>
protected override void CheckParams(params object[] @params)
{
if (@params is not null && @params.Length != 2)
{
throw new ArgumentException($"The {nameof(GaussianObservations)} operator requires exactly two parameters: {nameof(Mus)} and {nameof(SqrtSigmas)}.");
}
}

/// <inheritdoc/>
protected override void UpdateParams(params object[] @params)
{
Mus = (double[,])@params[0];
SqrtSigmas = (double[,,])@params[1];
}

/// <inheritdoc/>
protected override void CheckKwargs(params object[] kwargs)
{
if (kwargs is null || kwargs.Length != 0)
{
Mus = (double[,])value[0];
SqrtSigmas = (double[,,])value[1];
UpdateString();
throw new ArgumentException($"The {nameof(GaussianObservations)} operator requires exactly zero constructor arguments.");
}
}

Expand Down
Loading

0 comments on commit 6505dd6

Please sign in to comment.