Skip to content

Commit

Permalink
[KERAS]Embedding layer (apache#5444)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored and trevor-m committed Jun 18, 2020
1 parent be2f213 commit 98e67db
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
10 changes: 9 additions & 1 deletion python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,14 @@ def _convert_permute(inexpr, keras_layer, _):
return _op.transpose(inexpr, axes=(0,) + keras_layer.dims)


def _convert_embedding(inexpr, keras_layer, etab):
indices = inexpr
weightList = keras_layer.get_weights()
weight = etab.new_const(weightList[0])
out = _op.take(weight, indices.astype('int32'), axis=0)

return out

def _convert_dense(inexpr, keras_layer, etab):
weightList = keras_layer.get_weights()
weight = etab.new_const(weightList[0].transpose([1, 0]))
Expand Down Expand Up @@ -893,7 +901,7 @@ def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument
'Maximum' : _convert_merge,
'Dot' : _convert_merge,
'Permute' : _convert_permute,
# 'Embedding' : _convert_embedding,
'Embedding' : _convert_embedding,
# 'RepeatVector' : _convert_repeat_vector,

'InputLayer' : _default_skip,
Expand Down
20 changes: 19 additions & 1 deletion tests/python/frontend/keras/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,24 @@ def test_forward_zero_padding3d(self, keras):
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, layout='NDHWC')


def test_forward_embedding(self, keras):
data = keras.layers.Input(shape=(2, 4), dtype="int32")
x = keras.layers.Embedding(10, 3)(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, need_transpose=False)

data = keras.layers.Input(shape=(2, 3, 4), dtype="int32")
x = keras.layers.Embedding(4, 5)(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, need_transpose=False)

data = keras.layers.Input(shape=(6, 2, 3, 4), dtype="int32")
x = keras.layers.Embedding(4, 5)(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, need_transpose=False)


if __name__ == '__main__':
for k in [keras, tf_keras]:
sut = TestKeras()
Expand Down Expand Up @@ -497,4 +515,4 @@ def test_forward_zero_padding3d(self, keras):
sut.test_forward_pool3d(keras=k)
sut.test_forward_upsample3d(keras=k)
sut.test_forward_zero_padding3d(keras=k)

sut.test_forward_embedding(keras=k)

0 comments on commit 98e67db

Please sign in to comment.