Skip to content

Commit

Permalink
Change TF graph proto to TB graph proto (#1730)
Browse files Browse the repository at this point in the history
TF 2.0 is removing symbols for GraphDef and MetaGraphDef protos.
We are converting it to TB proto definitions instead.

Relates to #1718.
  • Loading branch information
stephanwlee authored Dec 28, 2018
1 parent 5e484f8 commit 5328bd9
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 17 deletions.
8 changes: 5 additions & 3 deletions tensorboard/backend/event_processing/event_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from tensorboard.compat import tf
from tensorboard.compat.proto import config_pb2
from tensorboard.compat.proto import event_pb2
from tensorboard.compat.proto import graph_pb2
from tensorboard.compat.proto import meta_graph_pb2
from tensorboard.plugins.distribution import compressor
from tensorboard.util import tb_logging

Expand Down Expand Up @@ -353,7 +355,7 @@ def _ProcessEvent(self, event):
if self._graph is None or self._graph_from_metagraph:
# We may have a graph_def in the metagraph. If so, and no
# graph_def is directly available, use this one instead.
meta_graph = tf.compat.v1.MetaGraphDef()
meta_graph = meta_graph_pb2.MetaGraphDef()
meta_graph.ParseFromString(self._meta_graph)
if meta_graph.graph_def:
if self._graph is not None:
Expand Down Expand Up @@ -447,7 +449,7 @@ def Graph(self):
Returns:
The `graph_def` proto.
"""
graph = tf.compat.v1.GraphDef()
graph = graph_pb2.GraphDef()
if self._graph is not None:
graph.ParseFromString(self._graph)
return graph
Expand All @@ -464,7 +466,7 @@ def MetaGraph(self):
"""
if self._meta_graph is None:
raise ValueError('There is no metagraph in this EventAccumulator')
meta_graph = tf.compat.v1.MetaGraphDef()
meta_graph = meta_graph_pb2.MetaGraphDef()
meta_graph.ParseFromString(self._meta_graph)
return meta_graph

Expand Down
22 changes: 18 additions & 4 deletions tensorboard/backend/event_processing/event_accumulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from tensorboard.backend.event_processing import event_accumulator as ea
from tensorboard.compat.proto import config_pb2
from tensorboard.compat.proto import event_pb2
from tensorboard.compat.proto import graph_pb2
from tensorboard.compat.proto import meta_graph_pb2
from tensorboard.compat.proto import summary_pb2
from tensorboard.plugins.distribution import compressor
from tensorboard.util import tensor_util
Expand Down Expand Up @@ -829,8 +831,14 @@ def FakeScalarSummary(tag, value):
self.assertEqual(i * 5, sq_events[i].step)
self.assertEqual(i, id_events[i].value)
self.assertEqual(i * i, sq_events[i].value)
self.assertProtoEquals(graph.as_graph_def(add_shapes=True), acc.Graph())
self.assertProtoEquals(meta_graph_def, acc.MetaGraph())

expected_graph_def = graph_pb2.GraphDef.FromString(
graph.as_graph_def(add_shapes=True).SerializeToString())
self.assertProtoEquals(expected_graph_def, acc.Graph())

expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString(
meta_graph_def.SerializeToString())
self.assertProtoEquals(expected_meta_graph, acc.MetaGraph())

def testGraphFromMetaGraphBecomesAvailable(self):
"""Test accumulator by writing values and then reading them."""
Expand Down Expand Up @@ -858,8 +866,14 @@ def testGraphFromMetaGraphBecomesAvailable(self):
ea.GRAPH: True,
ea.META_GRAPH: True,
})
self.assertProtoEquals(graph.as_graph_def(add_shapes=True), acc.Graph())
self.assertProtoEquals(meta_graph_def, acc.MetaGraph())

expected_graph_def = graph_pb2.GraphDef.FromString(
graph.as_graph_def(add_shapes=True).SerializeToString())
self.assertProtoEquals(expected_graph_def, acc.Graph())

expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString(
meta_graph_def.SerializeToString())
self.assertProtoEquals(expected_meta_graph, acc.MetaGraph())

def _writeMetadata(self, logdir, summary_metadata, nonce=''):
"""Write to disk a summary with the given metadata.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from tensorboard.compat import tf
from tensorboard.compat.proto import config_pb2
from tensorboard.compat.proto import event_pb2
from tensorboard.compat.proto import graph_pb2
from tensorboard.compat.proto import meta_graph_pb2
from tensorboard.util import tb_logging


Expand Down Expand Up @@ -303,7 +305,7 @@ def _ProcessEvent(self, event):
if self._graph is None or self._graph_from_metagraph:
# We may have a graph_def in the metagraph. If so, and no
# graph_def is directly available, use this one instead.
meta_graph = tf.compat.v1.MetaGraphDef()
meta_graph = meta_graph_pb2.MetaGraphDef()
meta_graph.ParseFromString(self._meta_graph)
if meta_graph.graph_def:
if self._graph is not None:
Expand Down Expand Up @@ -380,7 +382,7 @@ def Graph(self):
Returns:
The `graph_def` proto.
"""
graph = tf.compat.v1.GraphDef()
graph = graph_pb2.GraphDef()
if self._graph is not None:
graph.ParseFromString(self._graph)
return graph
Expand All @@ -397,7 +399,7 @@ def MetaGraph(self):
"""
if self._meta_graph is None:
raise ValueError('There is no metagraph in this EventAccumulator')
meta_graph = tf.compat.v1.MetaGraphDef()
meta_graph = meta_graph_pb2.MetaGraphDef()
meta_graph.ParseFromString(self._meta_graph)
return meta_graph

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from tensorboard.backend.event_processing import plugin_event_accumulator as ea
from tensorboard.compat.proto import config_pb2
from tensorboard.compat.proto import event_pb2
from tensorboard.compat.proto import graph_pb2
from tensorboard.compat.proto import meta_graph_pb2
from tensorboard.compat.proto import summary_pb2
from tensorboard.plugins.audio import summary as audio_summary
from tensorboard.plugins.image import summary as image_summary
Expand Down Expand Up @@ -620,8 +622,14 @@ def FakeScalarSummary(tag, value):
self.assertEqual(i * 5, sq_events[i].step)
self.assertEqual(i, tensor_util.make_ndarray(id_events[i].tensor_proto).item())
self.assertEqual(i * i, tensor_util.make_ndarray(sq_events[i].tensor_proto).item())
self.assertProtoEquals(graph.as_graph_def(add_shapes=True), acc.Graph())
self.assertProtoEquals(meta_graph_def, acc.MetaGraph())

expected_graph_def = graph_pb2.GraphDef.FromString(
graph.as_graph_def(add_shapes=True).SerializeToString())
self.assertProtoEquals(expected_graph_def, acc.Graph())

expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString(
meta_graph_def.SerializeToString())
self.assertProtoEquals(expected_meta_graph, acc.MetaGraph())

def testGraphFromMetaGraphBecomesAvailable(self):
"""Test accumulator by writing values and then reading them."""
Expand Down Expand Up @@ -649,8 +657,14 @@ def testGraphFromMetaGraphBecomesAvailable(self):
ea.GRAPH: True,
ea.META_GRAPH: True,
})
self.assertProtoEquals(graph.as_graph_def(add_shapes=True), acc.Graph())
self.assertProtoEquals(meta_graph_def, acc.MetaGraph())

expected_graph_def = graph_pb2.GraphDef.FromString(
graph.as_graph_def(add_shapes=True).SerializeToString())
self.assertProtoEquals(expected_graph_def, acc.Graph())

expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString(
meta_graph_def.SerializeToString())
self.assertProtoEquals(expected_meta_graph, acc.MetaGraph())

def _writeMetadata(self, logdir, summary_metadata, nonce=''):
"""Write to disk a summary with the given metadata.
Expand Down
1 change: 1 addition & 0 deletions tensorboard/plugins/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ py_test(
"//tensorboard:expect_tensorflow_installed",
"//tensorboard/backend:application",
"//tensorboard/backend/event_processing:event_multiplexer",
"//tensorboard/compat/proto:protos_all_py_pb2",
"//tensorboard/plugins:base_plugin",
"//tensorboard/util:test_util",
"@org_pocoo_werkzeug",
Expand Down
8 changes: 5 additions & 3 deletions tensorboard/plugins/core/core_plugin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@

from tensorboard.backend import application
from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer # pylint: disable=line-too-long
from tensorboard.compat.proto import graph_pb2
from tensorboard.compat.proto import meta_graph_pb2
from tensorboard.plugins import base_plugin
from tensorboard.plugins.core import core_plugin
from tensorboard.util import test_util
Expand Down Expand Up @@ -345,19 +347,19 @@ def _generate_test_data(self, run_name, experiment_name):
with test_util.FileWriterCache.get(run_path) as writer:

# Add a simple graph event.
graph_def = tf.compat.v1.GraphDef()
graph_def = graph_pb2.GraphDef()
node1 = graph_def.node.add()
node1.name = 'a'
node2 = graph_def.node.add()
node2.name = 'b'
node2.attr['very_large_attr'].s = b'a' * 2048 # 2 KB attribute

meta_graph_def = tf.compat.v1.MetaGraphDef(graph_def=graph_def)
meta_graph_def = meta_graph_pb2.MetaGraphDef(graph_def=graph_def)

if self._only_use_meta_graph:
writer.add_meta_graph(meta_graph_def)
else:
writer.add_graph(graph_def)
writer.add_graph(graph=None, graph_def=graph_def)

# Write data for the run to the database.
# TODO(nickfelt): Figure out why reseting the graph is necessary.
Expand Down
19 changes: 19 additions & 0 deletions tensorboard/util/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

from tensorboard import db
from tensorboard.compat.proto import event_pb2
from tensorboard.compat.proto import graph_pb2
from tensorboard.compat.proto import meta_graph_pb2
from tensorboard.compat.proto import summary_pb2
from tensorboard.util import tb_logging
from tensorboard.util import util
Expand Down Expand Up @@ -217,6 +219,23 @@ def add_session_log(self, session_log, global_step=None):
tf_session_log = session_log
super(FileWriter, self).add_session_log(tf_session_log, global_step)

def add_graph(self, graph, global_step=None, graph_def=None):
if isinstance(graph_def, graph_pb2.GraphDef):
tf_graph_def = tf.compat.v1.GraphDef.FromString(graph_def.SerializeToString())
else:
tf_graph_def = graph_def

super(FileWriter, self).add_graph(graph, global_step=global_step, graph_def=tf_graph_def)

def add_meta_graph(self, meta_graph_def, global_step=None):
if isinstance(meta_graph_def, meta_graph_pb2.MetaGraphDef):
tf_meta_graph_def = tf.compat.v1.MetaGraphDef.FromString(meta_graph_def.SerializeToString())
else:
tf_meta_graph_def = meta_graph_def

super(FileWriter, self).add_meta_graph(meta_graph_def=tf_meta_graph_def, global_step=global_step)


class FileWriterCache(object):
"""Cache for TensorBoard test file writers.
"""
Expand Down

0 comments on commit 5328bd9

Please sign in to comment.