Skip to content

Commit

Permalink
Merge pull request #77 from thibaultvarin-r/molecular_transformer_dat…
Browse files Browse the repository at this point in the history
…aframe_input

Molecular transformer dataframe input
  • Loading branch information
maclandrol authored Sep 5, 2023
2 parents 345c19d + f0641d5 commit 4bb511b
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 1 deletion.
7 changes: 6 additions & 1 deletion molfeat/trans/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import datamol as dm
import numpy as np

from sklearn import utils
from sklearn.base import TransformerMixin
from sklearn.base import BaseEstimator
from loguru import logger
Expand Down Expand Up @@ -295,6 +296,8 @@ def transform(
features: a list of features for each molecule in the input set
"""
# Convert single mol to iterable format
if isinstance(mols, pd.DataFrame):
mols = mols[mols.columns[0]]
if isinstance(mols, (str, dm.Mol)) or not isinstance(mols, Iterable):
mols = [mols]

Expand Down Expand Up @@ -326,7 +329,9 @@ def _to_mol(x):
f"Cannot transform molecule at index {ind}. Please check logs (set verbose to True) to see errors!"
)

return features
# sklearn feature validation
return utils.check_array(features)


def __len__(self):
"""Compute featurizer length"""
Expand Down
68 changes: 68 additions & 0 deletions tests/test_molecule_transformer_dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import pytest

import numpy as np
import pandas as pd

from sklearn.pipeline import Pipeline
from sklearn.naive_bayes import GaussianNB
from sklearn.compose import ColumnTransformer
from molfeat.trans.base import MoleculeTransformer


@pytest.fixture
def smiles():
return ['CC1CC2C3CCC4=CC(=O)C=CC4(C3(C(CC2(C(=O)CO1)O)C)O)C',
'CN(CCOC(c1ccccc1)c1ccccc1)C',
'O/N=C(/c1csc(n1)N)\C(=O)N[C@@H]1C(=O)N2[C@@H]1SCC(=C2C(=O)O)C=C',
'CC(C)(C)NCC(C1=CC(=C(C=C1)O)CO)O']


@pytest.fixture(params=[
'list',
'series',
'dataframe'
])
def mols(request, smiles):
if request.param == 'list':
return smiles
elif request.param == 'series':
return pd.Series(smiles, name='smiles')
elif request.param == 'dataframe':
return pd.DataFrame({'smiles': smiles, 'column_2': [1, 0, 1, 1]})


def test_list_series_dataframe(mols):
transformer_ecfp = MoleculeTransformer(featurizer='ecfp')
results = transformer_ecfp.fit_transform(mols)

assert results.shape == (4, 2048)
assert isinstance(results, np.ndarray)

def test_with_pipeline_column_transformer(smiles):
# setup data
mols = pd.DataFrame({'smiles': smiles, 'column_2': [1, 0, 1, 1]})

# setup pipeline
transformer_ecfp = MoleculeTransformer(featurizer='ecfp')
column_preprocessor = ColumnTransformer(
transformers=[
('ecfp_trans', transformer_ecfp, ['smiles']),
('column_2', 'passthrough', ['column_2'])
]
)

pipeline = Pipeline([
('preprocess', column_preprocessor),
('classifier', GaussianNB())
])

# fit/predict pipeline
pipeline.fit(mols, [1, 0, 1, 0])
r = pipeline.predict(mols)

# tests
expect = np.ndarray
assert isinstance(r, expect)

expect = (4,)
assert r.shape==expect

0 comments on commit 4bb511b

Please sign in to comment.