diff --git a/src/konnektor/__init__.py b/src/konnektor/__init__.py index 86645de..6971754 100644 --- a/src/konnektor/__init__.py +++ b/src/konnektor/__init__.py @@ -2,6 +2,7 @@ # For details, see https://github.com/OpenFreeEnergy/kartograf from .network_planners import (MaximalNetworkPlanner, + HeuristicMaximalNetworkPlanner, RadialLigandNetworkPlanner, StarrySkyLigandNetworkPlanner, MinimalSpanningTreeLigandNetworkPlanner, diff --git a/src/konnektor/network_planners/__init__.py b/src/konnektor/network_planners/__init__.py index 40f8595..8fe04f7 100644 --- a/src/konnektor/network_planners/__init__.py +++ b/src/konnektor/network_planners/__init__.py @@ -1,5 +1,7 @@ #Network Generators from .generators.maximal_network_planner import MaximalNetworkPlanner +from .generators.heuristic_maximal_network_planner import HeuristicMaximalNetworkPlanner + ## Starmap Like Networks from .generators.radial_network_planner import StarLigandNetworkPlanner RadialLigandNetworkPlanner = StarLigandNetworkPlanner diff --git a/src/konnektor/network_planners/concatenator/mst_concatenator.py b/src/konnektor/network_planners/concatenator/mst_concatenator.py index 3df556b..4aaeeb5 100644 --- a/src/konnektor/network_planners/concatenator/mst_concatenator.py +++ b/src/konnektor/network_planners/concatenator/mst_concatenator.py @@ -53,6 +53,7 @@ def concatenate_networks(self, ligand_networks: Iterable[LigandNetwork]) -> Liga selected_mappings = [edge_map[k] if (k in edge_map) else edge_map[ tuple(list(k)[::-1])] for k in mg.edges] + log.info("Adding ConnectingEdges: " + str(len(selected_mappings))) """ # prio queue kruska approach connecting_nodes = [] @@ -69,7 +70,6 @@ def concatenate_networks(self, ligand_networks: Iterable[LigandNetwork]) -> Liga if len(connecting_edges) >= self.n_connecting_edges: break """ - log.info("Adding ConnectingEdges: " + str(len(selected_mappings))) # Constructed final Edges: # Add all old network edges: @@ -82,7 +82,6 @@ def concatenate_networks(self, ligand_networks: Iterable[LigandNetwork]) -> Liga concat_LigandNetwork = LigandNetwork(edges=selected_edges, nodes=set(selected_nodes)) - log.info("Total Concatenating Edges: " + str(len(selected_mappings))) log.info("Total Concatenated Edges: " + str(len(selected_edges))) return concat_LigandNetwork diff --git a/src/konnektor/network_planners/generators/_abstract_ligand_network_planner.py b/src/konnektor/network_planners/generators/_abstract_ligand_network_planner.py index 4a977da..4b0f13c 100644 --- a/src/konnektor/network_planners/generators/_abstract_ligand_network_planner.py +++ b/src/konnektor/network_planners/generators/_abstract_ligand_network_planner.py @@ -3,7 +3,9 @@ from typing import Iterable from gufe import SmallMoleculeComponent, LigandNetwork +from gufe import AtomMapper +from .netx_netgen._abstract_network_generator import _AbstractNetworkGenerator log = logging.getLogger(__name__) #log.setLevel(logging.WARNING) @@ -12,8 +14,18 @@ class LigandNetworkPlanner(abc.ABC): progress: bool = False nprocesses: int - def __init__(self, mapper, scorer, network_generator, nprocesses:int=1, + def __init__(self, mapper: AtomMapper, scorer, network_generator:_AbstractNetworkGenerator, nprocesses:int=1, _initial_edge_lister=None): + """ + + Parameters + ---------- + mapper + scorer + network_generator + nprocesses + _initial_edge_lister + """ self.mapper = mapper self.scorer = scorer self.network_generator = network_generator diff --git a/src/konnektor/network_planners/generators/heuristic_maximal_network_planner.py b/src/konnektor/network_planners/generators/heuristic_maximal_network_planner.py new file mode 100644 index 0000000..eac21a3 --- /dev/null +++ b/src/konnektor/network_planners/generators/heuristic_maximal_network_planner.py @@ -0,0 +1,87 @@ +import itertools +import functools +import numpy as np + +from tqdm.auto import tqdm + +from typing import Iterable, Union + +from gufe import SmallMoleculeComponent, LigandNetwork + +from ._abstract_ligand_network_planner import LigandNetworkPlanner +from ._parallel_mapping_pattern import _parallel_map_scoring + +class HeuristicMaximalNetworkPlanner(LigandNetworkPlanner): + def __init__(self, mapper, scorer, progress=False, nprocesses=1, n_samples:int=100): + super().__init__(mapper=mapper, scorer=scorer, + network_generator=None, _initial_edge_lister=self) + self.progress = progress + self.nprocesses = nprocesses + self.n_samples = n_samples + + def generate_ligand_network(self, nodes: Iterable[SmallMoleculeComponent]): + """Create a network with all possible proposed mappings. + + This will attempt to create (and optionally score) all possible mappings + (up to $N(N-1)/2$ for each mapper given). There may be fewer actual + mappings that this because, when a mapper cannot return a mapping for a + given pair, there is simply no suggested mapping for that pair. + This network is typically used as the starting point for other network + generators (which then optimize based on the scores) or to debug atom + mappers (to see which mappings the mapper fails to generate). + + + Parameters + ---------- + nodes : Iterable[SmallMoleculeComponent] + the ligands to include in the LigandNetwork + mappers : Iterable[LigandAtomMapper] + the AtomMappers to use to propose mappings. At least 1 required, + but many can be given, in which case all will be tried to find the + lowest score edges + scorer : Scoring function + any callable which takes a LigandAtomMapping and returns a float + progress : Union[bool, Callable[Iterable], Iterable] + progress bar: if False, no progress bar will be shown. If True, use a + tqdm progress bar that only appears after 1.5 seconds. You can also + provide a custom progress bar wrapper as a callable. + """ + nodes = list(nodes) + total = len(nodes) * (len(nodes) - 1) // 2 + + # Parallel or not Parallel: + # generate combinations to be searched. + if len(nodes) > self.n_samples: + sample_combinations = [] + for n in nodes: + sample_indices =np.random.choice(range(len(nodes)), size=self.n_samples, replace=False) + sample_combinations.extend([(n, nodes[i]) for i in sample_indices if n!=nodes[i]]) + else: + sample_combinations = itertools.combinations(nodes,2) + + if(self.nprocesses > 1): + mappings = _parallel_map_scoring( + possible_edges=sample_combinations, + scorer=self.scorer, + mapper=self.mapper, + n_processes=self.nprocesses, + show_progress=self.progress) + else: #serial variant + if self.progress is True: + progress = functools.partial(tqdm, total=total, delay=1.5, + desc="Mapping") + else: + progress = lambda x: x + + mapping_generator = itertools.chain.from_iterable( + self.mapper.suggest_mappings(molA, molB) + for molA, molB in progress(sample_combinations) + ) + if self.scorer: + mappings = [mapping.with_annotations({'score': self.scorer(mapping)}) + for mapping in mapping_generator] + else: + mappings = list(mapping_generator) + + network = LigandNetwork(mappings, nodes=nodes) + return network \ No newline at end of file diff --git a/src/konnektor/utils/toy_data.py b/src/konnektor/utils/toy_data.py index a539076..8801aa3 100644 --- a/src/konnektor/utils/toy_data.py +++ b/src/konnektor/utils/toy_data.py @@ -1,4 +1,5 @@ from rdkit import Chem +from rdkit.Chem import AllChem from gufe import SmallMoleculeComponent from gufe import LigandAtomMapping diff --git a/src/konnektor/visualization/widget.py b/src/konnektor/visualization/widget.py index ae2dac1..fc6cade 100644 --- a/src/konnektor/visualization/widget.py +++ b/src/konnektor/visualization/widget.py @@ -79,8 +79,13 @@ def build_cytoscape(network, layout="concentric", show_molecules=True, show_mapp weights = [edge_map[k].annotations['score'] for k in edges] connectivities = np.array(get_node_connectivities(network)) - mixins = np.clip(connectivities / (sum(connectivities) / len(connectivities)), a_min=0, a_max=2) / 2 - cs = list(map(lambda x: color_gradient(mix=x), mixins)) + if(len(connectivities) == 0): + mixins=np.array([0]) + cs = list(map(lambda x: color_gradient(mix=x), mixins)) + + else: + mixins = np.clip(connectivities / (sum(connectivities) / len(connectivities)), a_min=0, a_max=2) / 2 + cs = list(map(lambda x: color_gradient(mix=x), mixins)) # build a graph g = nx.Graph()