Skip to content

Commit

Permalink
Fix tabular preprocessing fail during web-client prediction (without …
Browse files Browse the repository at this point in the history
…labels)
  • Loading branch information
JulienVig committed Apr 11, 2024
1 parent adb3d3f commit 1edf14c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,27 @@ interface TabularEntry extends tf.TensorContainerObject {
const sanitize: PreprocessingFunction = {
type: TabularPreprocessing.Sanitize,
apply: async (entry: Promise<tf.TensorContainer>) => {
const entryContainer = await entry
// if preprocessing a dataset without labels, then the entry is an array of numbers
if (Array.isArray(entry)) {
if (Array.isArray(entryContainer)) {
const entry = entryContainer as number[]
return entry.map((i: number) => i ?? 0)
// otherwise it is an object with feature and labels
} else {
const { xs, ys } = await entry as TabularEntry
return {
xs: xs.map(i => i ?? 0),
ys
// if it is an object
} else if (typeof entryContainer === 'object' && entry !== null) {
// if the object is a tensor container with features xs and labels ys
if (Object.hasOwn(entryContainer, 'xs')) {
const { xs, ys } = entryContainer as TabularEntry
return {
xs: xs.map(i => i ?? 0),
ys
}
// if the object contains features as a dict of feature names-values
} else {
const entry = Object.values(entryContainer)
return entry.map((i: number) => i ?? 0)
}
} else {
throw new Error('Unrecognized format during tabular preprocessing')
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion discojs/discojs-core/src/validation/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ export class Validator {

let hits = 0
// Get model predictions per batch and flatten the result
// Also build the features and groudTruth arrays
// Also build the features and ground truth arrays
const predictions: number[] = (await data.preprocess().dataset.batch(batchSize)
.mapAsync(async e => {
if (typeof e === 'object' && 'xs' in e && 'ys' in e) {
Expand Down

0 comments on commit 1edf14c

Please sign in to comment.