From 5328bd919772853264b892c8eb8fb66f67744a52 Mon Sep 17 00:00:00 2001 From: Stephan Lee Date: Fri, 28 Dec 2018 15:43:10 -0800 Subject: [PATCH] Change TF graph proto to TB graph proto (#1730) TF 2.0 is removing symbols for GraphDef and MetaGraphDef protos. We are converting it to TB proto definitions instead. Relates to #1718. --- .../event_processing/event_accumulator.py | 8 ++++--- .../event_accumulator_test.py | 22 +++++++++++++++---- .../plugin_event_accumulator.py | 8 ++++--- .../plugin_event_accumulator_test.py | 22 +++++++++++++++---- tensorboard/plugins/core/BUILD | 1 + tensorboard/plugins/core/core_plugin_test.py | 8 ++++--- tensorboard/util/test_util.py | 19 ++++++++++++++++ 7 files changed, 71 insertions(+), 17 deletions(-) diff --git a/tensorboard/backend/event_processing/event_accumulator.py b/tensorboard/backend/event_processing/event_accumulator.py index 600cf4c69d7..f2c89e23561 100644 --- a/tensorboard/backend/event_processing/event_accumulator.py +++ b/tensorboard/backend/event_processing/event_accumulator.py @@ -28,6 +28,8 @@ from tensorboard.compat import tf from tensorboard.compat.proto import config_pb2 from tensorboard.compat.proto import event_pb2 +from tensorboard.compat.proto import graph_pb2 +from tensorboard.compat.proto import meta_graph_pb2 from tensorboard.plugins.distribution import compressor from tensorboard.util import tb_logging @@ -353,7 +355,7 @@ def _ProcessEvent(self, event): if self._graph is None or self._graph_from_metagraph: # We may have a graph_def in the metagraph. If so, and no # graph_def is directly available, use this one instead. - meta_graph = tf.compat.v1.MetaGraphDef() + meta_graph = meta_graph_pb2.MetaGraphDef() meta_graph.ParseFromString(self._meta_graph) if meta_graph.graph_def: if self._graph is not None: @@ -447,7 +449,7 @@ def Graph(self): Returns: The `graph_def` proto. """ - graph = tf.compat.v1.GraphDef() + graph = graph_pb2.GraphDef() if self._graph is not None: graph.ParseFromString(self._graph) return graph @@ -464,7 +466,7 @@ def MetaGraph(self): """ if self._meta_graph is None: raise ValueError('There is no metagraph in this EventAccumulator') - meta_graph = tf.compat.v1.MetaGraphDef() + meta_graph = meta_graph_pb2.MetaGraphDef() meta_graph.ParseFromString(self._meta_graph) return meta_graph diff --git a/tensorboard/backend/event_processing/event_accumulator_test.py b/tensorboard/backend/event_processing/event_accumulator_test.py index e03f2acb6d8..7a857e78e8f 100644 --- a/tensorboard/backend/event_processing/event_accumulator_test.py +++ b/tensorboard/backend/event_processing/event_accumulator_test.py @@ -27,6 +27,8 @@ from tensorboard.backend.event_processing import event_accumulator as ea from tensorboard.compat.proto import config_pb2 from tensorboard.compat.proto import event_pb2 +from tensorboard.compat.proto import graph_pb2 +from tensorboard.compat.proto import meta_graph_pb2 from tensorboard.compat.proto import summary_pb2 from tensorboard.plugins.distribution import compressor from tensorboard.util import tensor_util @@ -829,8 +831,14 @@ def FakeScalarSummary(tag, value): self.assertEqual(i * 5, sq_events[i].step) self.assertEqual(i, id_events[i].value) self.assertEqual(i * i, sq_events[i].value) - self.assertProtoEquals(graph.as_graph_def(add_shapes=True), acc.Graph()) - self.assertProtoEquals(meta_graph_def, acc.MetaGraph()) + + expected_graph_def = graph_pb2.GraphDef.FromString( + graph.as_graph_def(add_shapes=True).SerializeToString()) + self.assertProtoEquals(expected_graph_def, acc.Graph()) + + expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString( + meta_graph_def.SerializeToString()) + self.assertProtoEquals(expected_meta_graph, acc.MetaGraph()) def testGraphFromMetaGraphBecomesAvailable(self): """Test accumulator by writing values and then reading them.""" @@ -858,8 +866,14 @@ def testGraphFromMetaGraphBecomesAvailable(self): ea.GRAPH: True, ea.META_GRAPH: True, }) - self.assertProtoEquals(graph.as_graph_def(add_shapes=True), acc.Graph()) - self.assertProtoEquals(meta_graph_def, acc.MetaGraph()) + + expected_graph_def = graph_pb2.GraphDef.FromString( + graph.as_graph_def(add_shapes=True).SerializeToString()) + self.assertProtoEquals(expected_graph_def, acc.Graph()) + + expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString( + meta_graph_def.SerializeToString()) + self.assertProtoEquals(expected_meta_graph, acc.MetaGraph()) def _writeMetadata(self, logdir, summary_metadata, nonce=''): """Write to disk a summary with the given metadata. diff --git a/tensorboard/backend/event_processing/plugin_event_accumulator.py b/tensorboard/backend/event_processing/plugin_event_accumulator.py index e554e7e593a..08c5461dad8 100644 --- a/tensorboard/backend/event_processing/plugin_event_accumulator.py +++ b/tensorboard/backend/event_processing/plugin_event_accumulator.py @@ -31,6 +31,8 @@ from tensorboard.compat import tf from tensorboard.compat.proto import config_pb2 from tensorboard.compat.proto import event_pb2 +from tensorboard.compat.proto import graph_pb2 +from tensorboard.compat.proto import meta_graph_pb2 from tensorboard.util import tb_logging @@ -303,7 +305,7 @@ def _ProcessEvent(self, event): if self._graph is None or self._graph_from_metagraph: # We may have a graph_def in the metagraph. If so, and no # graph_def is directly available, use this one instead. - meta_graph = tf.compat.v1.MetaGraphDef() + meta_graph = meta_graph_pb2.MetaGraphDef() meta_graph.ParseFromString(self._meta_graph) if meta_graph.graph_def: if self._graph is not None: @@ -380,7 +382,7 @@ def Graph(self): Returns: The `graph_def` proto. """ - graph = tf.compat.v1.GraphDef() + graph = graph_pb2.GraphDef() if self._graph is not None: graph.ParseFromString(self._graph) return graph @@ -397,7 +399,7 @@ def MetaGraph(self): """ if self._meta_graph is None: raise ValueError('There is no metagraph in this EventAccumulator') - meta_graph = tf.compat.v1.MetaGraphDef() + meta_graph = meta_graph_pb2.MetaGraphDef() meta_graph.ParseFromString(self._meta_graph) return meta_graph diff --git a/tensorboard/backend/event_processing/plugin_event_accumulator_test.py b/tensorboard/backend/event_processing/plugin_event_accumulator_test.py index ad398dc59e6..20b388281e4 100644 --- a/tensorboard/backend/event_processing/plugin_event_accumulator_test.py +++ b/tensorboard/backend/event_processing/plugin_event_accumulator_test.py @@ -27,6 +27,8 @@ from tensorboard.backend.event_processing import plugin_event_accumulator as ea from tensorboard.compat.proto import config_pb2 from tensorboard.compat.proto import event_pb2 +from tensorboard.compat.proto import graph_pb2 +from tensorboard.compat.proto import meta_graph_pb2 from tensorboard.compat.proto import summary_pb2 from tensorboard.plugins.audio import summary as audio_summary from tensorboard.plugins.image import summary as image_summary @@ -620,8 +622,14 @@ def FakeScalarSummary(tag, value): self.assertEqual(i * 5, sq_events[i].step) self.assertEqual(i, tensor_util.make_ndarray(id_events[i].tensor_proto).item()) self.assertEqual(i * i, tensor_util.make_ndarray(sq_events[i].tensor_proto).item()) - self.assertProtoEquals(graph.as_graph_def(add_shapes=True), acc.Graph()) - self.assertProtoEquals(meta_graph_def, acc.MetaGraph()) + + expected_graph_def = graph_pb2.GraphDef.FromString( + graph.as_graph_def(add_shapes=True).SerializeToString()) + self.assertProtoEquals(expected_graph_def, acc.Graph()) + + expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString( + meta_graph_def.SerializeToString()) + self.assertProtoEquals(expected_meta_graph, acc.MetaGraph()) def testGraphFromMetaGraphBecomesAvailable(self): """Test accumulator by writing values and then reading them.""" @@ -649,8 +657,14 @@ def testGraphFromMetaGraphBecomesAvailable(self): ea.GRAPH: True, ea.META_GRAPH: True, }) - self.assertProtoEquals(graph.as_graph_def(add_shapes=True), acc.Graph()) - self.assertProtoEquals(meta_graph_def, acc.MetaGraph()) + + expected_graph_def = graph_pb2.GraphDef.FromString( + graph.as_graph_def(add_shapes=True).SerializeToString()) + self.assertProtoEquals(expected_graph_def, acc.Graph()) + + expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString( + meta_graph_def.SerializeToString()) + self.assertProtoEquals(expected_meta_graph, acc.MetaGraph()) def _writeMetadata(self, logdir, summary_metadata, nonce=''): """Write to disk a summary with the given metadata. diff --git a/tensorboard/plugins/core/BUILD b/tensorboard/plugins/core/BUILD index a954ce1afa7..39e5dbb5067 100644 --- a/tensorboard/plugins/core/BUILD +++ b/tensorboard/plugins/core/BUILD @@ -35,6 +35,7 @@ py_test( "//tensorboard:expect_tensorflow_installed", "//tensorboard/backend:application", "//tensorboard/backend/event_processing:event_multiplexer", + "//tensorboard/compat/proto:protos_all_py_pb2", "//tensorboard/plugins:base_plugin", "//tensorboard/util:test_util", "@org_pocoo_werkzeug", diff --git a/tensorboard/plugins/core/core_plugin_test.py b/tensorboard/plugins/core/core_plugin_test.py index f1246acb0e7..1907a0f6227 100644 --- a/tensorboard/plugins/core/core_plugin_test.py +++ b/tensorboard/plugins/core/core_plugin_test.py @@ -34,6 +34,8 @@ from tensorboard.backend import application from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer # pylint: disable=line-too-long +from tensorboard.compat.proto import graph_pb2 +from tensorboard.compat.proto import meta_graph_pb2 from tensorboard.plugins import base_plugin from tensorboard.plugins.core import core_plugin from tensorboard.util import test_util @@ -345,19 +347,19 @@ def _generate_test_data(self, run_name, experiment_name): with test_util.FileWriterCache.get(run_path) as writer: # Add a simple graph event. - graph_def = tf.compat.v1.GraphDef() + graph_def = graph_pb2.GraphDef() node1 = graph_def.node.add() node1.name = 'a' node2 = graph_def.node.add() node2.name = 'b' node2.attr['very_large_attr'].s = b'a' * 2048 # 2 KB attribute - meta_graph_def = tf.compat.v1.MetaGraphDef(graph_def=graph_def) + meta_graph_def = meta_graph_pb2.MetaGraphDef(graph_def=graph_def) if self._only_use_meta_graph: writer.add_meta_graph(meta_graph_def) else: - writer.add_graph(graph_def) + writer.add_graph(graph=None, graph_def=graph_def) # Write data for the run to the database. # TODO(nickfelt): Figure out why reseting the graph is necessary. diff --git a/tensorboard/util/test_util.py b/tensorboard/util/test_util.py index f4562d7106f..36fd70545ae 100644 --- a/tensorboard/util/test_util.py +++ b/tensorboard/util/test_util.py @@ -33,6 +33,8 @@ from tensorboard import db from tensorboard.compat.proto import event_pb2 +from tensorboard.compat.proto import graph_pb2 +from tensorboard.compat.proto import meta_graph_pb2 from tensorboard.compat.proto import summary_pb2 from tensorboard.util import tb_logging from tensorboard.util import util @@ -217,6 +219,23 @@ def add_session_log(self, session_log, global_step=None): tf_session_log = session_log super(FileWriter, self).add_session_log(tf_session_log, global_step) + def add_graph(self, graph, global_step=None, graph_def=None): + if isinstance(graph_def, graph_pb2.GraphDef): + tf_graph_def = tf.compat.v1.GraphDef.FromString(graph_def.SerializeToString()) + else: + tf_graph_def = graph_def + + super(FileWriter, self).add_graph(graph, global_step=global_step, graph_def=tf_graph_def) + + def add_meta_graph(self, meta_graph_def, global_step=None): + if isinstance(meta_graph_def, meta_graph_pb2.MetaGraphDef): + tf_meta_graph_def = tf.compat.v1.MetaGraphDef.FromString(meta_graph_def.SerializeToString()) + else: + tf_meta_graph_def = meta_graph_def + + super(FileWriter, self).add_meta_graph(meta_graph_def=tf_meta_graph_def, global_step=global_step) + + class FileWriterCache(object): """Cache for TensorBoard test file writers. """