Visualization with chainer-saliency
はじめに
以下の2回で補助図を用いた化合物の可視化を試してみました。今回は補助図描画をchainer-saliency の算出結果に適用してみます。
chainer-saliencyについて
以下の論文(1)をchainerで実装したもので、記事(2)にて紹介されています。記事(2)が公開された時点では、chainer-saliencyのレポジトリにありましたが、現在はchainer-chemistry内に入っています。
Saliency計算と可視化
chainer-saliencyにあるexampleのjupyter notebookを参考にし、一部書き方を修正しながら進めます。
notebookでは GraphConvPredictor
から定義してますが、chainer-chemistryにはモデルを簡単に定義出来るchainer_chemistry.models.prediction.set_up_predictor
が入ってますのでこれを使います。データセットはexampleと同じdelaneyの溶解度データセットです。
# パラメータはexampleのそのまま n_unit = 32 conv_layers = 4 class_num = 1 predictor = set_up_predictor('ggnn', n_unit=n_unit, conv_layers=conv_layers, class_num=class_num, postprocess_fn=activation_relu_dropout) regressor = Regressor(predictor, device=device) # fitはexampleのまま fit(regressor, train, batchsize=16, epoch=40, device=device)
学習したモデルを用いてsaliencyを計算します。chainer-chemistry にはVanillaGrad, SmoothGrad, BayesGradの3種類 (文献1,2参照)が実装され、example notebookで紹介されています。
calculator = IntegratedGradientsCalculator( predictor, steps=5, eval_fun=eval_fun, target_extractor=VariableMonitorLinkHook(predictor.graph_conv.embed, timing='post'), device=device) M = 5 # サンプリング数 saliency_samples_vanilla = calculator.compute( train, M=1, converter=concat_mols) saliency_samples_smooth = calculator.compute( train, M=M, converter=concat_mols, noise_sampler=GaussianNoiseSampler()) saliency_samples_bayes = calculator.compute( train, M=M, converter=concat_mols, train=True)
得られたsaliencyはBaceCalculator
クラスのaggregate
で集計します。
method = 'raw' saliency_agg_vanilla = calculator.aggregate( saliency_samples_vanilla, ch_axis=3, method=method) print(f'Before aggregation: {saliency_samples_vanilla.shape} -> After: {saliency_agg_vanilla.shape}') # Before aggregation: (1, 902, 55, 32) -> After: (902, 55) saliency_agg_smooth = calculator.aggregate( saliency_samples_smooth, ch_axis=3, method=method) print(f'Before aggregation: {saliency_samples_smooth.shape} -> After: {saliency_agg_smooth.shape}') # Before aggregation: (5, 902, 55, 32) -> After: (902, 55) saliency_agg_bayes = calculator.aggregate( saliency_samples_bayes, ch_axis=3, method=method) print(f'Before aggregation: {saliency_samples_bayes.shape} -> After: {saliency_agg_bayes.shape}') # Before aggregation: (5, 902, 55, 32) -> After: (902, 55)
これらsaliency_agg_xxx
には学習に用いた902化合物ごとに原子の重みが入っており、各化合物のインデックスにアクセスすることで重み/saliencyが取り出せます。このsaliencyを用いてプロットする関数は、これまで用いていた関数を少し変更してます。
重みのscalingにrdkit.Chem.Draw.SimilarityMaps.GetStandardizedWeights
を用いていましたが、chainer-chemistryにも幾つかscalingのメソッドが入っているためこれを指定できる形式にしました。
def plot_mol_saliency(mol, saliency, scaler=None, atoms=['C', 'N', 'O', 'S', 'F', 'Cl', 'P', 'Br']): n_atom = mol.GetNumAtoms() symbols = [f'{mol.GetAtomWithIdx(i).GetSymbol()}_{i}' for i in range(mol.GetNumAtoms())] df = pd.DataFrame(columns=atoms) saliency = saliency[:n_atom] num_atoms = mol.GetNumAtoms() if scaler is not None: vmax = np.max(np.abs(saliency)) weights = scaler(saliency) else: weights, vmax = SimilarityMaps.GetStandardizedWeights(saliency) arr = np.zeros((num_atoms, len(atoms))) for i in range(mol.GetNumAtoms()): _a = mol.GetAtomWithIdx(i).GetSymbol() arr[i,atoms.index(_a)] = weights[i] df = pd.DataFrame(arr, index=symbols, columns=atoms) # 以下同じ
試しに、ID699の化合物を図示してみます。
i = 699 mol = Chem.MolFromSmiles(smiles[i]) plot_mol_saliency(mol, saliency=saliency_agg_bayes[i], scaler=abs_max_scaler, atoms=['C', 'N', 'O', 'S', 'F', 'Cl', 'P', 'Br'])
元々化合物の描画の際のカラーマップ等はchainer-chemistryから拝借していたのでexampleとほぼ同じ描画が出来ました。(今回の記事では補助図は除いてます。)
スケーラーによる描画の違い
前回の記事までは化合物の描画の際、化合物ごとに原子の重み/saliencyはrdkitのGetStandardizedWeights
を用いて、絶対値最大の値でスケーリングしていました。chainer-chemistryでもabs_max_scaler
で同じスケーリングを施しています。一方、chainer-chemistryにはabs_max_scaler
以外にも、最小値と最大値を用いて0-1にスケーリングするmin_max_scaler
が実装されています。(他にもnormalize_scaler
がありますが、負の値を含まないsaliencyのためのメソッドなので今回は割愛します。)
min_max_scaler
を使って先のID699を可視化してみます。
i = 699 mol = Chem.MolFromSmiles(smiles[i]) plot_mol_saliency(mol, saliency=saliency_agg_bayes[i], scaler=min_max_scaler, atoms=['C', 'N', 'O', 'S', 'F', 'Cl', 'P', 'Br'])
溶解度にプラスに寄与しているとされる原子が濃く描画されているものの、色の濃淡のみで分かりづらくなってしまいました。原子ごとの色はchainer_chemistry.saliency.visualizer.visualizer_utils.red_blue_cmap
で決定していましたが、これは元々-1~1にノーマライズされた重み/saliencyに対応するために定義されたものであるため、min_max_scaler
との相性は良くないようです。
では、abs_max_scaler
が最適なのか。これを考えるために新しくscalerを考えます。与えた重み/saliencyの絶対値最大ではなく、指定した値でスケーリングするfixed_val_scaler
を定義しました。
def _fixed_val_scaler(saliency, max_val, logger=None): try: xp = cuda.get_array_module(saliency) except: xp = saliency maxv = max_val if maxv <= 0: logger = logger or getLogger(__name__) logger.info('All saliency value is 0') return xp.zeros_like(saliency) else: return saliency / maxv def fixed_val_scaler(max_val): return functools.partial(_fixed_val_scaler, max_val=max_val)
np.unravel_index(np.argmax(np.abs(saliency_agg_bayes)), saliency_agg_bayes.shape), np.max(np.abs(saliency_agg_bayes)) # ((455, 10), 1.1827562)
ID455の化合物がデータセット中で最大の重み (1.1827562)を有することがわかりました。この値をfixed_val_scaler
のmax_valとして指定します。ID455の場合は、当然ですがabs_max_scaler
とfixed_val_scaler
で違いはありません。
次に重み/saliencyの絶対値最大の値が小さい化合物で描画してみます。下記の通り、ID762が2番目に小さい重み (0.114224076)を有するようです。(一番小さい化合物はメタンだったので例には不適でした)
cpid_2nd_min = np.argsort(np.max(np.abs(saliency_agg_bayes), axis=1))[1] cpid_2nd_min, np.max(np.abs(saliency_agg_bayes[cpid_2nd_min])) # (762, 0.114224076)
fixed_val_scaler
を用いた方が色が薄くなっています。
このように、abs_max_scaler
を用いた場合は、一つの化合物"内"の原子の寄与度を比較する場合には良いですが、化合物"間"で比較する場合には注意が必要そうです。一方、fixed_val_scaler
で単純にデータセット内の重み最大値を用いると、各化合物の描画が分かりづらくなってしまいます。下記の通り、個々の化合物の重み/saliencyの和は目的変数とよく相関しており、実際の目的変数/予測値を併記したり、目的変数/予測値で重みを補正する等の対処も可能かもしれません。
scipy.stats.pearsonr(y_train, np.sum(saliency_agg_bayes, axis=1)) # (0.8418738520866434, 2.617349583427535e-243)
おわりに
今回はchainer-chemistryで計算できるsaliencyを元に、スケーリングによる違いを試してみました。説明可能性は機械学習において注目されており、化合物関連でも色々な手法が出てきています。一方、下記のTJOさんの記事でも紹介されていましたが、そもそも機械学習のモデルの説明可能性に対して批判的な意見も出てきているようです。
自分としては、これら出てきた手法をそのまま実務での活用にあたってはまだまだ研究・考察が必要であると言うのが個人的な感想です。単純にケミストのintuitionとの比較だけで良し悪しを論じる前に幾つかのポイントを検証していく必要があると思います。ひとまず思いつくところを書いてみましたが、もっとあるかもしれません。ご意見等いただければ幸いです。
データセット (学習データセットの中に、予測したい化合物の部分構造は十分に含まれているのか。予測したい化合物は適用範囲内か。)
説明可能性の算出アルゴリズム
描画方法 (カラーマップ、重み/saliencyのスケーリング方法、など)
今回のコードはGistにアップしておきました。