Sanity Checks for Saliency Maps of Molecules

はじめに

 これまでQSARの結果を可視化する方法を試してきました。

rkakamilan.hatenablog.com

rkakamilan.hatenablog.com

 これらはそれぞれ視覚的になんとなく上手くモデルを説明できてそうに見えます。しかし、使う側からすると"なんとなく"では困ります。しかも前回の記事でも触れたように、そもそも説明可能性の手法自体にも疑問が呈されているのが最近の状況です。このような状況を考えると、できれば定量的に説明可能性を評価したいものです。

 最近各種saliency methodが報告されている画像処理の分野でsaliency methodの正当性 (adequacy)を評価する方法が考案・報告されておりますので、今回はこの手法を化合物のQSARで試してみようと思います。

"Sanity Checks for Saliency Maps"について

 本論文では以下のrandomization testを提唱しています。

  • Model parameter randomization test

 学習したモデルのsaliency mapと、モデルの重みをランダムに初期化した場合のsaliency mapとを比較。両マップが似ているならば、それは学習したモデルの重みに依存しない (=似ていない方がsaliency methodとして妥当)。

  • Data randomization test

 学習したモデルのsaliency mapと、教師データの目的変数(ラベル)をランダムに入れ替えて学習した場合のsaliency mapとを比較。両マップが似ているならば、それは教師データに依存しない (=似ていない方がsaliency methodとして妥当)。

 "似ている"ことの評価には、saliency mapのSSIMやHOG、Spearman順位相関係数を用いています。

Sanity Check for QSAR Model

 今回はNN系ではないのでdata randomization testを適用しました。

 学習に用いたトレーニングデータについて可視化を行います。

train_path = os.path.join(RDPaths.RDDocsDir, 'Book/data/solubility.train.sdf')
test_path = os.path.join(RDPaths.RDDocsDir, 'Book/data/solubility.test.sdf')
train_mols = [m for m in Chem.SDMolSupplier(train_path) if m is not None]
test_mols = [m for m in Chem.SDMolSupplier(test_path) if m is not None]
print(len(mols))
# 1282
sol_classes = {'(A) low': 0, '(B) medium': 1, '(C) high': 2}
X_train = np.array([mol2fp(m)[0] for m in train_mols])
y_train = np.array([sol_classes[m.GetProp('SOL_classification')] for m in train_mols], dtype=np.int) 
X_test = np.array([mol2fp(m)[0] for m in test_mols])
y_test = np.array([sol_classes[m.GetProp('SOL_classification')] for m in test_mols], dtype=np.int)

 ラベルをランダムにシャッフルしたモデルも作成します。

rf_models = []
clf = RandomForestClassifier(random_state=20200119)
clf.fit(X_train, y_train)
print(f'Train: {accuracy_score(y_train, clf.predict(X_train)):.3f}, Test: {accuracy_score(y_test, clf.predict(X_test)):.3f}')
rf_models.append(clf)
for _ in range(10):
    y_perm = np.random.permutation(y_train)
    _clf = RandomForestClassifier(random_state=20200119)
    _clf.fit(X_train, y_perm)
    rf_models.append(_clf)
    print(f'Shuffuled Data: {accuracy_score(y_perm, _clf.predict(X_train)):.3f}, True Train: {accuracy_score(y_train, _clf.predict(X_train)):.3f}, Test: {accuracy_score(y_test, _clf.predict(X_test)):.3f}')
          
# Train: 0.992, Test: 0.685
# Shuffuled Data: 0.985, True Train: 0.352, Test: 0.370
# Shuffuled Data: 0.985, True Train: 0.371, Test: 0.401
# Shuffuled Data: 0.985, True Train: 0.363, Test: 0.350
# Shuffuled Data: 0.980, True Train: 0.368, Test: 0.358
# Shuffuled Data: 0.981, True Train: 0.354, Test: 0.389
# Shuffuled Data: 0.983, True Train: 0.365, Test: 0.389
# Shuffuled Data: 0.984, True Train: 0.355, Test: 0.393
# Shuffuled Data: 0.982, True Train: 0.363, Test: 0.444
# Shuffuled Data: 0.986, True Train: 0.378, Test: 0.475
# Shuffuled Data: 0.980, True Train: 0.359, Test: 0.401

 論文ではvisual inspection / assessment はあまりよろしくないとの述べられておりましたが、一応ランダムにシャッフルした結果を可視化してみましょう。rdkitのSimilarityMaps.GetAtomicWeightsForModelで描画しました。

f:id:rkakamilan:20200127203949p:plain
visual inspection

 正しいデータで学習した場合とシャッフルした場合では、全く違う重みが得られており、saliency methodとして妥当そうに見えます。

 次に構造式での描画だけでなく、グラフでも示してみます。正しいデータで学習した時の重みを折れ線グラフ (緑)と右の補助図で示し、シャッフルした時の重みを箱ひげ図で示しました。

f:id:rkakamilan:20200127204136p:plain
variation sanity check

 正しいデータで学習したモデルの重みとランダムにシャッフルした場合の重みを比較すると、シャッフルした場合の重みが元と全く違う値を示していることが分かります。saliency methodとしては妥当であると言えそうです。

 では、シャッフルの前後でsaliencyの結果がどのように変わるか、類似性で評価してみます。原子ごとの重みがarrayで得られていますので、元の重みとシャッフルした場合の重みの相関係数 (Pearson, Spearman)を見てみます。

f:id:rkakamilan:20200128000820p:plain
correlation sanity check

 クラス数が少ないためか、ランダムな試行の中で相関が高く出ている時もありますが、10回の試行全体で見ると相関係数は低く、このチェックでも妥当そうな結果です。

 では、データセット全体で重みを計算してみます。RandomForestでのSimilarityMapsに加えて、LightGBMとSHAPも計算してみました。

f:id:rkakamilan:20200128083138p:plain
correlation RF vs LGB

 RandomForest / LightGBM、SimilarityMaps / SHAPいずれも0付近の相関係数を示しており、saliency methodとしては妥当と言えそうです。

 RandomForestのトレーニングデータとテストデータでも比較してみました。

f:id:rkakamilan:20200128083402p:plain
correlation train vs test

 トレーニングとテストでAccuracyに差があったので、それが何らか相関係数の差にも反映されないかと思って見てみましたが、ほとんど違いは認められないですね。

おわりに

 画像データで提唱されているsanity checkを化合物のQSARにも適用してみました。rdkitのSimilarityMapsもSHAPも、randomization testの結果から妥当なsaliency methodであると言えそうです。一方、定量的にどの程度のばらつきや閾値であれば妥当と言えるのかの感覚は今回の実験だけでは掴めなかったので別のデータセットアルゴリズムでも試してみます。

[2020/01/29 07:53 typo修正]

[2020/02/01 0:14 原先生のスライドシェアのリンク追加]