Skip to content

Commit

Permalink
add support for initializing registers and memories to the functional…
Browse files Browse the repository at this point in the history
… backend
  • Loading branch information
aiju committed Jul 24, 2024
1 parent 76d06ed commit a468cf3
Show file tree
Hide file tree
Showing 10 changed files with 416 additions and 280 deletions.
101 changes: 59 additions & 42 deletions backends/functional/cxx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,35 +69,48 @@ struct CxxType {
using CxxWriter = FunctionalTools::Writer;

struct CxxStruct {
std::string name;
dict<IdString, CxxType> types;
CxxScope<IdString> scope;
CxxStruct(std::string name)
: name(name) {
scope.reserve("fn");
scope.reserve("visit");
}
void insert(IdString name, CxxType type) {
scope(name, name);
types.insert({name, type});
}
void print(CxxWriter &f) {
f.print("\tstruct {} {{\n", name);
for (auto p : types) {
f.print("\t\t{} {};\n", p.second.to_string(), scope(p.first, p.first));
}
f.print("\n\t\ttemplate <typename T> void visit(T &&fn) {{\n");
for (auto p : types) {
f.print("\t\t\tfn(\"{}\", {});\n", RTLIL::unescape_id(p.first), scope(p.first, p.first));
}
f.print("\t\t}}\n");
f.print("\t}};\n\n");
};
std::string operator[](IdString field) {
return scope(field, field);
}
std::string name;
dict<IdString, CxxType> types;
CxxScope<IdString> scope;
CxxStruct(std::string name) : name(name)
{
scope.reserve("fn");
scope.reserve("visit");
}
void insert(IdString name, CxxType type) {
scope(name, name);
types.insert({name, type});
}
void print(CxxWriter &f) {
f.print("\tstruct {} {{\n", name);
for (auto p : types) {
f.print("\t\t{} {};\n", p.second.to_string(), scope(p.first, p.first));
}
f.print("\n\t\ttemplate <typename T> void visit(T &&fn) {{\n");
for (auto p : types) {
f.print("\t\t\tfn(\"{}\", {});\n", RTLIL::unescape_id(p.first), scope(p.first, p.first));
}
f.print("\t\t}}\n");
f.print("\t}};\n\n");
};
std::string operator[](IdString field) {
return scope(field, field);
}
};

std::string cxx_const(RTLIL::Const const &value) {
std::stringstream ss;
ss << "Signal<" << value.size() << ">(" << std::hex << std::showbase;
if(value.size() > 32) ss << "{";
for(int i = 0; i < value.size(); i += 32) {
if(i > 0) ss << ", ";
ss << value.extract(i, 32).as_int();
}
if(value.size() > 32) ss << "}";
ss << ")";
return ss.str();
}

template<class NodePrinter> struct CxxPrintVisitor : public FunctionalIR::AbstractVisitor<void> {
using Node = FunctionalIR::Node;
CxxWriter &f;
Expand Down Expand Up @@ -136,20 +149,7 @@ template<class NodePrinter> struct CxxPrintVisitor : public FunctionalIR::Abstra
void logical_shift_right(Node, Node a, Node b) override { print("{} >> {}", a, b); }
void arithmetic_shift_right(Node, Node a, Node b) override { print("{}.arithmetic_shift_right({})", a, b); }
void mux(Node, Node a, Node b, Node s) override { print("{2}.any() ? {1} : {0}", a, b, s); }
void constant(Node, RTLIL::Const value) override {
std::stringstream ss;
bool multiple = value.size() > 32;
ss << "Signal<" << value.size() << ">(" << std::hex << std::showbase;
if(multiple) ss << "{";
while(value.size() > 32) {
ss << value.as_int() << ", ";
value = value.extract(32, value.size() - 32);
}
ss << value.as_int();
if(multiple) ss << "}";
ss << ")";
print("{}", ss.str());
}
void constant(Node, RTLIL::Const const & value) override { print("{}", cxx_const(value)); }
void input(Node, IdString name) override { print("input.{}", input_struct[name]); }
void state(Node, IdString name) override { print("current_state.{}", state_struct[name]); }
void memory_read(Node, Node mem, Node addr) override { print("{}.read({})", mem, addr); }
Expand Down Expand Up @@ -184,8 +184,24 @@ struct CxxModule {
output_struct.print(f);
state_struct.print(f);
f.print("\tstatic void eval(Inputs const &, Outputs &, State const &, State &);\n");
f.print("\tstatic void initialize(State &);\n");
f.print("}};\n\n");
}
void write_initial_def(CxxWriter &f) {
f.print("void {0}::initialize({0}::State &state)\n{{\n", module_name);
for (auto [name, sort] : ir.state()) {
if (sort.is_signal())
f.print("\tstate.{} = {};\n", state_struct[name], cxx_const(ir.get_initial_state_signal(name)));
else if (sort.is_memory()) {
const auto &contents = ir.get_initial_state_memory(name);
f.print("\tstate.{}.fill({});\n", state_struct[name], cxx_const(contents.default_value()));
for(auto range : contents)
for(auto addr = range.base(); addr < range.limit(); addr++)
f.print("\tstate.{}[{}] = {};\n", state_struct[name], addr, cxx_const(range[addr]));
}
}
f.print("}}\n\n");
}
void write_eval_def(CxxWriter &f) {
f.print("void {0}::eval({0}::Inputs const &input, {0}::Outputs &output, {0}::State const &current_state, {0}::State &next_state)\n{{\n", module_name);
CxxScope<int> locals;
Expand All @@ -204,7 +220,7 @@ struct CxxModule {
f.print("\tnext_state.{} = {};\n", state_struct[name], node_name(ir.get_state_next_node(name)));
for (auto [name, sort] : ir.outputs())
f.print("\toutput.{} = {};\n", output_struct[name], node_name(ir.get_output_node(name)));
f.print("}}\n");
f.print("}}\n\n");
}
};

Expand All @@ -225,6 +241,7 @@ struct FunctionalCxxBackend : public Backend
mod.write_header(f);
mod.write_struct_def(f);
mod.write_eval_def(f);
mod.write_initial_def(f);
}

void execute(std::ostream *&f, std::string filename, std::vector<std::string> args, RTLIL::Design *design) override
Expand Down
3 changes: 3 additions & 0 deletions backends/functional/cxx_runtime/sim.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,9 @@ class Memory {
ret._contents[addr.template as_numeric<size_t>()] = data;
return ret;
}
// mutating methods for initializing a state
void fill(Signal<d> data) { _contents.fill(data); }
Signal<d> &operator[](Signal<a> addr) { return _contents[addr.template as_numeric<size_t>()]; }
};

#endif
68 changes: 46 additions & 22 deletions backends/functional/smtlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ class SmtStruct {
}
};

std::string smt_const(RTLIL::Const const &c) {
std::string s = "#b";
for(int i = c.size(); i-- > 0; )
s += c[i] == State::S1 ? '1' : '0';
return s;
}

struct SmtPrintVisitor : public FunctionalIR::AbstractVisitor<SExpr> {
using Node = FunctionalIR::Node;
std::function<SExpr(Node)> n;
Expand All @@ -117,13 +124,6 @@ struct SmtPrintVisitor : public FunctionalIR::AbstractVisitor<SExpr> {

SmtPrintVisitor(SmtStruct &input_struct, SmtStruct &state_struct) : input_struct(input_struct), state_struct(state_struct) {}

std::string literal(RTLIL::Const c) {
std::string s = "#b";
for(int i = c.size(); i-- > 0; )
s += c[i] == State::S1 ? '1' : '0';
return s;
}

SExpr from_bool(SExpr &&arg) {
return list("ite", std::move(arg), "#b1", "#b0");
}
Expand All @@ -149,8 +149,8 @@ struct SmtPrintVisitor : public FunctionalIR::AbstractVisitor<SExpr> {
SExpr bitwise_xor(Node, Node a, Node b) override { return list("bvxor", n(a), n(b)); }
SExpr bitwise_not(Node, Node a) override { return list("bvnot", n(a)); }
SExpr unary_minus(Node, Node a) override { return list("bvneg", n(a)); }
SExpr reduce_and(Node, Node a) override { return from_bool(list("=", n(a), literal(RTLIL::Const(State::S1, a.width())))); }
SExpr reduce_or(Node, Node a) override { return from_bool(list("distinct", n(a), literal(RTLIL::Const(State::S0, a.width())))); }
SExpr reduce_and(Node, Node a) override { return from_bool(list("=", n(a), smt_const(RTLIL::Const(State::S1, a.width())))); }
SExpr reduce_or(Node, Node a) override { return from_bool(list("distinct", n(a), smt_const(RTLIL::Const(State::S0, a.width())))); }
SExpr reduce_xor(Node, Node a) override {
vector<SExpr> s { "bvxor" };
for(int i = 0; i < a.width(); i++)
Expand All @@ -174,7 +174,7 @@ struct SmtPrintVisitor : public FunctionalIR::AbstractVisitor<SExpr> {
SExpr logical_shift_right(Node, Node a, Node b) override { return list("bvlshr", n(a), extend(n(b), b.width(), a.width())); }
SExpr arithmetic_shift_right(Node, Node a, Node b) override { return list("bvashr", n(a), extend(n(b), b.width(), a.width())); }
SExpr mux(Node, Node a, Node b, Node s) override { return list("ite", to_bool(n(s)), n(b), n(a)); }
SExpr constant(Node, RTLIL::Const value) override { return literal(value); }
SExpr constant(Node, RTLIL::Const const &value) override { return smt_const(value); }
SExpr memory_read(Node, Node mem, Node addr) override { return list("select", n(mem), n(addr)); }
SExpr memory_write(Node, Node mem, Node addr, Node data) override { return list("store", n(mem), n(addr), n(data)); }

Expand All @@ -199,6 +199,7 @@ struct SmtModule {
, output_struct(scope.unique_name(module->name.str() + "_Outputs"), scope)
, state_struct(scope.unique_name(module->name.str() + "_State"), scope)
{
scope.reserve(name + "-initial");
for (const auto &input : ir.inputs())
input_struct.insert(input.first, input.second);
for (const auto &output : ir.outputs())
Expand All @@ -207,18 +208,8 @@ struct SmtModule {
state_struct.insert(state.first, state.second);
}

void write(std::ostream &out)
{
SExprWriter w(out);

input_struct.write_definition(w);
output_struct.write_definition(w);
state_struct.write_definition(w);

w << list("declare-datatypes",
list(list("Pair", 2)),
list(list("par", list("X", "Y"), list(list("pair", list("first", "X"), list("second", "Y"))))));

void write_eval(SExprWriter &w)
{
w.push();
w.open(list("define-fun", name,
list(list("inputs", input_struct.name),
Expand All @@ -245,6 +236,39 @@ struct SmtModule {
state_struct.write_value(w, [&](IdString name) { return node_to_sexpr(ir.get_state_next_node(name)); });
w.pop();
}

void write_initial(SExprWriter &w)
{
std::string initial = name + "-initial";
w << list("declare-const", initial, state_struct.name);
for (const auto &[name, sort] : ir.state()) {
if(sort.is_signal())
w << list("assert", list("=", state_struct.access(initial, name), smt_const(ir.get_initial_state_signal(name))));
else if(sort.is_memory()) {
auto contents = ir.get_initial_state_memory(name);
for(int i = 0; i < 1<<sort.addr_width(); i++) {
auto addr = smt_const(RTLIL::Const(i, sort.addr_width()));
w << list("assert", list("=", list("select", state_struct.access(initial, name), addr), smt_const(contents[i])));
}
}
}
}

void write(std::ostream &out)
{
SExprWriter w(out);

input_struct.write_definition(w);
output_struct.write_definition(w);
state_struct.write_definition(w);

w << list("declare-datatypes",
list(list("Pair", 2)),
list(list("par", list("X", "Y"), list(list("pair", list("first", "X"), list("second", "Y"))))));

write_eval(w);
write_initial(w);
}
};

struct FunctionalSmtBackend : public Backend {
Expand Down
99 changes: 99 additions & 0 deletions backends/functional/test_generic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,100 @@

#include "kernel/yosys.h"
#include "kernel/functionalir.h"
#include <random>

USING_YOSYS_NAMESPACE
PRIVATE_NAMESPACE_BEGIN

struct MemContentsTest {
int addr_width, data_width;
MemContents state;
using addr_t = MemContents::addr_t;
std::map<addr_t, RTLIL::Const> reference;
MemContentsTest(int addr_width, int data_width) : addr_width(addr_width), data_width(data_width), state(addr_width, data_width, RTLIL::Const(State::S0, data_width)) {}
void check() {
state.check();
for(auto addr = 0; addr < (1<<addr_width); addr++) {
auto it = reference.find(addr);
if(it != reference.end()) {
if(state.count_range(addr, addr + 1) != 1) goto error;
if(it->second != state[addr]) goto error;
} else {
if(state.count_range(addr, addr + 1) != 0) goto error;
}
}
return;
error:
printf("FAIL\n");
int digits = (data_width + 3) / 4;

for(auto addr = 0; addr < (1<<addr_width); addr++) {
if(addr % 8 == 0) printf("%.8x ", addr);
auto it = reference.find(addr);
bool ref_def = it != reference.end();
RTLIL::Const ref_value = ref_def ? it->second : state.default_value();
std::string ref_string = stringf("%.*x", digits, ref_value.as_int());
bool sta_def = state.count_range(addr, addr + 1) == 1;
RTLIL::Const sta_value = state[addr];
std::string sta_string = stringf("%.*x", digits, sta_value.as_int());
if(ref_def && sta_def) {
if(ref_value == sta_value) printf("%s%s", ref_string.c_str(), string(digits, ' ').c_str());
else printf("%s%s", ref_string.c_str(), sta_string.c_str());
} else if(ref_def) {
printf("%s%s", ref_string.c_str(), string(digits, 'M').c_str());
} else if(sta_def) {
printf("%s%s", sta_string.c_str(), string(digits, 'X').c_str());
} else {
printf("%s", string(2*digits, ' ').c_str());
}
printf(" ");
if(addr % 8 == 7) printf("\n");
}
printf("\n");
//log_abort();
}
void clear_range(addr_t begin_addr, addr_t end_addr) {
for(auto addr = begin_addr; addr != end_addr; addr++)
reference.erase(addr);
state.clear_range(begin_addr, end_addr);
check();
}
void insert_concatenated(addr_t addr, RTLIL::Const const &values) {
addr_t words = ((addr_t) values.size() + data_width - 1) / data_width;
for(addr_t i = 0; i < words; i++) {
reference.erase(addr + i);
reference.emplace(addr + i, values.extract(i * data_width, data_width));
}
state.insert_concatenated(addr, values);
check();
}
template<typename Rnd> void run(Rnd &rnd, int n) {
std::uniform_int_distribution<addr_t> addr_dist(0, (1<<addr_width) - 1);
std::poisson_distribution<addr_t> length_dist(10);
std::uniform_int_distribution<uint64_t> data_dist(0, ((uint64_t)1<<data_width) - 1);
while(n-- > 0) {
addr_t low = addr_dist(rnd);
//addr_t length = std::min((1<<addr_width) - low, length_dist(rnd));
//addr_t high = low + length - 1;
addr_t high = addr_dist(rnd);
if(low > high) std::swap(low, high);
if((rnd() & 7) == 0) {
log_debug("clear %.2x to %.2x\n", (int)low, (int)high);
clear_range(low, high + 1);
} else {
log_debug("insert %.2x to %.2x\n", (int)low, (int)high);
RTLIL::Const values;
for(addr_t addr = low; addr <= high; addr++) {
RTLIL::Const word(data_dist(rnd), data_width);
values.bits.insert(values.bits.end(), word.bits.begin(), word.bits.end());
}
insert_concatenated(low, values);
}
}
}

};

struct FunctionalTestGeneric : public Pass
{
FunctionalTestGeneric() : Pass("test_generic", "test the generic compute graph") {}
Expand All @@ -40,6 +130,14 @@ struct FunctionalTestGeneric : public Pass
size_t argidx = 1;
extra_args(args, argidx, design);

MemContentsTest test(8, 16);

std::random_device seed_dev;
std::mt19937 rnd(23); //seed_dev());
test.run(rnd, 1000);

/*
for (auto module : design->selected_modules()) {
log("Dumping module `%s'.\n", module->name.c_str());
auto fir = FunctionalIR::from_module(module);
Expand All @@ -50,6 +148,7 @@ struct FunctionalTestGeneric : public Pass
for(auto [name, sort] : fir.state())
std::cout << RTLIL::unescape_id(name) << " = " << RTLIL::unescape_id(fir.get_state_next_node(name).name()) << "\n";
}
*/
}
} FunctionalCxxBackend;

Expand Down
Loading

0 comments on commit a468cf3

Please sign in to comment.