化合物の可視化を機械学習に適用してみた
はじめに
創薬 (dry) Advent Calendar 2019 第15日目の記事です。
前回は化合物を構成する原子ごとの重みを元に、原子ハイライト+ヒートマップ/棒グラフの補助図で可視化する方法を紹介しました。本記事ではその手法を機械学習に適用してみようと思います。
前回の記事では以下の形式で算出した weights
を用いれば同様に可視化できるように定義しました。
weights = weight_fn(mol)
本記事ではこの形式に合わせて、以下の3つの手法で重みを計算します。
rdkit.Chem.Draw.SimilarityMaps.GetAtomicWeightsForModel
- Feature Importance
- SHAP values
[各手法の概説]
GetAtomicWeightsForModel
RDKitのCookbook (Using scikit-learn with RDKit)でも紹介されている GetSimilarityMapForModel
の中で読まれているメソッドです。以下の通り、Fingerprint全体 baseFP
を用いて算出した予測値baseProba
と、原子ごとで作成したnewFP
を用いて算出したnewProba
の差を以って各原子の重みとしています。
# https://github.com/rdkit/rdkit/blob/master/rdkit/Chem/Draw/SimilarityMaps.py より def GetAtomicWeightsForModel(probeMol, fpFunction, predictionFunction): """ Calculates the atomic weights for the probe molecule based on a fingerprint function and the prediction function of a ML model. Parameters: probeMol -- the probe molecule fpFunction -- the fingerprint function predictionFunction -- the prediction function of the ML model """ if hasattr(probeMol, '_fpInfo'): delattr(probeMol, '_fpInfo') probeFP = fpFunction(probeMol, -1) baseProba = predictionFunction(probeFP) # loop over atoms weights = [] for atomId in range(probeMol.GetNumAtoms()): newFP = fpFunction(probeMol, atomId) newProba = predictionFunction(newFP) weights.append(baseProba - newProba) if hasattr(probeMol, '_fpInfo'): delattr(probeMol, '_fpInfo') return weights
Feature Importance
決定木系アルゴリズムにおいて、各特徴量が判別/回帰にどれぐらい寄与したかを示す指標です。改めて解説するよりはググれば解説記事が見つかるのでそちらを挙げるに留めます。
- <Scikit-learn> ランダムフォレスト回帰のfeature_importances_の定義
- Xgboost : feature_importanceのimportance_type算出方法
- 特徴量重要度にバイアスが生じる状況ご存知ですか?
SHAP(SHapley Additive exPlanations) values
2017年末に出てきた手法ですが、同時に使いやすいライブラリーが出てきたこともあって急速に利活用が進んでいる手法です。ケムインフォマティクスの分野でも早速活用・解析した論文が出ています。よく「ゲーム理論に基づいて...」という説明がなされますが、考え方としてはシンプルに、ある特徴量がない場合の予測値の差分をその特徴量の重要度として計算しているものです。こちらも詳細は先達の記事へ。
- A Unified Approach to Interpreting Model Predictions (元論文)
- 機械学習モデルを解釈する指標SHAPについて (定式の抜粋と利用例も書いてあった)
- slundberg/shap (公式レポ、この通りに動かしたら感覚分かる)
- Interpretation of Compound Activity Predictions from Complex Machine Learning Models Using Local Approximations and Shapley Values (Fingerprint使って得られたSHAPの解析を色々実施)
- ADMET Evaluation in Drug Discovery. 19. Reliable Prediction of Human Cytochrome P450 Inhibition Using Artificial Intelligence Approaches (記述子でSHAPを見ている)
- 機械学習モデルの判断根拠の説明 (SHAPも含めた他の手法が非常に多く取り上げられている)
溶解度データセットを用いた機械学習に適用
まずはRDKitの溶解度データからランダムフォレストでモデルを作ります。
# ref. https://iwatobipen.wordpress.com/2018/11/07/visualize-important-features-of-machine-leaning-rdkit/ sol_classes = {'(A) low': 0, '(B) medium': 1, '(C) high': 2} X_train = np.array([mol2fp(m)[0] for m in train_mols]) y_train = np.array([sol_classes[m.GetProp('SOL_classification')] for m in train_mols], dtype=np.int) X_test = np.array([mol2fp(m)[0] for m in test_mols]) y_test = np.array([sol_classes[m.GetProp('SOL_classification')] for m in test_mols], dtype=np.int) # len(X_train) => 1025, len(X_test) -> 257 clf = RandomForestClassifier(random_state=123) clf.fit(X_train, y_train) print(accuracy_score(y_test, clf.predict(X_test))) # 0.7120622568093385
まずまずのモデルが出来ました。それではGetAtomicWeightsForModelから試してみましょう。最初に幾つか関数を定義します。
# 目的のクラスのprobabilityを返す関数 def get_proba(fp, proba_fn, class_id): return proba_fn((fp,))[0][class_id] # デフォルトでGetMorganFingerprintの引数が2048ビットになっているため def fp_partial(nBits): return functools.partial(SimilarityMaps.GetMorganFingerprint, nBits=nBits) # モデルの予測と正解を出力 def show_pred_results(mol, model): y_pred = model.predict(mol2fp(mol)[0].reshape((1,-1))) sol_dict = {val: key for key, val in sol_classes.items()} print(f"True: {mol.GetProp('SOL_classification')} vs Predicted: {sol_dict[y_pred[0]]}")
これらを用いて、以下のように使います。(前回定義したplot_explainable_images
を少し変えて、weightsを指定できるようにしています。)
mol = test_mols[141] show_pred_results(mol, clf) # Class 2 (High)に対する値を用いるためget_probaの class_id=2を指定 weights = SimilarityMaps.GetAtomicWeightsForModel(mol, fp_partial(1024), lambda x: get_proba(x, clf.predict_proba, 2)) plot_explainable_images(mol, weight_fn=None, weights=weights, atoms=['C', 'N', 'O', 'S', 'F', 'Cl', 'P', 'Br'])
クラスの予測は外れてますが、全体としてヘテロ原子が赤に、sp2炭素が青になっており、何となくイメージに合っていそうです。
こちらはカルボン酸とメチレン鎖で色分けできていてよりイメージ通りの結果になりました。
続いて、Feature Importanceです。各bitのFeature Importanceはclf.feature_importances_
で取得できます。一方、以下で取得できる分子のbitinfoには、各ビットにおける(原子ID, radius)がタプルで入っています。
# https://iwatobipen.wordpress.com/2018/11/07/visualize-important-features-of-machine-leaning-rdkit/ def mol2fp(mol,radius=2, nBits=1024): bitInfo={} fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=nBits, bitInfo=bitInfo) arr = np.zeros((1,)) DataStructs.ConvertToNumpyArray(fp, arr) return arr, bitInfo fp, bitinfo = mol2fp(test_mols[141])
>>>bitinfo {0: ((4, 2),), 33: ((0, 0), (5, 0), (10, 0), (6, 2)), 121: ((0, 1), (5, 1), (10, 1)), 179: ((7, 2),), 234: ((1, 2),), 283: ((12, 2),), 314: ((3, 1), (13, 1)), 330: ((11, 2),), 356: ((2, 0), (6, 0), (11, 0), (12, 0)), 378: ((7, 0),), 385: ((9, 1),), 400: ((2, 2),), 416: ((11, 1),), 428: ((7, 1),), 463: ((9, 2),), 493: ((8, 2),), 504: ((12, 1),), 564: ((4, 1), (1, 1)), 650: ((3, 0), (13, 0)), 672: ((6, 1),), 771: ((2, 1),), 849: ((8, 0),), 932: ((8, 1),), 935: ((1, 0), (4, 0), (9, 0))}
各bitのimportance値を、そのbitの中心の原子の値として各原子の重みを計算しました。上記の例のようにbit collisionがあり、その扱いには考慮が必要ですが、今回は単純に各bitに入っている原子数で平均しました。
def weights_from_feat_imp(mol, feature_importance, collided_bits='mean'): fp, bitinfo = mol2fp(mol) weights = np.zeros(mol.GetNumAtoms(), ) for bit, infos in bitinfo.items(): for atom_infos in infos: if collided_bits == 'mean': # 同じbitに入っている原子数で平均 weights[atom_infos[0]] += feature_importance[bit]/ len(infos) else: weights[atom_infos[0]] += feature_importance[bit] return weights
この方法で計算した重みを可視化したのが下図です。コハク酸のカルボン酸が強くハイライトされているのはイメージに合いますが、全体的に分かりづらいですね。Feature Importanceが非負値であるため、色の濃淡のみでハイライトするこの手法とは相性が悪いようです。Feature Importanceについては、特に重要なbitに絞って(※)可視化した方が視認性が上がるかも知れません。 (※ Visualize important features of machine leaning #RDKit )
最後にSHAPです。公式のレポジトリーを参考にして、以下のようにSHAP値を算出します。後は同様に可視化します。
explainer = shap.TreeExplainer(clf, data=X_train, feature_dependence="independent") shap_values = explainer.shap_values(fp.reshape((1,-1)))[class_id]
SHAPでは正負両方の値が出てくるため、GetAtomicWeightsForModelと同様に感覚的に捉えやすい結果になったと思います。
前回紹介した可視化手法を機械学習と組み合わせ、GetAtomicWeightsForModel / Feature Importance / SHAPという3つの重み計算方法を試してみました。bitinfoを活用すれば原子ごとの重みの計算はできましたが、bit collisionやradiusの考慮等、まだ検討すべき点はありそうです。より大きいサイズのデータセットやモデルチューニングによるパフォーマンス向上の効果も見ておきたいところです。今回は時間が足りず可視化するところまでで終わりましたが、改めて時間を作って追加実験したいと思います。
(2019/12/17 07:30, コードをnbviewerに変更)
(2020/1/13, nbviewer再アップロード)