化合物の可視化を機械学習に適用してみた

はじめに

 創薬 (dry) Advent Calendar 2019 第15日目の記事です。

 前回は化合物を構成する原子ごとの重みを元に、原子ハイライト+ヒートマップ/棒グラフの補助図で可視化する方法を紹介しました。本記事ではその手法を機械学習に適用してみようと思います。

 前回の記事では以下の形式で算出した weights を用いれば同様に可視化できるように定義しました。

weights = weight_fn(mol)

 本記事ではこの形式に合わせて、以下の3つの手法で重みを計算します。

  • rdkit.Chem.Draw.SimilarityMaps.GetAtomicWeightsForModel
  • Feature Importance
  • SHAP values

[各手法の概説]

GetAtomicWeightsForModel

 RDKitのCookbook (Using scikit-learn with RDKit)でも紹介されている GetSimilarityMapForModelの中で読まれているメソッドです。以下の通り、Fingerprint全体 baseFPを用いて算出した予測値baseProbaと、原子ごとで作成したnewFPを用いて算出したnewProbaの差を以って各原子の重みとしています。

# https://github.com/rdkit/rdkit/blob/master/rdkit/Chem/Draw/SimilarityMaps.py より
def GetAtomicWeightsForModel(probeMol, fpFunction, predictionFunction):
    """
    Calculates the atomic weights for the probe molecule based on
    a fingerprint function and the prediction function of a ML model.
    Parameters:
      probeMol -- the probe molecule
      fpFunction -- the fingerprint function
      predictionFunction -- the prediction function of the ML model
    """
    if hasattr(probeMol, '_fpInfo'):
        delattr(probeMol, '_fpInfo')
    probeFP = fpFunction(probeMol, -1)
    baseProba = predictionFunction(probeFP)
    # loop over atoms
    weights = []
    for atomId in range(probeMol.GetNumAtoms()):
        newFP = fpFunction(probeMol, atomId)
        newProba = predictionFunction(newFP)
        weights.append(baseProba - newProba)
    if hasattr(probeMol, '_fpInfo'):
        delattr(probeMol, '_fpInfo')
    return weights

Feature Importance

 決定木系アルゴリズムにおいて、各特徴量が判別/回帰にどれぐらい寄与したかを示す指標です。改めて解説するよりはググれば解説記事が見つかるのでそちらを挙げるに留めます。

SHAP(SHapley Additive exPlanations) values

 2017年末に出てきた手法ですが、同時に使いやすいライブラリーが出てきたこともあって急速に利活用が進んでいる手法です。ケムインフォマティクスの分野でも早速活用・解析した論文が出ています。よく「ゲーム理論に基づいて...」という説明がなされますが、考え方としてはシンプルに、ある特徴量がない場合の予測値の差分をその特徴量の重要度として計算しているものです。こちらも詳細は先達の記事へ。

溶解度データセットを用いた機械学習に適用

 まずはRDKitの溶解度データからランダムフォレストでモデルを作ります。

# ref. https://iwatobipen.wordpress.com/2018/11/07/visualize-important-features-of-machine-leaning-rdkit/

sol_classes = {'(A) low': 0, '(B) medium': 1, '(C) high': 2}
X_train = np.array([mol2fp(m)[0] for m in train_mols])
y_train = np.array([sol_classes[m.GetProp('SOL_classification')] for m in train_mols], dtype=np.int)
X_test = np.array([mol2fp(m)[0] for m in test_mols])
y_test = np.array([sol_classes[m.GetProp('SOL_classification')] for m in test_mols], dtype=np.int)
# len(X_train) => 1025, len(X_test) -> 257

clf = RandomForestClassifier(random_state=123)
clf.fit(X_train, y_train)
print(accuracy_score(y_test, clf.predict(X_test)))
#  0.7120622568093385

 まずまずのモデルが出来ました。それではGetAtomicWeightsForModelから試してみましょう。最初に幾つか関数を定義します。

# 目的のクラスのprobabilityを返す関数
def get_proba(fp, proba_fn, class_id):
    return proba_fn((fp,))[0][class_id]

# デフォルトでGetMorganFingerprintの引数が2048ビットになっているため
def fp_partial(nBits):
    return functools.partial(SimilarityMaps.GetMorganFingerprint, nBits=nBits)

# モデルの予測と正解を出力
def show_pred_results(mol, model):
    y_pred = model.predict(mol2fp(mol)[0].reshape((1,-1)))
    sol_dict = {val: key for key, val in sol_classes.items()}
    print(f"True: {mol.GetProp('SOL_classification')} vs Predicted: {sol_dict[y_pred[0]]}")

 これらを用いて、以下のように使います。(前回定義したplot_explainable_images を少し変えて、weightsを指定できるようにしています。)

mol = test_mols[141]
show_pred_results(mol, clf)
# Class 2 (High)に対する値を用いるためget_probaの class_id=2を指定
weights = SimilarityMaps.GetAtomicWeightsForModel(mol, fp_partial(1024), lambda x: get_proba(x, clf.predict_proba, 2))
plot_explainable_images(mol, weight_fn=None, weights=weights, atoms=['C', 'N', 'O', 'S', 'F', 'Cl', 'P', 'Br'])

f:id:rkakamilan:20191216022503p:plain
test_mols[141]

 クラスの予測は外れてますが、全体としてヘテロ原子が赤に、sp2炭素が青になっており、何となくイメージに合っていそうです。

f:id:rkakamilan:20191216022543p:plain
test_mols[79]

 こちらはカルボン酸とメチレン鎖で色分けできていてよりイメージ通りの結果になりました。

 続いて、Feature Importanceです。各bitのFeature Importanceはclf.feature_importances_で取得できます。一方、以下で取得できる分子のbitinfoには、各ビットにおける(原子ID, radius)がタプルで入っています。

# https://iwatobipen.wordpress.com/2018/11/07/visualize-important-features-of-machine-leaning-rdkit/

def mol2fp(mol,radius=2, nBits=1024):
    bitInfo={}
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=nBits, bitInfo=bitInfo)
    arr = np.zeros((1,))
    DataStructs.ConvertToNumpyArray(fp, arr)
    return arr, bitInfo

fp, bitinfo = mol2fp(test_mols[141])
>>>bitinfo
{0: ((4, 2),),
 33: ((0, 0), (5, 0), (10, 0), (6, 2)),
 121: ((0, 1), (5, 1), (10, 1)),
 179: ((7, 2),),
 234: ((1, 2),),
 283: ((12, 2),),
 314: ((3, 1), (13, 1)),
 330: ((11, 2),),
 356: ((2, 0), (6, 0), (11, 0), (12, 0)),
 378: ((7, 0),),
 385: ((9, 1),),
 400: ((2, 2),),
 416: ((11, 1),),
 428: ((7, 1),),
 463: ((9, 2),),
 493: ((8, 2),),
 504: ((12, 1),),
 564: ((4, 1), (1, 1)),
 650: ((3, 0), (13, 0)),
 672: ((6, 1),),
 771: ((2, 1),),
 849: ((8, 0),),
 932: ((8, 1),),
 935: ((1, 0), (4, 0), (9, 0))}

 各bitのimportance値を、そのbitの中心の原子の値として各原子の重みを計算しました。上記の例のようにbit collisionがあり、その扱いには考慮が必要ですが、今回は単純に各bitに入っている原子数で平均しました。

def weights_from_feat_imp(mol, feature_importance, collided_bits='mean'):
    fp, bitinfo = mol2fp(mol)
    weights = np.zeros(mol.GetNumAtoms(), ) 
    for bit, infos in bitinfo.items():
        for atom_infos in infos:            
            if collided_bits == 'mean':
                # 同じbitに入っている原子数で平均
                weights[atom_infos[0]] += feature_importance[bit]/ len(infos)
            else:
                weights[atom_infos[0]] += feature_importance[bit]
    return weights

 この方法で計算した重みを可視化したのが下図です。コハク酸のカルボン酸が強くハイライトされているのはイメージに合いますが、全体的に分かりづらいですね。Feature Importanceが非負値であるため、色の濃淡のみでハイライトするこの手法とは相性が悪いようです。Feature Importanceについては、特に重要なbitに絞って(※)可視化した方が視認性が上がるかも知れません。 (※ Visualize important features of machine leaning #RDKit )

f:id:rkakamilan:20191216023100p:plain
test_mols[141]_FI

f:id:rkakamilan:20191216023203p:plain
test_mols[79]_FI

 最後にSHAPです。公式のレポジトリーを参考にして、以下のようにSHAP値を算出します。後は同様に可視化します。

explainer = shap.TreeExplainer(clf, data=X_train, feature_dependence="independent")
shap_values = explainer.shap_values(fp.reshape((1,-1)))[class_id]

 

f:id:rkakamilan:20191216023435p:plain
test_mols[141]_SHAP

f:id:rkakamilan:20191216023514p:plain
test_mols[79]_SHAP

 SHAPでは正負両方の値が出てくるため、GetAtomicWeightsForModelと同様に感覚的に捉えやすい結果になったと思います。

 前回紹介した可視化手法を機械学習と組み合わせ、GetAtomicWeightsForModel / Feature Importance / SHAPという3つの重み計算方法を試してみました。bitinfoを活用すれば原子ごとの重みの計算はできましたが、bit collisionやradiusの考慮等、まだ検討すべき点はありそうです。より大きいサイズのデータセットやモデルチューニングによるパフォーマンス向上の効果も見ておきたいところです。今回は時間が足りず可視化するところまでで終わりましたが、改めて時間を作って追加実験したいと思います。

(2019/12/17 07:30, コードをnbviewerに変更)

(2020/1/13, nbviewer再アップロード)