Skip to content

Commit

Permalink
[Pass] Check in infershape, move indexedgraph to graph.h (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 29, 2018
1 parent 5629330 commit 9135bc0
Show file tree
Hide file tree
Showing 9 changed files with 307 additions and 112 deletions.
149 changes: 149 additions & 0 deletions nnvm/include/nnvm/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

namespace nnvm {

class IndexedGraph;

/*!
* \brief Symbolic computation graph.
* This is the intermediate representation for optimization pass.
Expand All @@ -32,6 +34,145 @@ class Graph {
* and can be shared across multiple Instance of graph
*/
std::unordered_map<std::string, std::shared_ptr<const any> > attrs;
/*!
* \brief Get the attribute from attrs.
* \param attr_name the name of the attribute
* \return the reference to corresponding attribute
* \tparam T the type of the attribute.
*/
template<typename T>
inline const T& GetAttr(const std::string& attr_name);
/*!
* \brief get a indexed graph of current graph, if not exist, create it on demand
* \return The indexed graph.
* \sa IndexedGraph
*/
const IndexedGraph& indexed_graph();

private:
// internal structure of indexed graph
std::shared_ptr<const IndexedGraph> indexed_graph_;
};

/*!
* \brief Auxililary data structure to index a graph.
* It maps Nodes in the graph to consecutive integers node_id.
* It also maps IndexedGraph::NodeEntry to consecutive integer entry_id.
* This allows storing properties of Node and NodeEntry into
* compact vector and quickly access them without resorting to hashmap.
*
* The node_id and entry_rptr are the same as the JSON graph produced by SaveJSON Pass.
*/
class IndexedGraph {
public:
/*! \brief represents a data in the graph */
struct NodeEntry {
/*! \brief the source node id in the computation graph */
uint32_t node_id;
/*! \brief index of output from the source. */
uint32_t index;
/*!
* \brief compare equality
* \param other the other entry to compare
* \return whether two entries equals to each other
*/
inline bool operator==(const NodeEntry& other) const {
return node_id == other.node_id && index == other.index;
}
};
/*! \brief Node data structure in IndexedGraph */
struct Node {
/*! \brief pointer to the source node */
const nnvm::Node* source;
/*! \brief inputs to the node */
array_view<NodeEntry> inputs;
/*! \brief control flow dependencies to the node */
array_view<uint32_t> control_deps;
};
/*! \return number of nodes in the graph */
inline size_t num_nodes() const {
return nodes_.size();
}
/*! \return total number of NodeEntry in the graph */
inline size_t num_node_entries() const {
return entry_rptr_.back();
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given IndexedGraph::NodeEntry
* \param node_id The node index
* \param index the output index
* \return the unique index.
*/
inline uint32_t entry_id(uint32_t node_id, uint32_t index) const {
return entry_rptr_[node_id] + index;
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given IndexedGraph::NodeEntry
* \param e The entry to query for index.
* \return the unique index.
*/
inline uint32_t entry_id(const NodeEntry& e) const {
return entry_rptr_[e.node_id] + e.index;
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given NodeEntry.
* \param e The entry to query for index.
* \return the unique index.
*/
inline uint32_t entry_id(const nnvm::NodeEntry& e) const {
return entry_rptr_[node_id(e.node.get())] + e.index;
}
/*!
* \brief Get the corresponding node id for a given Node in the IndexedGraph.
* \param node The Node to query for index.
* \return the node index.
*/
inline uint32_t node_id(const nnvm::Node* node) const {
return node2index_.at(node);
}
/*!
* \brief Get the corresponding Node structure for a given node_id.
* \param node_id The node id
* \return const reference to the corresponding IndexedGraph::Node
*/
inline const Node& operator[](uint32_t node_id) const {
return nodes_[node_id];
}
/*!
* \brief Get the corresponding Node structure
* \param node The pointer to the Node structure
* \return const reference to the corresponding IndexedGraph::Node
*/
inline const Node& operator[](const nnvm::Node* node) const {
return nodes_[node_id(node)];
}
/*! \return list of argument nodes */
inline const std::vector<uint32_t>& arg_nodes() const {
return arg_nodes_;
}

private:
friend class Graph;
/*!
* \brief Constructor an IndexedGraph from normal Graph
* \param other The source graph.
*/
explicit IndexedGraph(const Graph& other);
// node pointers in CSR structure.
std::vector<Node> nodes_;
// index to argument nodes
std::vector<uint32_t> arg_nodes_;
// mapping from node to index.
std::unordered_map<const nnvm::Node*, uint32_t> node2index_;
// CSR pointer of node entries
std::vector<size_t> entry_rptr_;
// space to store input entries of each
std::vector<NodeEntry> input_entries_;
// control flow dependencies
std::vector<uint32_t> control_deps_;
};

/*!
Expand All @@ -45,6 +186,14 @@ template<typename FVisit>
inline void DFSVisit(const std::vector<NodeEntry>& heads, FVisit fvisit);

// inline function implementations
template<typename T>
inline const T& Graph::GetAttr(const std::string& attr_name) {
auto it = attrs.find(attr_name);
CHECK(it != attrs.end())
<< "Cannot find attribute " << attr_name << " in the graph";
return nnvm::get<T>(*it->second);
}

template <typename GNode, typename HashType,
typename FVisit, typename HashFunc,
typename InDegree, typename GetInput>
Expand Down
133 changes: 25 additions & 108 deletions nnvm/include/nnvm/graph_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,120 +7,37 @@
#define NNVM_GRAPH_ATTR_TYPES_H_

#include <vector>
#include <unordered_map>
#include "./graph.h"
#include <string>
#include "./tuple.h"

namespace nnvm {

/*!
* \brief Auxililary data structure to index a graph.
* It maps Nodes in the graph to consecutive integers node_id.
* It also maps IndexedGraph::NodeEntry to consecutive integer entry_id.
* This allows storing properties of Node and NodeEntry into
* compact vector and quickly access them without resorting to hashmap.
* \brief The result holder of JSON serializer
*
* \note Stored under ret.attrs["json"], provided by Pass "SaveJSON"
* \code
* Graph ret = ApplyPass(src_graph, {"SaveJSON"});
* const JSONString& json = ret.GetAttr<JSONString>("shape");
* \endcode
*/
struct IndexedGraph {
public:
/*! \brief represents a data in the graph */
struct NodeEntry {
/*! \brief the source node id in the computation graph */
uint32_t node_id;
/*! \brief index of output from the source. */
uint32_t index;
/*!
* \brief compare equality
* \param other the other entry to compare
* \return whether two entries equals to each other
*/
inline bool operator==(const NodeEntry& other) const {
return node_id == other.node_id && index == other.index;
}
};
/*! \brief Node data structure in IndexedGraph */
struct Node {
/*! \brief pointer to the source node */
const nnvm::Node* source;
/*! \brief inputs to the node */
array_view<NodeEntry> inputs;
/*! \brief control flow dependencies to the node */
array_view<uint32_t> control_deps;
};
/*! \return number of nodes in the graph */
inline size_t num_nodes() const {
return nodes_.size();
}
/*! \return total number of NodeEntry in the graph */
inline size_t num_node_entries() const {
return entry_rptr_.back();
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given IndexedGraph::NodeEntry
* \param e The entry to query for index.
* \return the unique index.
*/
inline uint32_t entry_id(const NodeEntry& e) const {
return entry_rptr_[e.node_id] + e.index;
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given NodeEntry.
* \param e The entry to query for index.
* \return the unique index.
*/
inline uint32_t entry_id(const nnvm::NodeEntry& e) const {
return entry_rptr_[node_id(e.node.get())] + e.index;
}
/*!
* \brief Get the corresponding node id for a given Node in the IndexedGraph.
* \param node The Node to query for index.
* \return the node index.
*/
inline uint32_t node_id(const nnvm::Node* node) const {
return node2index_.at(node);
}
/*!
* \brief Get the corresponding Node structure for a given node_id.
* \param node_id The node id
* \return const reference to the corresponding IndexedGraph::Node
*/
inline const Node& operator[](uint32_t node_id) const {
return nodes_[node_id];
}
/*!
* \brief Get the corresponding Node structure
* \param node The pointer to the Node structure
* \return const reference to the corresponding IndexedGraph::Node
*/
inline const Node& operator[](const nnvm::Node* node) const {
return nodes_[node_id(node)];
}
/*! \return list of argument nodes */
inline const std::vector<uint32_t>& arg_nodes() const {
return arg_nodes_;
}
/*!
* \brief Constructor an IndexedGraph from normal Graph
* \param other The source graph.
*/
explicit IndexedGraph(const Graph& other);
// disallow copy assign
IndexedGraph(const IndexedGraph& other) = delete;
using JSONString = std::string;

private:
// node pointers in CSR structure.
std::vector<Node> nodes_;
// index to argument nodes
std::vector<uint32_t> arg_nodes_;
// mapping from node to index.
std::unordered_map<const nnvm::Node*, uint32_t> node2index_;
// CSR pointer of node entries
std::vector<size_t> entry_rptr_;
// space to store input entries of each
std::vector<NodeEntry> input_entries_;
// control flow dependencies
std::vector<uint32_t> control_deps_;
};
/*!
* \brief The result holder of shape of each NodeEntry in the graph.
* \note Stored under graph.attrs["shape"], provided by Pass "InferShape"
*
* \code
* Graph g = ApplyPass(src_graph, {"InferShape"});
* const ShapeVector& shapes = g.GetAttr<ShapeVector>("shape");
* // get shape by entry id
* TShape entry_shape = shapes[g.indexed_graph().entry_id(my_entry)];
* \endcode
*
* \sa FInferShape
*/
using ShapeVector = std::vector<TShape>;

} // namespace nnvm

Expand Down
23 changes: 23 additions & 0 deletions nnvm/include/nnvm/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <vector>
#include <string>
#include <functional>
#include "./base.h"
#include "./tuple.h"

namespace nnvm {

Expand Down Expand Up @@ -39,6 +41,7 @@ using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs

/*!
* \brief Check whether operator will mutate k-th input.
* \param attrs The attributes of the node.
* \param index The input index
* \return Whether this operator will mutate index-th input.
*
Expand All @@ -47,6 +50,26 @@ using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs
*/
using FMutateInput = std::function<bool (const NodeAttrs& attrs, uint32_t index)>;

/*!
* \brief Shape inference function.
* Update the shapes given the input shape information.
* TShape.ndim() == 0 means the shape is still unknown.
*
* \param attrs The attributes of the node.
* \param in_shapes Array of shapes from the inputs.
* \param out_shapes Array of shapes from the outputs.
*
* \return Whether all the shapes are known.
*
* \note Register under "FInferShape",
* by default do not update any shapes.
*
* FInferShape is needed by shape inference
*/
using FInferShape = std::function<bool (const NodeAttrs& attrs,
array_view<TShape*> in_shapes,
array_view<TShape*> out_shapes)>;

} // namespace nnvm

#endif // NNVM_OP_ATTR_TYPES_H_
Loading

0 comments on commit 9135bc0

Please sign in to comment.