Skip to content

Commit

Permalink
removes state vector-wrapper in favor of vanilla array
Browse files Browse the repository at this point in the history
  • Loading branch information
elstehle committed Jul 13, 2022
1 parent 485a1c6 commit 5f1c4b5
Showing 1 changed file with 51 additions and 101 deletions.
152 changes: 51 additions & 101 deletions cpp/src/io/fst/agent_dfa.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,54 +21,30 @@

namespace cudf::io::fst::detail {

//-----------------------------------------------------------------------------
// STATE VECTOR
//-----------------------------------------------------------------------------
/**
* @brief A vector is able to hold multiple state indices (e.g., to represent multiple DFA
* instances, where the i-th item would represent the i-th DFA instance).
*
* @tparam StateIndexT Signed or unsigned type used to index items inside the vector
* @tparam NUM_ITEMS The number of items to be allocated for a vector
*/
template <typename StateIndexT, int32_t NUM_ITEMS>
class MultiItemStateVector {
public:
template <typename IndexT>
__host__ __device__ __forceinline__ void Set(IndexT index, StateIndexT value) noexcept
{
state_[index] = value;
}

template <typename IndexT>
__host__ __device__ __forceinline__ StateIndexT Get(IndexT index) const noexcept
{
return state_[index];
}

private:
StateIndexT state_[NUM_ITEMS];
};

//-----------------------------------------------------------------------------
// DFA-SIMULATION STATE COMPOSITION FUNCTORS
//-----------------------------------------------------------------------------
/**
* @brief Implements an associative composition operation for state transition vectors and
* offset-to-overap vectors to be used with a prefix scan.
*
* Read the following table as follows: c = op(l,r), where op is the composition operator.
* For row 0: l maps 0 to 2. r maps 2 to 2. Hence, the result for 0 is 2.
* For row 1: l maps 1 to 1. r maps 1 to 2. Hence, the result for 1 is 2.
* For row 2: l maps 2 to 0. r maps 0 to 1. Hence, the result for 2 is 1.
*
* l r = c ( s->l->r)
* 0: [2] [1] [2] (i.e. 0->2->2)
* 1: [1] [2] [2] (i.e. 1->1->2)
* 2: [0] [2] [1] (i.e. 2->0->2)
* 2: [0] [2] [1] (i.e. 2->0->1)
* @tparam NUM_ITEMS The number of items stored within a vector
*/
template <int32_t NUM_ITEMS>
struct VectorCompositeOp {
template <typename VectorT>
__host__ __device__ __forceinline__ VectorT operator()(VectorT const& lhs, VectorT const& rhs)
{
VectorT res;
VectorT res{};
for (int32_t i = 0; i < NUM_ITEMS; ++i) {
res.Set(i, rhs.Get(lhs.Get(i)));
}
Expand All @@ -95,16 +71,16 @@ class DFASimulationCallbackWrapper {
if (!write) out_count = 0;
}

template <typename CharIndexT, typename StateVectorT, typename SymbolIndexT>
__host__ __device__ __forceinline__ void ReadSymbol(CharIndexT const& character_index,
StateVectorT const& old_state,
StateVectorT const& new_state,
SymbolIndexT const& symbol_id)
template <typename CharIndexT, typename StateIndexT, typename SymbolIndexT>
__host__ __device__ __forceinline__ void ReadSymbol(CharIndexT const character_index,
StateIndexT const old_state,
StateIndexT const new_state,
SymbolIndexT const symbol_id)
{
uint32_t const count = transducer_table(old_state.Get(0), symbol_id);
uint32_t const count = transducer_table(old_state, symbol_id);
if (write) {
for (uint32_t out_char = 0; out_char < count; out_char++) {
out_it[out_count + out_char] = transducer_table(old_state.Get(0), symbol_id, out_char);
out_it[out_count + out_char] = transducer_table(old_state, symbol_id, out_char);
out_idx_it[out_count + out_char] = offset + character_index;
}
}
Expand All @@ -125,22 +101,11 @@ class DFASimulationCallbackWrapper {
//-----------------------------------------------------------------------------
// STATE-TRANSITION CALLBACKS
//-----------------------------------------------------------------------------
class StateTransitionCallbackOp {
template <int32_t NUM_INSTANCES, typename TransitionTableT>
class StateVectorTransitionOp {
public:
template <typename CharIndexT, typename SymbolIndexT>
__host__ __device__ __forceinline__ void ReadSymbol(CharIndexT const& character_index,
SymbolIndexT const& read_symbol_id) const
{
}
};
/// Type alias for a state transition callback class that performs no operation on any callback
using NoOpStateTransitionOp = StateTransitionCallbackOp;

template <int32_t NUM_INSTANCES, typename StateVectorT, typename TransitionTableT>
class StateVectorTransitionOp : public StateTransitionCallbackOp {
public:
__host__ __device__ __forceinline__
StateVectorTransitionOp(TransitionTableT const& transition_table, StateVectorT& state_vector)
__host__ __device__ __forceinline__ StateVectorTransitionOp(
TransitionTableT const& transition_table, std::array<int32_t, NUM_INSTANCES>& state_vector)
: transition_table(transition_table), state_vector(state_vector)
{
}
Expand All @@ -150,39 +115,37 @@ class StateVectorTransitionOp : public StateTransitionCallbackOp {
SymbolIndexT const read_symbol_id) const
{
for (int32_t i = 0; i < NUM_INSTANCES; ++i) {
state_vector.Set(i, transition_table(state_vector.Get(i), read_symbol_id));
state_vector[i] = transition_table(state_vector[i], read_symbol_id);
}
}

public:
StateVectorT& state_vector;
std::array<int32_t, NUM_INSTANCES>& state_vector;
const TransitionTableT& transition_table;
};

template <typename CallbackOpT, typename StateVectorT, typename TransitionTableT>
template <typename CallbackOpT, typename TransitionTableT>
struct StateTransitionOp {
StateVectorT old_state_vector;
StateVectorT state_vector;
int32_t state;
const TransitionTableT& transition_table;
CallbackOpT& callback_op;

__host__ __device__ __forceinline__ StateTransitionOp(const TransitionTableT& transition_table,
StateVectorT state_vector,
__host__ __device__ __forceinline__ StateTransitionOp(TransitionTableT const& transition_table,
int32_t state,
CallbackOpT& callback_op)
: transition_table(transition_table),
state_vector(state_vector),
old_state_vector(state_vector),
callback_op(callback_op)
: transition_table(transition_table), state(state), callback_op(callback_op)
{
}

template <typename CharIndexT, typename SymbolIndexT>
__host__ __device__ __forceinline__ void ReadSymbol(const CharIndexT& character_index,
const SymbolIndexT& read_symbol_id)
__host__ __device__ __forceinline__ void ReadSymbol(CharIndexT const& character_index,
SymbolIndexT const& read_symbol_id)
{
old_state_vector = state_vector;
state_vector.Set(0, transition_table(state_vector.Get(0), read_symbol_id));
callback_op.ReadSymbol(character_index, old_state_vector, state_vector, read_symbol_id);
// Remember what state we were in before we made the transition
int32_t previous_state = state;

state = transition_table(state, read_symbol_id);
callback_op.ReadSymbol(character_index, previous_state, state, read_symbol_id);
}
};

Expand Down Expand Up @@ -237,9 +200,6 @@ struct AgentDFA {
{
}

//---------------------------------------------------------------------
// STATIC PARSING PRIMITIVES
//---------------------------------------------------------------------
template <int32_t NUM_SYMBOLS, // The net (excluding overlap) number of characters to be parsed
typename SymbolMatcherT, // The symbol matcher returning the matched symbol and its
// length
Expand All @@ -251,13 +211,11 @@ struct AgentDFA {
CallbackOpT callback_op,
cub::Int2Type<IS_FULL_BLOCK> /*IS_FULL_BLOCK*/)
{
uint32_t matched_id;

// Iterate over symbols
#pragma unroll
for (int32_t i = 0; i < NUM_SYMBOLS; ++i) {
if (IS_FULL_BLOCK || threadIdx.x * SYMBOLS_PER_THREAD + i < max_num_chars) {
matched_id = symbol_matcher(chars[i]);
uint32_t matched_id = symbol_matcher(chars[i]);
callback_op.ReadSymbol(i, matched_id);
}
}
Expand Down Expand Up @@ -400,20 +358,16 @@ struct AgentDFA {
}
}

template <int32_t NUM_STATES,
typename SymbolMatcherT,
typename TransitionTableT,
typename StateVectorT>
template <int32_t NUM_STATES, typename SymbolMatcherT, typename TransitionTableT>
__device__ __forceinline__ void GetThreadStateTransitionVector(
const SymbolMatcherT& symbol_matcher,
const TransitionTableT& transition_table,
const CharT* d_chars,
const OffsetT block_offset,
const OffsetT num_total_symbols,
StateVectorT& state_vector)
std::array<int32_t, NUM_STATES>& state_vector)
{
using StateVectorTransitionOpT =
StateVectorTransitionOp<NUM_STATES, StateVectorT, TransitionTableT>;
using StateVectorTransitionOpT = StateVectorTransitionOp<NUM_STATES, TransitionTableT>;

// Start parsing and to transition states
StateVectorTransitionOpT transition_op(transition_table, state_vector);
Expand All @@ -439,29 +393,26 @@ struct AgentDFA {
GetThreadStateTransitions<SYMBOLS_PER_THREAD>(
symbol_matcher, t_chars, num_block_chars, transition_op, cub::Int2Type<false>());
}

// transition_op.TearDown();
}

template <int32_t BYPASS_LOAD,
typename SymbolMatcherT,
typename TransitionTableT,
typename StateVectorT,
typename CallbackOpT>
__device__ __forceinline__ void GetThreadStateTransitions(
SymbolMatcherT const& symbol_matcher,
TransitionTableT const& transition_table,
CharT const* d_chars,
OffsetT const block_offset,
OffsetT const num_total_symbols,
StateVectorT& state_vector,
int32_t& state,
CallbackOpT& callback_op,
cub::Int2Type<BYPASS_LOAD> /**/)
{
using StateTransitionOpT = StateTransitionOp<CallbackOpT, StateVectorT, TransitionTableT>;
using StateTransitionOpT = StateTransitionOp<CallbackOpT, TransitionTableT>;

// Start parsing and to transition states
StateTransitionOpT transition_op(transition_table, state_vector, callback_op);
StateTransitionOpT transition_op(transition_table, state, callback_op);

// Load characters into shared memory
if (!BYPASS_LOAD) LoadBlock(d_chars, block_offset, num_total_symbols);
Expand Down Expand Up @@ -528,7 +479,7 @@ __launch_bounds__(int32_t(AgentDFAPolicy::BLOCK_THREADS)) __global__
SYMBOLS_PER_BLOCK = AgentDfaSimT::SYMBOLS_PER_BLOCK
};

// Shared memory required by the DFA simulator
// Shared memory required by the DFA simulation algorithm
__shared__ typename AgentDfaSimT::TempStorage dfa_storage;

// Shared memory required by the symbol group lookup table
Expand All @@ -552,18 +503,18 @@ __launch_bounds__(int32_t(AgentDFAPolicy::BLOCK_THREADS)) __global__
// Set up DFA
AgentDfaSimT agent_dfa(dfa_storage);

// Memory is the state transition vector passed on to the second stage of the algorithm
// The state transition vector passed on to the second stage of the algorithm
StateVectorT out_state_vector;

// Stage 1: Compute the state-transition vector
if (IS_TRANS_VECTOR_PASS || IS_SINGLE_PASS) {
// StateVectorT state_vector;
MultiItemStateVector<int32_t, NUM_STATES> state_vector;
// Keeping track of the state for each of the <NUM_STATES> state machines
std::array<int32_t, NUM_STATES> state_vector;

// Initialize the seed state transition vector with the identity vector
#pragma unroll
for (int32_t i = 0; i < NUM_STATES; ++i) {
state_vector.Set(i, i);
state_vector[i] = i;
}

// Compute the state transition vector
Expand All @@ -577,18 +528,18 @@ __launch_bounds__(int32_t(AgentDFAPolicy::BLOCK_THREADS)) __global__
// Initialize the state transition vector passed on to the second stage
#pragma unroll
for (int32_t i = 0; i < NUM_STATES; ++i) {
out_state_vector.Set(i, state_vector.Get(i));
out_state_vector.Set(i, state_vector[i]);
}

// Write out state-transition vector
if (!IS_SINGLE_PASS) {
d_thread_state_transition[blockIdx.x * BLOCK_THREADS + threadIdx.x] = out_state_vector;
}
}

// Stage 2: Perform FSM simulation
if ((!IS_TRANS_VECTOR_PASS) || IS_SINGLE_PASS) {
constexpr uint32_t SINGLE_ITEM_COUNT = 1;
MultiItemStateVector<int32_t, SINGLE_ITEM_COUNT> state;
int32_t state = 0;

//------------------------------------------------------------------------------
// SINGLE-PASS:
Expand Down Expand Up @@ -637,10 +588,9 @@ __launch_bounds__(int32_t(AgentDFAPolicy::BLOCK_THREADS)) __global__
.ExclusiveScan(out_state_vector, out_state_vector, state_vector_scan_op, prefix_op);
}
__syncthreads();
state.Set(0, out_state_vector.Get(seed_state));
state = out_state_vector.Get(seed_state);
} else {
state.Set(
0, d_thread_state_transition[blockIdx.x * BLOCK_THREADS + threadIdx.x].Get(seed_state));
state = d_thread_state_transition[blockIdx.x * BLOCK_THREADS + threadIdx.x].Get(seed_state);
}

// Perform finite-state machine simulation, computing size of transduced output
Expand All @@ -649,8 +599,7 @@ __launch_bounds__(int32_t(AgentDFAPolicy::BLOCK_THREADS)) __global__
TransducedIndexOutItT>
callback_wrapper(transducer_table, transduced_out_it, transduced_out_idx_it);

MultiItemStateVector<int32_t, SINGLE_ITEM_COUNT> t_start_state;
t_start_state.Set(0, state.Get(seed_state));
int32_t t_start_state = state;
agent_dfa.GetThreadStateTransitions(symbol_matcher,
transition_table,
d_chars,
Expand All @@ -661,6 +610,7 @@ __launch_bounds__(int32_t(AgentDFAPolicy::BLOCK_THREADS)) __global__
cub::Int2Type<IS_SINGLE_PASS>());

__syncthreads();

using OffsetPrefixScanCallbackOpT_ =
cub::TilePrefixCallbackOp<OffsetT, cub::Sum, OutOffsetScanTileState>;

Expand Down

0 comments on commit 5f1c4b5

Please sign in to comment.