Skip to content

Commit

Permalink
Merge pull request #1487 from SpiNNakerManchester/wta_connector
Browse files Browse the repository at this point in the history
Wta connector
  • Loading branch information
rowleya authored Sep 24, 2024
2 parents 3a4996c + 8852ee2 commit 5f8528d
Show file tree
Hide file tree
Showing 8 changed files with 684 additions and 5 deletions.
8 changes: 7 additions & 1 deletion neural_modelling/src/synapse_expander/connection_generator.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "connection_generators/connection_generator_fixed_pre.h"
#include "connection_generators/connection_generator_fixed_post.h"
#include "connection_generators/connection_generator_kernel.h"
#include "connection_generators/connection_generator_all_but_me.h"

//! \brief Known "hashes" of connection generators
//!
Expand All @@ -43,6 +44,7 @@ enum {
FIXED_PRE, //!< Fixed pre-size connection generator
FIXED_POST, //!< Fixed post-size connection generator
KERNEL, //!< Convolution kernel connection generator
ALL_BUT_ME, //!< AllButMe connection generator
N_CONNECTION_GENERATORS//!< The number of known generators
};

Expand Down Expand Up @@ -96,7 +98,11 @@ static const connection_generator_info connection_generators[] = {
{KERNEL,
connection_generator_kernel_initialise,
connection_generator_kernel_generate,
connection_generator_kernel_free}
connection_generator_kernel_free},
{ALL_BUT_ME,
connection_generator_all_but_me_initialise,
connection_generator_all_but_me_generate,
connection_generator_all_but_me_free}
};

connection_generator_t connection_generator_init(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
/*
* Copyright (c) 2024 The University of Manchester
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/**
* \dir
* \brief Connection generators
* \file
* \brief All But Me connection generator implementation
*/

#include <stdbool.h>
#include <synapse_expander/generator_types.h>

//! \brief The parameters to be passed around for this connector
struct all_but_me_conf {
// How many values there are in each WTA group
uint32_t n_values;

// Whether there are weight values specified or not
uint32_t has_weights;

// The weight values if specified.
// If so, there must be (n_values * n_values - 1) weights
accum weights[];
};

//! \brief The parameters to be passed around for this connector
struct all_but_me {
// How many neurons there are in each WTA group
uint32_t n_neurons_per_group;

// The weight values if specified.
// If so, there must be (n_values * n_values - 1) weights
accum *weights;
};


/**
* \brief Initialise the all but me connection generator
* \param[in,out] region: Region to read parameters from. Should be updated
* to position just after parameters after calling.
* \return A data item to be passed in to other functions later on
*/
static void *connection_generator_all_but_me_initialise(void **region) {
// Get the SDRAM params
struct all_but_me_conf *params_sdram = *region;

// Allocate the data structure for parameters
struct all_but_me *params = spin1_malloc(sizeof(struct all_but_me));

// Copy the parameters
params->n_neurons_per_group = params_sdram->n_values;
if (params_sdram->has_weights) {
uint32_t n_per_group = params->n_neurons_per_group;
uint32_t weight_size = n_per_group * (n_per_group - 1) * sizeof(accum);
params->weights = spin1_malloc(weight_size);
if (params->weights == NULL) {
// If we can't copy, just reference the SDRAM
params->weights = &params_sdram->weights[0];
} else {
spin1_memcpy(&params->weights[0], &params_sdram->weights[0], weight_size);
}
*region = &params_sdram->weights[n_per_group * (n_per_group - 1)];
} else {
params->weights = NULL;
*region = &params_sdram->weights[0];
}

log_info("allButMe connector, n_values = %u, has_weights = %u", params->n_neurons_per_group,
params_sdram->has_weights);

return params;
}

/**
* \brief Free the All But Me connection generator
* \param[in] generator: The generator to free
*/
static void connection_generator_all_but_me_free(void *generator) {
sark_free(generator);
}

static inline bool make_all_but_me_conn(accum weight,
param_generator_t delay_generator, matrix_generator_t matrix_generator,
uint32_t pre, uint32_t post, unsigned long accum weight_scale,
accum timestep_per_delay) {
uint16_t delay = rescale_delay(
param_generator_generate(delay_generator), timestep_per_delay);
if (!matrix_generator_write_synapse(matrix_generator, pre, post,
weight, delay, weight_scale)) {
log_error("Matrix not sized correctly!");
return false;
}
return true;
}

static inline void div_mod(uint32_t dividend, uint32_t divisor, uint32_t *div,
uint32_t *mod) {
uint32_t remainder = dividend;
uint32_t count = 0;
while (remainder >= divisor) {
remainder -= divisor;
count++;
}
*div = count;
*mod = remainder;
}

/**
* Get the weight for a given pre *value* and post *value*.
*/
static inline accum get_weight(struct all_but_me *obj,
param_generator_t weight_generator, uint32_t pre_value,
uint32_t post_value) {
// Get the post position rather than the post value. Because each "row" in
// the table has the diagonal removed, we need to adjust where we get the
// value from depending on the relative pre and post values (which must not
// be the same - this isn't checked here though).
uint32_t post_pos = post_value;
if (post_value >= pre_value) {
post_pos -= 1;
}
if (obj->weights != NULL) {
uint32_t weight_index = (pre_value * (obj->n_neurons_per_group - 1)) + post_pos;
return obj->weights[weight_index];
} else {
return param_generator_generate(weight_generator);
}
}

/**
* \brief Generate connections with the all but me connection generator
* \param[in] generator: The generator to use to generate connections
* \param[in] pre_slice_start: The start of the slice of the pre-population
* being generated
* \param[in] pre_slice_count: The number of neurons in the slice of the
* pre-population being generated
* \param[in] post_slice_start: The start of the slice of the post-population
* being generated
* \param[in] post_slice_count: The number of neurons in the slice of the
* post-population being generated
*/
static bool connection_generator_all_but_me_generate(
void *generator, uint32_t pre_lo, uint32_t pre_hi,
uint32_t post_lo, uint32_t post_hi, UNUSED uint32_t post_index,
uint32_t post_slice_start, uint32_t post_slice_count,
unsigned long accum weight_scale, accum timestep_per_delay,
param_generator_t weight_generator, param_generator_t delay_generator,
matrix_generator_t matrix_generator) {
struct all_but_me *obj = generator;

// Get the actual ranges to generate within
uint32_t post_start = max(post_slice_start, post_lo);
uint32_t post_end = min(post_slice_start + post_slice_count - 1, post_hi);

// Work out where we are in the generation
// We need to connect each pre-neuron to each post-neuron in each group
// (but not to itself). We are currently generating a subset of the post
// neurons, so we need to work out which group we are in within that subset,
// and which is the first post-neuron in the group that we are generating
// for now.
uint32_t post_group;
uint32_t post_value;
div_mod(post_start, obj->n_neurons_per_group, &post_group, &post_value);

// Work out where the pre-neurons start and end for the group that we are
// in at the start of the post-neurons. The group might not have enough
// neurons in it, so we check just in case.
uint32_t pre_start = pre_lo + post_group * obj->n_neurons_per_group;
uint32_t pre_end = min(pre_start + obj->n_neurons_per_group, pre_hi + 1);
uint32_t n_values = pre_end - pre_start;

// Go through the post neurons in this slice
for (uint32_t post = post_start; post <= post_end; post++) {
uint32_t local_post = post - post_slice_start;

// Go through each of the "values" in this group that can target this
// post neuron (each of which is a pre-neuron)
for (uint32_t pre_value = 0; pre_value < n_values; pre_value++) {
if (pre_value != post_value) {
uint32_t pre = pre_start + pre_value;
accum weight = get_weight(obj, weight_generator, pre_value, post_value);
if (!make_all_but_me_conn(weight, delay_generator,
matrix_generator, pre, local_post, weight_scale,
timestep_per_delay)) {
return false;
}
}
}

// Work out next loop iteration. If we have reached the end of a group
// of values, we need to move onto the next group.
post_value += 1;
if (post_value == obj->n_neurons_per_group) {
post_value = 0;
pre_start += obj->n_neurons_per_group;
pre_end = min(pre_start + obj->n_neurons_per_group, pre_hi + 1);
if (pre_start >= pre_hi) {
break;
}
n_values = pre_end - pre_start;
}
}

return true;
}
2 changes: 1 addition & 1 deletion neural_modelling/src/synapse_expander/type_writers.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ static type_info type_writers[] = {
};

static type_info *get_type_writer(type t) {
if (t < 0 || t >= sizeof(type_writers) / sizeof(*type_writers)) {
if (t >= sizeof(type_writers) / sizeof(*type_writers)) {
// Bogus index is bad! And otherwise hard to debug!
log_error("type id=%u is outside sane range", t);
rt_error(RTE_SWERR);
Expand Down
7 changes: 6 additions & 1 deletion spynnaker/pyNN/extra_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
IFCurrExpSEMDBase as IF_curr_exp_sEMD,
IFCurrDeltaCa2Adaptive, StocExp, StocExpStable, StocSigma, IFTruncDelta,
IFCurrDeltaFixedProb)
from spynnaker.pyNN.models.neural_projections.connectors import (
AllButMeConnector)

# Variable rate poisson
from spynnaker.pyNN.models.spike_source import SpikeSourcePoissonVariable
Expand All @@ -57,5 +59,8 @@
'StocExp', 'StocExpStable', 'StocSigma', 'IFCurrDeltaFixedProb',

# Special
'IFTruncDelta'
'IFTruncDelta',

# Connectors
'AllButMeConnector'
]
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
AbstractGenerateConnectorOnMachine)
from .abstract_generate_connector_on_host import (
AbstractGenerateConnectorOnHost)
from .all_but_me_connector import AllButMeConnector
from .all_to_all_connector import AllToAllConnector
from .array_connector import ArrayConnector
from .csa_connector import CSAConnector
Expand All @@ -36,8 +37,8 @@
from .pool_dense_connector import PoolDenseConnector

__all__ = ["AbstractConnector", "AbstractGenerateConnectorOnMachine",
"AbstractGenerateConnectorOnHost", "AllToAllConnector",
"ArrayConnector", "CSAConnector",
"AbstractGenerateConnectorOnHost", "AllButMeConnector",
"AllToAllConnector", "ArrayConnector", "CSAConnector",
"DistanceDependentProbabilityConnector", "FixedNumberPostConnector",
"FixedNumberPreConnector", "FixedProbabilityConnector",
"FromFileConnector",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class ConnectorIDs(Enum):
FIXED_NUMBER_PRE_CONNECTOR = 4
FIXED_NUMBER_POST_CONNECTOR = 5
KERNEL_CONNECTOR = 6
WTA_CONNECTOR = 7


class AbstractGenerateConnectorOnMachine(
Expand Down
Loading

0 comments on commit 5f8528d

Please sign in to comment.