Skip to content

Commit

Permalink
minor improvements:
Browse files Browse the repository at this point in the history
- toy data fix
- widget div 0 safer
  • Loading branch information
RiesBen committed Apr 18, 2024
1 parent 7c25b55 commit 6d257aa
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/konnektor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# For details, see https://github.com/OpenFreeEnergy/kartograf

from .network_planners import (MaximalNetworkPlanner,
HeuristicMaximalNetworkPlanner,
RadialLigandNetworkPlanner,
StarrySkyLigandNetworkPlanner,
MinimalSpanningTreeLigandNetworkPlanner,
Expand Down
2 changes: 2 additions & 0 deletions src/konnektor/network_planners/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#Network Generators
from .generators.maximal_network_planner import MaximalNetworkPlanner
from .generators.heuristic_maximal_network_planner import HeuristicMaximalNetworkPlanner

## Starmap Like Networks
from .generators.radial_network_planner import StarLigandNetworkPlanner
RadialLigandNetworkPlanner = StarLigandNetworkPlanner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def concatenate_networks(self, ligand_networks: Iterable[LigandNetwork]) -> Liga

selected_mappings = [edge_map[k] if (k in edge_map) else edge_map[
tuple(list(k)[::-1])] for k in mg.edges]
log.info("Adding ConnectingEdges: " + str(len(selected_mappings)))

""" # prio queue kruska approach
connecting_nodes = []
Expand All @@ -69,7 +70,6 @@ def concatenate_networks(self, ligand_networks: Iterable[LigandNetwork]) -> Liga
if len(connecting_edges) >= self.n_connecting_edges:
break
"""
log.info("Adding ConnectingEdges: " + str(len(selected_mappings)))

# Constructed final Edges:
# Add all old network edges:
Expand All @@ -82,7 +82,6 @@ def concatenate_networks(self, ligand_networks: Iterable[LigandNetwork]) -> Liga

concat_LigandNetwork = LigandNetwork(edges=selected_edges, nodes=set(selected_nodes))

log.info("Total Concatenating Edges: " + str(len(selected_mappings)))
log.info("Total Concatenated Edges: " + str(len(selected_edges)))

return concat_LigandNetwork
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from typing import Iterable

from gufe import SmallMoleculeComponent, LigandNetwork
from gufe import AtomMapper

from .netx_netgen._abstract_network_generator import _AbstractNetworkGenerator
log = logging.getLogger(__name__)
#log.setLevel(logging.WARNING)

Expand All @@ -12,8 +14,18 @@ class LigandNetworkPlanner(abc.ABC):
progress: bool = False
nprocesses: int

def __init__(self, mapper, scorer, network_generator, nprocesses:int=1,
def __init__(self, mapper: AtomMapper, scorer, network_generator:_AbstractNetworkGenerator, nprocesses:int=1,
_initial_edge_lister=None):
"""
Parameters
----------
mapper
scorer
network_generator
nprocesses
_initial_edge_lister
"""
self.mapper = mapper
self.scorer = scorer
self.network_generator = network_generator
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import itertools
import functools
import numpy as np

from tqdm.auto import tqdm

from typing import Iterable, Union

from gufe import SmallMoleculeComponent, LigandNetwork

from ._abstract_ligand_network_planner import LigandNetworkPlanner
from ._parallel_mapping_pattern import _parallel_map_scoring

class HeuristicMaximalNetworkPlanner(LigandNetworkPlanner):
def __init__(self, mapper, scorer, progress=False, nprocesses=1, n_samples:int=100):
super().__init__(mapper=mapper, scorer=scorer,
network_generator=None, _initial_edge_lister=self)
self.progress = progress
self.nprocesses = nprocesses
self.n_samples = n_samples

def generate_ligand_network(self, nodes: Iterable[SmallMoleculeComponent]):
"""Create a network with all possible proposed mappings.
This will attempt to create (and optionally score) all possible mappings
(up to $N(N-1)/2$ for each mapper given). There may be fewer actual
mappings that this because, when a mapper cannot return a mapping for a
given pair, there is simply no suggested mapping for that pair.
This network is typically used as the starting point for other network
generators (which then optimize based on the scores) or to debug atom
mappers (to see which mappings the mapper fails to generate).
Parameters
----------
nodes : Iterable[SmallMoleculeComponent]
the ligands to include in the LigandNetwork
mappers : Iterable[LigandAtomMapper]
the AtomMappers to use to propose mappings. At least 1 required,
but many can be given, in which case all will be tried to find the
lowest score edges
scorer : Scoring function
any callable which takes a LigandAtomMapping and returns a float
progress : Union[bool, Callable[Iterable], Iterable]
progress bar: if False, no progress bar will be shown. If True, use a
tqdm progress bar that only appears after 1.5 seconds. You can also
provide a custom progress bar wrapper as a callable.
"""
nodes = list(nodes)
total = len(nodes) * (len(nodes) - 1) // 2

# Parallel or not Parallel:
# generate combinations to be searched.
if len(nodes) > self.n_samples:
sample_combinations = []
for n in nodes:
sample_indices =np.random.choice(range(len(nodes)), size=self.n_samples, replace=False)
sample_combinations.extend([(n, nodes[i]) for i in sample_indices if n!=nodes[i]])
else:
sample_combinations = itertools.combinations(nodes,2)

if(self.nprocesses > 1):
mappings = _parallel_map_scoring(
possible_edges=sample_combinations,
scorer=self.scorer,
mapper=self.mapper,
n_processes=self.nprocesses,
show_progress=self.progress)
else: #serial variant
if self.progress is True:
progress = functools.partial(tqdm, total=total, delay=1.5,
desc="Mapping")
else:
progress = lambda x: x

mapping_generator = itertools.chain.from_iterable(
self.mapper.suggest_mappings(molA, molB)
for molA, molB in progress(sample_combinations)
)
if self.scorer:
mappings = [mapping.with_annotations({'score': self.scorer(mapping)})
for mapping in mapping_generator]
else:
mappings = list(mapping_generator)

network = LigandNetwork(mappings, nodes=nodes)
return network
1 change: 1 addition & 0 deletions src/konnektor/utils/toy_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from rdkit import Chem
from rdkit.Chem import AllChem
from gufe import SmallMoleculeComponent

from gufe import LigandAtomMapping
Expand Down
9 changes: 7 additions & 2 deletions src/konnektor/visualization/widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,13 @@ def build_cytoscape(network, layout="concentric", show_molecules=True, show_mapp
weights = [edge_map[k].annotations['score'] for k in edges]

connectivities = np.array(get_node_connectivities(network))
mixins = np.clip(connectivities / (sum(connectivities) / len(connectivities)), a_min=0, a_max=2) / 2
cs = list(map(lambda x: color_gradient(mix=x), mixins))
if(len(connectivities) == 0):
mixins=np.array([0])
cs = list(map(lambda x: color_gradient(mix=x), mixins))

else:
mixins = np.clip(connectivities / (sum(connectivities) / len(connectivities)), a_min=0, a_max=2) / 2
cs = list(map(lambda x: color_gradient(mix=x), mixins))

# build a graph
g = nx.Graph()
Expand Down

0 comments on commit 6d257aa

Please sign in to comment.