diff --git a/pyschlandals/src/lib.rs b/pyschlandals/src/lib.rs index 7343aaa..0cc9866 100644 --- a/pyschlandals/src/lib.rs +++ b/pyschlandals/src/lib.rs @@ -2,6 +2,8 @@ use pyo3::prelude::*; use pyo3::Python; use std::path::PathBuf; use schlandals::*; +use std::fs::File; +use std::io::Write; #[pyclass] #[derive(Clone)] @@ -58,6 +60,7 @@ fn approximate_search_function(file: String, branching: BranchingHeuristic, epsi fn compilation_submodule(py: Python<'_>, parent_module: &PyModule) -> PyResult<()> { let module = PyModule::new(py, "compiler")?; module.add_function(wrap_pyfunction!(compile_function, module)?)?; + module.add_function(wrap_pyfunction!(compiler_from_file, module)?)?; module.add_class::()?; module.add_class::()?; module.add_class::()?; @@ -194,6 +197,78 @@ impl PyDac { pub fn circuit_node_input_distribution_at(&self, node: usize, index: usize) -> (usize, usize) { self.nodes[node].distribution_input[index] } + + pub fn to_graphviz(&self, path: String) { + + let mut out = String::new(); + out.push_str("digraph {\ntranksep = 3;\n\n"); + let node_attributes = String::from("shape=circle,style=filled"); + + for node in 0..self.distributions.len() { + if !self.distributions[node].outputs.is_empty() { + out.push_str(&format!("{} [{},label=\"{}\"];\n", node, &node_attributes, &format!("d{}", node))); + } + } + + for node in 0..self.nodes.len() { + let id = node + self.distributions.len(); + if self.nodes[node].is_mul { + out.push_str(&format!("{} [{}, label=\"(X ({:.3}))\"];\n", id, &node_attributes, self.nodes[node].value)) + } else { + out.push_str(&format!("{} [{}, label=\"(+ ({:.3}))\"];\n", id, &node_attributes, self.nodes[node].value)) + } + } + + for node in 0..self.distributions.len() { + for (output, value) in self.distributions[node].outputs.iter().copied() { + let to = output + self.distributions.len(); + let label = format!("({}, {:.3})", value, self.distributions[node].probabilities[value]); + out.push_str(&format!("\t{node} -> {to} [penwidth=1,label=\"{label}\"];\n")); + } + } + + for node in 0..self.nodes.len() { + let from = node + self.distributions.len(); + let start = self.nodes[node].outputs_start; + let end = start + self.nodes[node].number_output; + for output in self.outputs[start..end].iter().copied() { + let to = output + self.distributions.len(); + out.push_str(&format!("\t{from} -> {to} [penwidth=1];\n")); + } + } + out.push_str("}\n"); + let mut outfile = File::create(path).unwrap(); + match outfile.write(out.as_bytes()) { + Ok(_) => (), + Err(e) => println!("Culd not write the circuit into the dot file: {:?}", e), + } + } +} + +fn pydac_from_dac(dac: Dac) -> PyDac { + let mut py_dac = PyDac::new(); + py_dac.outputs = dac.outputs_node_iter().map(|n| n.0).collect(); + py_dac.inputs = dac.inputs_node_iter().map(|n| n.0).collect(); + for distribution in dac.distributions_iter() { + py_dac.distributions.push(PyDistributionNode { + probabilities: dac.get_distribution_probabilities(distribution).to_vec(), + outputs: dac.get_distribution_outputs(distribution).iter().copied().map(|(node, value)| (node.0, value)).collect(), + }) + } + + for node in dac.nodes_iter() { + let is_mul = dac.is_circuit_node_mul(node); + py_dac.nodes.push(PyCircuitNode { + outputs_start: dac.get_circuit_node_out_start(node), + number_output: dac.get_circuit_node_number_output(node), + inputs_start: dac.get_circuit_node_in_start(node), + number_input: dac.get_circuit_node_number_input(node), + distribution_input: dac.get_circuit_node_input_distribution(node).map(|(n, v)| (n.0, v)).collect(), + value: dac.get_circuit_node_probability(node).to_f64(), + is_mul, + }) + } + py_dac } #[pyfunction] @@ -206,34 +281,16 @@ fn compile_function(file: String, branching: BranchingHeuristic) -> Option None, - Some(dac) => { - let mut py_dac = PyDac::new(); - py_dac.outputs = dac.outputs_node_iter().map(|n| n.0).collect(); - py_dac.inputs = dac.inputs_node_iter().map(|n| n.0).collect(); - for distribution in dac.distributions_iter() { - py_dac.distributions.push(PyDistributionNode { - probabilities: dac.get_distribution_probabilities(distribution).to_vec(), - outputs: dac.get_distribution_outputs(distribution).iter().copied().map(|(node, value)| (node.0, value)).collect(), - }) - } - - for node in dac.nodes_iter() { - let is_mul = dac.is_circuit_node_mul(node); - py_dac.nodes.push(PyCircuitNode { - outputs_start: dac.get_circuit_node_out_start(node), - number_output: dac.get_circuit_node_number_output(node), - inputs_start: dac.get_circuit_node_in_start(node), - number_input: dac.get_circuit_node_number_input(node), - distribution_input: dac.get_circuit_node_input_distribution(node).map(|(n, v)| (n.0, v)).collect(), - value: if is_mul { 1.0 } else { 0.0 }, - is_mul, - }) - } - Some(py_dac) - } + Some(dac) => Some(pydac_from_dac(dac)), } } +#[pyfunction] +#[pyo3(name = "dac_from_file")] +fn compiler_from_file(file: String) -> PyDac { + pydac_from_dac(Dac::from_file(&PathBuf::from(file))) +} + /// Base module for pyschlandals #[pymodule] diff --git a/src/compiler/circuit.rs b/src/compiler/circuit.rs index 03b7ab9..6a62237 100644 --- a/src/compiler/circuit.rs +++ b/src/compiler/circuit.rs @@ -495,6 +495,10 @@ impl Dac { pub fn is_circuit_node_mul(&self, node: CircuitNodeIndex) -> bool { self.nodes[node.0].is_mul } + + pub fn get_circuit_node_probability(&self, node: CircuitNodeIndex) -> &Float { + &self.nodes[node.0].value + } }