DOORS DX

ベストなDXへの入り口が
見つかるメディア

自然言語処理分野で用いられる敵対的学習手法について

公開日
2022.08.23
更新日
2024.03.15

当社データサイエンティストが、自然言語処理分野でよく用いられる「敵対的学習手法」から、「FGM(Fast Gradient Method)」「AWP(Adversarial Weight Perturbation)」手法をピックアップしてご紹介します。

こんにちは。アナリティクスサービス部の佐々木です。
今回は、自然言語処理の分野においてよく用いられる「敵対的学習手法」についてご紹介します。

敵対的学習とは

深層学習モデルの発展により、自然言語処理分野を含め様々な分野で高い精度のモデルを作成できるようになってきた一方、入力に小さな摂動(perturbation)を加えただけで誤ったクラスに分類してしまう、敵対的サンプルと呼ばれるサンプルが存在することが指摘されています。

Intriguing properties of neural networks
Figure 5 敵対的サンプルの例
各画像の左の列がオリジナルのサンプル、中央の列が摂動、右の列が敵対的サンプルとなっており、
オリジナルのサンプルは正しく分類できているが、敵対的サンプルは全て「ダチョウ」と分類されている

敵対的学習は、オリジナルのサンプルに加え、上記のような敵対的サンプルについても正しく分類できるような学習手法となっています。

敵対的学習を行うことで、単にモデルの頑健性が高まるだけでなく、オリジナルのサンプルに対してのパフォーマンスも向上することが知られており、特に近年の自然言語処理のコンペでは様々な敵対的学習手法が用いられ、それらが実際にモデルの精度向上に寄与することが報告されています。

敵対的学習、と一口にいってもそのアプローチは様々であり、それらの方法を知っておくことは実務家にとっても重要と考えられます。今回は自然言語処理でよく使われる手法をピックアップしてご紹介し、実際のコンペデータでの性能を試してみたいと思います。

敵対的学習の手法

FGM(Fast Gradient Method)

FGMでは、学習のプロセスにおいてコスト関数に以下の項を追加します。

ここで、\(x\)はモデルへの入力(自然言語処理の場合、入力は文章そのものではなく埋め込み表現を想定します)、\(θ\)はモデルのパラメータ、\(r\)は摂動、\(\hat{θ}\)は現在のモデルのパラメータとなっています。

学習の各stepおいて、現在のモデルの精度を、ノルムが\(ϵ\)の範囲で最も悪化させる敵対的な摂動[tex:{r}_{adv}]を求め、その摂動を加えた際の入力に対してもコスト関数が最小化されるよう学習することで、敵対的サンプルについても正しく分類できるような学習を行うことが可能となります。

上記の学習のためには敵対的な摂動を求める必要がありますが、深層学習モデルなど多くのモデルではこの値を正確に求めることが難しくなっています。

そこで、Explaining and Harnessing Adversarial Examples において、\({r}_{adv}\)の求め方として以下のような方法が提案されました。

モデルの精度を悪化させる方向を微分によって求める方法となっており、これは深層学習においては誤差逆伝播を利用することで簡単に計算することができます。

FGMのpytorchでの実装、使い方は以下のようになります。

実装

# reference: https://www.kaggle.com/c/tweet-sentiment-extraction/discussion/143764
class FGM():
    def __init__(self, model):
        self.model = model
        self.backup = {}

    def attack(self, epsilon=1., emb_name='word_embeddings'):
       """
        敵対的な摂動を求め、現在のembedding layerに摂動を加える
        """
        for name, param in self.model.named_parameters():
            if param.requires_grad and emb_name in name:
                self.backup[name] = param.data.clone()
                norm = torch.norm(param.grad)
                if norm != 0:
                    r_at = epsilon * param.grad / norm
                    param.data.add_(r_at)

    def restore(self, emb_name='word_embeddings'):
        """
        敵対的な摂動を求める際に変更してしまったembedding layerのパラメータについて
        元のパラメータを代入する
        """
        for name, param in self.model.named_parameters():
            if param.requires_grad and emb_name in name:
                assert name in self.backup
                param.data = self.backup[name]
            self.backup = {}

使い方

fgm = FGM(model)
 for batch_input, batch_label in data:
       # オリジナルのサンプルについての損失を計算
       loss = model(batch_input, batch_label)
       loss.backward()  

       # adversarial training
       # embedding layerに敵対的な摂動を加える
       fgm.attack() 
       # 敵対的な摂動を加えられた状態での損失を計算
       loss_adv = model(batch_input, batch_label)
       loss_adv.backward() 
       fgm.restore()  

       optimizer.step()
       model.zero_grad()

AWP(Adversarial Weight Perturbation)

FGMにおいては、敵対的な摂動はモデルの入力(入力の埋め込み表現)に対して加えられることになりますが、摂動をモデルの「入力」ではなくモデルの「重み」に加える、という方法も考えられます。この方法として代表的なのがAWP(Adversarial Weight Perturbation)です。

コンペにおいては以下の実装がよく用いられています。原論文においてはモデルの重みに対しての摂動だけではなく、同時に入力に対しての摂動も加えるなど、コンペで用いられている実装とは差異がありますので、原論文の実装が気になる方は論文を参照していただければと思います。

実装

# reference: https://www.kaggle.com/code/wht1996/feedback-nn-train/notebook
class AWP:
    def __init__(
        self,
        model,
        optimizer,
        adv_param="weight",
        adv_lr=1,
        adv_eps=0.2,
        start_epoch=0,
        adv_step=1,
        scaler=None
    ):
        self.model = model
        self.optimizer = optimizer
        self.adv_param = adv_param
        self.adv_lr = adv_lr
        self.adv_eps = adv_eps
        self.start_epoch = start_epoch
        self.adv_step = adv_step
        self.backup = {}
        self.backup_eps = {}
        self.scaler = scaler

    def attack_backward(self, x, y, attention_mask,epoch):
        """
        敵対的な摂動を加えた損失を計算し、パラメータを更新する
        """
        if (self.adv_lr == 0) or (epoch < self.start_epoch):
            return None

        self._save() 
        for i in range(self.adv_step):
            self._attack_step() 
            with torch.cuda.amp.autocast():
                adv_loss, tr_logits = self.model(input_ids=x, attention_mask=attention_mask, labels=y)
                adv_loss = adv_loss.mean()
            self.optimizer.zero_grad()
            self.scaler.scale(adv_loss).backward()
            
        self._restore()

    def _attack_step(self):
        """
        敵対的な摂動を求め、重みに加える
        重みの範囲をbackup_epsで制限している
        """
        e = 1e-6
        for name, param in self.model.named_parameters():
            if param.requires_grad and param.grad is not None and self.adv_param in name:
                norm1 = torch.norm(param.grad)
                norm2 = torch.norm(param.data.detach())
                if norm1 != 0 and not torch.isnan(norm1):
                    r_at = self.adv_lr * param.grad / (norm1 + e) * (norm2 + e)
                    param.data.add_(r_at)
                    param.data = torch.min(
                        torch.max(param.data, self.backup_eps[name][0]), self.backup_eps[name][1]
                    )

    def _save(self):
        """
        重みのバックアップと、重みの範囲を取得する
        重みの範囲はパラメータの絶対値とadv_epsによって決定する
        """
        for name, param in self.model.named_parameters():
            if param.requires_grad and param.grad is not None and self.adv_param in name:
                if name not in self.backup:
                    self.backup[name] = param.data.clone()
                    grad_eps = self.adv_eps * param.abs().detach()
                    self.backup_eps[name] = (
                        self.backup[name] - grad_eps,
                        self.backup[name] + grad_eps,
                    )

    def _restore(self):
        """
        バックアップを取っていたパラメータを代入するとともに初期化する
        """
        for name, param in self.model.named_parameters():
            if name in self.backup:
                param.data = self.backup[name]
        self.backup = {}
        self.backup_eps = {}

おわりに

今回は自然言語処理分野で用いられる敵対的学習手法についていくつか紹介しました。今後も様々な敵対的学習手法が出てくると思いますので、引き続きキャッチアップしていきたいと考えています。


このページをシェアする

あなたにオススメの記事

株式会社ブレインパッドについて

2004年の創業以来、「データ活用の促進を通じて持続可能な未来をつくる」をミッションに掲げ、データの可能性をまっすぐに信じてきたブレインパッドは、データ活用を核としたDX実践経験により、あらゆる社会課題や業界、企業の課題解決に貢献してきました。 そのため、「DXの核心はデータ活用」にあり、日々蓄積されるデータをうまく活用し、データドリブン経営に舵を切ることであると私達は考えています。

メールマガジン

Mail Magazine