Visualization with chainer-saliency

はじめに

 以下の2回で補助図を用いた化合物の可視化を試してみました。今回は補助図描画をchainer-saliency の算出結果に適用してみます。

rkakamilan.hatenablog.com

rkakamilan.hatenablog.com

chainer-saliencyについて

 以下の論文(1)をchainerで実装したもので、記事(2)にて紹介されています。記事(2)が公開された時点では、chainer-saliencyのレポジトリにありましたが、現在はchainer-chemistry内に入っています。

  1. BayesGrad: Explaining Predictions of Graph Convolutional Networks

  2. NNの予測根拠可視化をライブラリ化する

Saliency計算と可視化

 chainer-saliencyにあるexampleのjupyter notebookを参考にし、一部書き方を修正しながら進めます。

 notebookでは GraphConvPredictor から定義してますが、chainer-chemistryにはモデルを簡単に定義出来るchainer_chemistry.models.prediction.set_up_predictorが入ってますのでこれを使います。データセットはexampleと同じdelaneyの溶解度データセットです。

# パラメータはexampleのそのまま
n_unit = 32
conv_layers = 4
class_num = 1

predictor = set_up_predictor('ggnn',  n_unit=n_unit, conv_layers=conv_layers,
    class_num=class_num, postprocess_fn=activation_relu_dropout)
regressor = Regressor(predictor, device=device)

# fitはexampleのまま
fit(regressor, train, batchsize=16, epoch=40, device=device)

 学習したモデルを用いてsaliencyを計算します。chainer-chemistry にはVanillaGrad, SmoothGrad, BayesGradの3種類 (文献1,2参照)が実装され、example notebookで紹介されています。

calculator = IntegratedGradientsCalculator(
    predictor, steps=5, eval_fun=eval_fun, target_extractor=VariableMonitorLinkHook(predictor.graph_conv.embed, timing='post'),
    device=device)

M = 5 # サンプリング数
saliency_samples_vanilla = calculator.compute(
    train, M=1, converter=concat_mols)
saliency_samples_smooth = calculator.compute(
    train, M=M, converter=concat_mols, noise_sampler=GaussianNoiseSampler())
saliency_samples_bayes = calculator.compute(
    train, M=M, converter=concat_mols, train=True)

 得られたsaliencyはBaceCalculatorクラスのaggregateで集計します。

method = 'raw'
saliency_agg_vanilla = calculator.aggregate(
    saliency_samples_vanilla, ch_axis=3, method=method)
print(f'Before aggregation: {saliency_samples_vanilla.shape} -> After: {saliency_agg_vanilla.shape}')
# Before aggregation: (1, 902, 55, 32) -> After: (902, 55)
saliency_agg_smooth = calculator.aggregate(
    saliency_samples_smooth, ch_axis=3, method=method)
print(f'Before aggregation: {saliency_samples_smooth.shape} -> After: {saliency_agg_smooth.shape}')
# Before aggregation: (5, 902, 55, 32) -> After: (902, 55)
saliency_agg_bayes = calculator.aggregate(
    saliency_samples_bayes, ch_axis=3, method=method)
print(f'Before aggregation: {saliency_samples_bayes.shape} -> After: {saliency_agg_bayes.shape}')
# Before aggregation: (5, 902, 55, 32) -> After: (902, 55)

 これらsaliency_agg_xxxには学習に用いた902化合物ごとに原子の重みが入っており、各化合物のインデックスにアクセスすることで重み/saliencyが取り出せます。このsaliencyを用いてプロットする関数は、これまで用いていた関数を少し変更してます。

 重みのscalingにrdkit.Chem.Draw.SimilarityMaps.GetStandardizedWeightsを用いていましたが、chainer-chemistryにも幾つかscalingのメソッドが入っているためこれを指定できる形式にしました。

def plot_mol_saliency(mol, saliency, scaler=None, atoms=['C', 'N', 'O', 'S', 'F', 'Cl', 'P', 'Br']):
    n_atom = mol.GetNumAtoms()
    symbols = [f'{mol.GetAtomWithIdx(i).GetSymbol()}_{i}' for i in range(mol.GetNumAtoms())]
    df = pd.DataFrame(columns=atoms)
    saliency = saliency[:n_atom]
    num_atoms = mol.GetNumAtoms()
    if scaler is not None:
        vmax = np.max(np.abs(saliency))
        weights = scaler(saliency)
    else:
        weights, vmax = SimilarityMaps.GetStandardizedWeights(saliency)
    
    arr = np.zeros((num_atoms, len(atoms)))
    for i in range(mol.GetNumAtoms()):
        _a = mol.GetAtomWithIdx(i).GetSymbol()
        arr[i,atoms.index(_a)] = weights[i]
    df = pd.DataFrame(arr, index=symbols, columns=atoms)

# 以下同じ

 試しに、ID699の化合物を図示してみます。

i = 699
mol = Chem.MolFromSmiles(smiles[i])
plot_mol_saliency(mol, saliency=saliency_agg_bayes[i], scaler=abs_max_scaler, atoms=['C', 'N', 'O', 'S', 'F', 'Cl', 'P', 'Br'])

 元々化合物の描画の際のカラーマップ等はchainer-chemistryから拝借していたのでexampleとほぼ同じ描画が出来ました。(今回の記事では補助図は除いてます。)

f:id:rkakamilan:20200113161833p:plain
ID699_abs_max

スケーラーによる描画の違い

 前回の記事までは化合物の描画の際、化合物ごとに原子の重み/saliencyはrdkitのGetStandardizedWeightsを用いて、絶対値最大の値でスケーリングしていました。chainer-chemistryでもabs_max_scalerで同じスケーリングを施しています。一方、chainer-chemistryにはabs_max_scaler以外にも、最小値と最大値を用いて0-1にスケーリングするmin_max_scalerが実装されています。(他にもnormalize_scalerがありますが、負の値を含まないsaliencyのためのメソッドなので今回は割愛します。)

 min_max_scalerを使って先のID699を可視化してみます。

i = 699
mol = Chem.MolFromSmiles(smiles[i])
plot_mol_saliency(mol, saliency=saliency_agg_bayes[i], scaler=min_max_scaler, atoms=['C', 'N', 'O', 'S', 'F', 'Cl', 'P', 'Br'])

f:id:rkakamilan:20200113162037p:plain
ID699_min_max

 溶解度にプラスに寄与しているとされる原子が濃く描画されているものの、色の濃淡のみで分かりづらくなってしまいました。原子ごとの色はchainer_chemistry.saliency.visualizer.visualizer_utils.red_blue_cmapで決定していましたが、これは元々-1~1にノーマライズされた重み/saliencyに対応するために定義されたものであるため、min_max_scalerとの相性は良くないようです。

 では、abs_max_scalerが最適なのか。これを考えるために新しくscalerを考えます。与えた重み/saliencyの絶対値最大ではなく、指定した値でスケーリングするfixed_val_scalerを定義しました。

def _fixed_val_scaler(saliency, max_val, logger=None):
    try: 
        xp = cuda.get_array_module(saliency)
    except:
        xp = saliency
    maxv = max_val
    if maxv <= 0:
        logger = logger or getLogger(__name__)
        logger.info('All saliency value is 0')
        return xp.zeros_like(saliency)
    else:
        return saliency / maxv

def fixed_val_scaler(max_val):
    return functools.partial(_fixed_val_scaler, max_val=max_val)
np.unravel_index(np.argmax(np.abs(saliency_agg_bayes)), saliency_agg_bayes.shape), np.max(np.abs(saliency_agg_bayes))
# ((455, 10), 1.1827562)

 ID455の化合物がデータセット中で最大の重み (1.1827562)を有することがわかりました。この値をfixed_val_scalerのmax_valとして指定します。ID455の場合は、当然ですがabs_max_scalerfixed_val_scalerで違いはありません。

f:id:rkakamilan:20200113162644p:plain
ID455_abs_max
f:id:rkakamilan:20200113162705p:plain
ID455_fixed

 次に重み/saliencyの絶対値最大の値が小さい化合物で描画してみます。下記の通り、ID762が2番目に小さい重み (0.114224076)を有するようです。(一番小さい化合物はメタンだったので例には不適でした)

cpid_2nd_min = np.argsort(np.max(np.abs(saliency_agg_bayes), axis=1))[1]
cpid_2nd_min, np.max(np.abs(saliency_agg_bayes[cpid_2nd_min]))
# (762, 0.114224076)

f:id:rkakamilan:20200113163207p:plain
ID762_abs_max
f:id:rkakamilan:20200113163226p:plain
ID762_fixed

 fixed_val_scalerを用いた方が色が薄くなっています。

 このように、abs_max_scalerを用いた場合は、一つの化合物"内"の原子の寄与度を比較する場合には良いですが、化合物"間"で比較する場合には注意が必要そうです。一方、fixed_val_scalerで単純にデータセット内の重み最大値を用いると、各化合物の描画が分かりづらくなってしまいます。下記の通り、個々の化合物の重み/saliencyの和は目的変数とよく相関しており、実際の目的変数/予測値を併記したり、目的変数/予測値で重みを補正する等の対処も可能かもしれません。

scipy.stats.pearsonr(y_train, np.sum(saliency_agg_bayes, axis=1))
# (0.8418738520866434, 2.617349583427535e-243)

おわりに

 今回はchainer-chemistryで計算できるsaliencyを元に、スケーリングによる違いを試してみました。説明可能性は機械学習において注目されており、化合物関連でも色々な手法が出てきています。一方、下記のTJOさんの記事でも紹介されていましたが、そもそも機械学習のモデルの説明可能性に対して批判的な意見も出てきているようです。

tjo.hatenablog.com

 自分としては、これら出てきた手法をそのまま実務での活用にあたってはまだまだ研究・考察が必要であると言うのが個人的な感想です。単純にケミストのintuitionとの比較だけで良し悪しを論じる前に幾つかのポイントを検証していく必要があると思います。ひとまず思いつくところを書いてみましたが、もっとあるかもしれません。ご意見等いただければ幸いです。

  1. データセット (学習データセットの中に、予測したい化合物の部分構造は十分に含まれているのか。予測したい化合物は適用範囲内か。)

  2. 機械学習アルゴリズム

  3. 説明可能性の算出アルゴリズム

  4. 描画方法 (カラーマップ、重み/saliencyのスケーリング方法、など)

 今回のコードはGistにアップしておきました。