Skip to content

Commit 4fbb047

Browse files
authored
Merge pull request #223 from datamol-io/fix/parse-args-in-lasso
Allow additional args for colors in lasso
2 parents 5d1cde1 + e56d083 commit 4fbb047

File tree

3 files changed

+82
-5
lines changed

3 files changed

+82
-5
lines changed

datamol/viz/_lasso_highlight.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# - possibility to do this for multiple target molecules at once
77
# - have the option to write to a file like to_image
88

9-
from typing import List, Iterator, Tuple, Union, Optional, Any, cast
9+
from typing import List, Dict, Iterator, Tuple, Union, Optional, Any, cast
1010

1111
from collections import defaultdict
1212
from collections import namedtuple
@@ -400,6 +400,10 @@ def lasso_highlight_image(
400400
line_width: int = 2,
401401
scale_padding: float = 1.0,
402402
verbose: bool = False,
403+
highlight_atoms: Optional[List[List[int]]] = None,
404+
highlight_bonds: Optional[List[List[int]]] = None,
405+
highlight_atom_colors: Optional[List[Dict[int, DatamolColor]]] = None,
406+
highlight_bond_colors: Optional[List[Dict[int, DatamolColor]]] = None,
403407
**kwargs: Any,
404408
):
405409
"""Create an image of a list of molecules with substructure matches using lasso-based highlighting.
@@ -408,7 +412,7 @@ def lasso_highlight_image(
408412
Args:
409413
target_molecules: One or a list of molecules to be highlighted.
410414
search_molecules: The substructure to be highlighted.
411-
atom_indices: Atom indices to be highlighted substructure.
415+
atom_indices: Atom indices to be highlighted as substructure using the lasso visualization.
412416
legends: A string or a list of string as legend for every molecules.
413417
n_cols: Number of molecules per column.
414418
mol_size: The size of the image to be returned
@@ -421,6 +425,10 @@ def lasso_highlight_image(
421425
line_width: width of drawn lines.
422426
scale_padding: Padding around the molecule when drawing to scale.
423427
verbose: Whether to print the verbose information.
428+
highlight_atoms: The atoms to highlight, a list for each molecule. It's the `highlightAtoms` argument of the RDKit drawer object.
429+
highlight_bonds: The bonds to highlight, a list for each molecule. It's the `highlightBonds` argument of the RDKit drawer object.
430+
highlight_atom_colors: The colors to use for highlighting atoms, a list of dict mapping atom index to color for each molecule.
431+
highlight_bond_colors: The colors to use for highlighting bonds, a list of dict mapping bond index to color for each molecule.
424432
**kwargs: Additional arguments to pass to the drawing function. See RDKit
425433
documentation related to `MolDrawOptions` for more details at
426434
https://www.rdkit.org/docs/source/rdkit.Chem.Draw.rdMolDraw2D.html.
@@ -551,9 +559,38 @@ def lasso_highlight_image(
551559
# EN: the following is edge-case free after trying 6 different logics, but may break if RDKit changes the way it draws molecules
552560
scaling_val = Point2D(scale_padding, scale_padding)
553561

562+
if isinstance(highlight_atoms, list) and isinstance(highlight_atoms[0], int):
563+
highlight_atoms = [highlight_atoms] * len(target_molecules)
564+
if isinstance(highlight_bonds, list) and isinstance(highlight_bonds[0], int):
565+
highlight_bonds = [highlight_bonds] * len(target_molecules)
566+
if isinstance(highlight_atom_colors, dict):
567+
highlight_atom_colors = [highlight_atom_colors] * len(target_molecules)
568+
if isinstance(highlight_bond_colors, dict):
569+
highlight_bond_colors = [highlight_bond_colors] * len(target_molecules)
570+
571+
# make sure we are using rdkit colors
572+
if highlight_atom_colors is not None:
573+
highlight_atom_colors = [
574+
{k: to_rdkit_color(v) for k, v in _.items()} for _ in highlight_atom_colors
575+
]
576+
if highlight_bond_colors is not None:
577+
highlight_bond_colors = [
578+
{k: to_rdkit_color(v) for k, v in _.items()} for _ in highlight_bond_colors
579+
]
580+
581+
kwargs["highlightAtoms"] = highlight_atoms
582+
kwargs["highlightBonds"] = highlight_bonds
583+
kwargs["highlightAtomColors"] = highlight_atom_colors
584+
kwargs["highlightBondColors"] = highlight_bond_colors
585+
554586
try:
555-
drawer.DrawMolecules(mols_to_draw, legends=legends, **kwargs)
556-
except Exception:
587+
drawer.DrawMolecules(
588+
mols_to_draw,
589+
legends=legends,
590+
**kwargs,
591+
)
592+
except Exception as e:
593+
logger.error(e)
557594
raise ValueError(
558595
"Failed to draw molecules. Some arguments neither match expected MolDrawOptions, nor DrawMolecule inputs. Please check the input arguments."
559596
)
@@ -567,8 +604,18 @@ def lasso_highlight_image(
567604
h_pos, w_pos = np.unravel_index(ind, (n_rows, n_cols))
568605
offset_x = int(w_pos * mol_size[0])
569606
offset_y = int(h_pos * mol_size[1])
607+
608+
ind_kwargs = kwargs.copy()
609+
if isinstance(ind_kwargs["highlightAtoms"], list):
610+
ind_kwargs["highlightAtoms"] = ind_kwargs["highlightAtoms"][ind]
611+
if isinstance(ind_kwargs["highlightAtomColors"], list):
612+
ind_kwargs["highlightAtomColors"] = ind_kwargs["highlightAtomColors"][ind]
613+
if isinstance(ind_kwargs["highlightBonds"], list):
614+
ind_kwargs["highlightBonds"] = ind_kwargs["highlightBonds"][ind]
615+
if isinstance(ind_kwargs["highlightBondColors"], list):
616+
ind_kwargs["highlightBondColors"] = ind_kwargs["highlightBondColors"][ind]
570617
drawer.SetOffset(offset_x, offset_y)
571-
drawer.DrawMolecule(mol, legend=legends[ind], **kwargs)
618+
drawer.DrawMolecule(mol, legend=legends[ind], **ind_kwargs)
572619
offset = None
573620
if draw_mols_same_scale:
574621
offset = drawer.Offset()

datamol/viz/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,12 @@ def to_rdkit_color(color: Optional[DatamolColor]) -> Optional[RDKitColor]:
141141
Args:
142142
color: A datamol color: hex, rgb, rgba or None.
143143
"""
144+
if color is None:
145+
return None
146+
144147
if isinstance(color, str):
145148
return mcolors.to_rgba(color) # type: ignore
149+
if isinstance(color, (tuple, list)) and len(color) in [3, 4] and any(x > 1 for x in color):
150+
return tuple(x / 255 if i < 3 else x for i, x in enumerate(color))
151+
146152
return color

tests/test_viz_lasso_highlight.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,30 @@ def test_from_mol():
1717
assert dm.lasso_highlight_image(mol, smarts_list)
1818

1919

20+
def test_with_highlight():
21+
smi = "CO[C@@H](O)C1=C(O[C@H](F)Cl)C(C#N)=C1ONNC[NH3+]"
22+
mol = dm.to_mol(smi)
23+
smarts_list = "CONN"
24+
highlight_atoms = [4, 5, 6]
25+
highlight_bonds = [1, 2, 3, 4]
26+
highlight_atom_colors = {4: (230, 230, 250), 5: (230, 230, 250), 6: (230, 230, 250)}
27+
highlight_bond_colors = {
28+
1: (230, 230, 250),
29+
2: (230, 230, 250),
30+
3: (230, 230, 250),
31+
4: (230, 230, 250),
32+
}
33+
assert dm.lasso_highlight_image(
34+
mol,
35+
smarts_list,
36+
highlight_atoms=highlight_atoms,
37+
highlight_bonds=highlight_bonds,
38+
highlight_atom_colors=highlight_atom_colors,
39+
highlight_bond_colors=highlight_bond_colors,
40+
continuousHighlight=False,
41+
)
42+
43+
2044
def test_original_working_solution_list_single_str():
2145
smi = "CO[C@@H](O)C1=C(O[C@H](F)Cl)C(C#N)=C1ONNC[NH3+]"
2246
smarts_list = ["CONN"]

0 commit comments

Comments
 (0)