import logging
import numpy as np
from rdkit import Chem
from rdkit import RDLogger
from rdkit.Chem.Scaffolds import MurckoScaffold
from federatedscope.core.splitters import BaseSplitter
logger = logging.getLogger(__name__)
RDLogger.DisableLog('rdApp.*')
def generate_scaffold(smiles, include_chirality=False):
"""return scaffold string of target molecule"""
mol = Chem.MolFromSmiles(smiles)
scaffold = MurckoScaffold\
.MurckoScaffoldSmiles(mol=mol, includeChirality=include_chirality)
return scaffold
def gen_scaffold_split(dataset, client_num=5):
r"""
return dict{ID:[idxs]}
"""
logger.info('Scaffold split might take minutes, please wait...')
scaffolds = {}
for idx, data in enumerate(dataset):
smiles = data.smiles
_ = Chem.MolFromSmiles(smiles)
scaffold = generate_scaffold(smiles)
if scaffold not in scaffolds:
scaffolds[scaffold] = [idx]
else:
scaffolds[scaffold].append(idx)
# Sort from largest to smallest scaffold sets
scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
scaffold_list = [
list(scaffold_set)
for (scaffold,
scaffold_set) in sorted(scaffolds.items(),
key=lambda x: (len(x[1]), x[1][0]),
reverse=True)
]
scaffold_idxs = sum(scaffold_list, [])
# Split data to list
splits = np.array_split(scaffold_idxs, client_num)
return [splits[ID] for ID in range(client_num)]
[docs]class ScaffoldSplitter(BaseSplitter):
"""
Split molecular via scaffold. This splitter will sort all moleculars, and \
split them into several parts.
Arguments:
client_num (int): Split data into client_num of pieces.
"""
def __init__(self, client_num):
super(ScaffoldSplitter, self).__init__(client_num)
def __call__(self, dataset, **kwargs):
dataset = [ds for ds in dataset]
idx_slice = gen_scaffold_split(dataset)
data_list = [[dataset[idx] for idx in idxs] for idxs in idx_slice]
return data_list