Skip to content

Commit

Permalink
Add FlexBuffers reader (#499)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Sep 8, 2020
1 parent 684ab68 commit c9bb7de
Showing 1 changed file with 249 additions and 2 deletions.
251 changes: 249 additions & 2 deletions source/tflite.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

var tflite = tflite || {};
var flatbuffers = flatbuffers || require('./flatbuffers');
var flexbuffers = {};

tflite.ModelFactory = class {

Expand Down Expand Up @@ -294,8 +295,26 @@ tflite.Node = class {
this._outputs.push(new tflite.Parameter(outputName, true, [ argument ]));
}
if (type.custom && node.custom_options.length > 0) {
const schema = metadata.attribute(this.type, 'custom');
this._attributes.push(new tflite.Attribute(schema, 'custom', Array.from(node.custom_options)));
let decoded = false;
if (node.custom_options_format === tflite.schema.CustomOptionsFormat.FLEXBUFFERS) {
try {
const reader = flexbuffers.Reader.create(node.custom_options);
const custom_options = reader.read();
for (const key of Object.keys(custom_options)) {
const schema = metadata.attribute(this.type, key);
const value = custom_options[key];
this._attributes.push(new tflite.Attribute(schema, key, value));
}
decoded = true;
}
catch (err) {
// continue regardless of error
}
}
if (!decoded) {
const schema = metadata.attribute(this.type, 'custom');
this._attributes.push(new tflite.Attribute(schema, 'custom', Array.from(node.custom_options)));
}
}
const options = node.builtin_options;
if (options) {
Expand Down Expand Up @@ -803,6 +822,234 @@ tflite.Error = class extends Error {
}
};

flexbuffers.Reader = class {

constructor(buffer) {
this._reader = new flexbuffers.BinaryReader(buffer);
}

static create(buffer) {
return new flexbuffers.Reader(buffer);
}

read() {
const length = this._reader.length;
if (length < 3) {
throw 'Invalid buffer size.';
}
const byteSize = this._reader.uint(length - 1, 0);
if (byteSize > 8) {
throw 'Invalid byte size.';
}
const bitSize = byteSize >> 2;
const packedType = this._reader.uint(length - 2, 0);
const offset = length - 2 - byteSize;
return new flexbuffers.Reference(this._reader, offset, bitSize, packedType).read();
}
};

flexbuffers.Reference = class {

constructor(reader, offset, parentBitSize, packedType) {
this._reader = reader;
this._offset = offset;
this._parentBitSize = parentBitSize;
this._bitSize = packedType & 3;
this._byteSize = 1 << this._bitSize;
this._valueType = packedType >> 2;
}

read() {
switch (this._valueType) {
case 0x00: // null
return null;
case 0x01: // int
return this._reader.int(this._offset, this._parentBitSize);
case 0x02: // uint
return this._reader.uint(this._offset, this._parentBitSize);
case 0x03: // float
return this._reader.float(this._offset, this._parentBitSize);
case 0x04: {
const offset = this._reader.indirect(this._offset, this._parentBitSize);
let size = 0;
while (this._reader.int(offset + size, 0) !== 0) {
size++;
}
return this._reader.string(offset, size);
}
case 0x05: { // string
const offset = this._reader.indirect(this._offset, this._parentBitSize);
let sizeByteSize = this._byteSize;
let size = this._reader.int(offset - sizeByteSize, this._bitSize);
while (this._reader.int(offset + size, 0) !== 0) {
sizeByteSize <<= 1;
size = this._reader.int(offset - sizeByteSize, this._bitSize);
}
return this._reader.string(offset, size);
}
case 0x06: // indirect int
return this._reader.int(this._offset, this._reader.indirect(this._offset, this._parentBitSize), this._bitSize);
case 0x07: // indirect uint
return this._reader.uint(this._offset, this._reader.indirect(this._offset, this._parentBitSize), this._bitSize);
case 0x08: // indirect float
return this._reader.float(this._reader.indirect(this._offset, this._parentBitSize), this._bitSize);
case 0x09: { // map
const length = this._reader.int(this._reader.indirect(this._offset, this._parentBitSize) - this._byteSize, this._bitSize);
const keysOffset = this._reader.indirect(this._offset, this._parentBitSize) - (this._byteSize * 3);
const keysVectorOffset = this._reader.indirect(keysOffset, this._bitSize);
const keyByteSize = this._reader.int(keysOffset + this._byteSize, this._bitSize);
let keyBitSize;
switch (keyByteSize) {
case 1: keyBitSize = 0; break;
case 2: keyBitSize = 1; break;
case 4: keyBitSize = 2; break;
case 8: keyBitSize = 3; break;
}
const valuesOffset = this._reader.indirect(this._offset, this._parentBitSize);
const obj = {};
for (let i = 0; i < length; i++) {
const keyOffset = keysVectorOffset + (i * keyByteSize);
const keyReference = new flexbuffers.Reference(this._reader, keyOffset, keyBitSize, (0x04 << 2) | keyBitSize);
const key = keyReference.read();
const valueOffset = valuesOffset + (i * this._byteSize);
const packedType = this._reader.uint(valuesOffset + (length * this._byteSize) + i, 0);
const valueReference = new flexbuffers.Reference(this._reader, valueOffset, this._bitSize, packedType);
const value = valueReference.read();
obj[key] = value;
}
return obj;
}
case 0x0a: { // vector
const length = this._reader.int(this._reader.indirect(this._offset, this._parentBitSize) - this._byteSize, this._bitSize);
const arr = new Array(length);
for (let i = 0; i < length; i++) {
const itemsOffset = this._reader.indirect(this._offset, this._parentBitSize);
const itemOffset = itemsOffset + (i * this._byteSize);
const packedType = this._reader.uint(itemsOffset + (length * this._byteSize) + i, 0);
const itemReference = new flexbuffers.Reference(this._reader, itemOffset, this._bitSize, packedType);
arr[i] = itemReference.read();
}
return arr;
}
case 0x0b: // vector int
case 0x0c: // vector uint
case 0x0d: // vector float
case 0x0e: // vector key
case 0x0f: // vector string deprecated
case 0x24: { // vector bool
const length = this._reader.int(this._reader.indirect(this._offset, this._parentBitSize) - this._byteSize, this._bitSize);
const valueType = this._valueType - 0x0b + 0x01;
const packedType = valueType << 2 | 0;
const arr = new Array(length);
for (let i = 0; i < length; i++) {
const itemsOffset = this._reader.indirect(this._offset, this._parentBitSize);
const itemOffset = itemsOffset + (i * this._byteSize);
const itemReference = new flexbuffers.Reference(this._reader, itemOffset, this._bitSize, packedType);
arr[i] = itemReference.read();
}
return arr;
}
case 0x10: // vector int2
case 0x11: // vector uint2
case 0x12: // vector float2
case 0x13: // vector int3
case 0x14: // vector uint3
case 0x15: // vector float3
case 0x16: // vector int4
case 0x17: // vector uint4
case 0x18: { // vector float4
const length = (((this._valueType - 0x10) / 3) >> 0) + 2;
const valueType = ((this._valueType - 0x10) % 3) + 0x01;
const packedType = valueType << 2 | 0;
const arr = new Array(length);
for (let i = 0; i < length; i++) {
const itemsOffset = this._reader.indirect(this._offset, this._parentBitSize);
const itemOffset = itemsOffset + (i * this._byteSize);
const itemReference = new flexbuffers.Reference(this._reader, itemOffset, this._bitSize, packedType);
arr[i] = itemReference.read();
}
return arr;
}
case 0x19: { // blob
const sizeOffset = this._reader.indirect(this._offset, this._parentBitSize) - this._byteSize;
const size = this._reader.int(sizeOffset, this._bitSize);
const offset = this._reader.indirect(this._offset, this._parentBitSize);
return this._reader.bytes(offset, size);
}
case 0x1A: { // bool
return this._reader.int(this._offset, this._parentBitSize) > 0;
}
}
return undefined;
}
};

flexbuffers.BinaryReader = class {

constructor(buffer) {
this._buffer = buffer;
this._length = buffer.length;
this._view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
this._utf8Decoder = new TextDecoder('utf-8');
}

get length() {
return this._length;
}

int(offset, size) {
switch (size) {
case 0: return this._view.getInt8(offset);
case 1: return this._view.getInt16(offset, true);
case 2: return this._view.getInt32(offset, true);
case 3: return this._view.getInt64(offset, true);
}
throw new flexbuffers.Error('Invalid int size.');
}

uint(offset, size) {
switch (size) {
case 0: return this._view.getUint8(offset);
case 1: return this._view.getUint16(offset, true);
case 2: return this._view.getUint32(offset, true);
case 3: return this._view.getUint64(offset, true);
}
throw new flexbuffers.Error('Invalid uint size.');
}

float(offset, size) {
switch (size) {
case 2:
return this._view.getFloat32(offset, true);
case 3:
return this._view.getFloat64(offset, true);
}
throw new flexbuffers.Error('Invalid float size.');
}

string(offset, size) {
const bytes = this._buffer.subarray(offset, offset + size);
return this._utf8Decoder.decode(bytes);
}

bytes(offset, size) {
return this._buffer.slice(offset, offset + size);
}

indirect(offset, size) {
return offset - this.uint(offset, size);
}
};

flexbuffers.Error = class extends Error {

constructor(message) {
super(message);
this.name = 'FlexBuffers Error';
this.message = message;
}
};

if (typeof module !== 'undefined' && typeof module.exports === 'object') {
module.exports.ModelFactory = tflite.ModelFactory;
}

0 comments on commit c9bb7de

Please sign in to comment.