Naive-FAQ-Chatbot-2

간단한 FAQ 챗봇을 만들어보겠습니다.

이 챗봇은 간단한 형태로 챗봇을 처음 접하시는 분들을 위해 작성한 코드정도로 생각하시면 될듯합니다.

아래의 파일은 trainer.py 입니다.

해당 파일의 기능은 모델을 훈련하고 검증하는 역할을 수행합니다.

class Trainer():
def __init__(self, model, optimizer, crit):
    self.model = model
    self.optimizer = optimizer
    self.loss = loss

    super().__init__()

위의 부분은 Trainer 클래스의 선언부로 model, optimizer, loss 값을 전달 받습니다.

def train(self, train_data, valid_data, config):
        lowest_loss = np.inf
        best_model = None

        for epoch_index in range(config.n_epochs):
            train_loss = self._train(train_data[0], train_data[1], config)
            valid_loss = self._validate(valid_data[0], valid_data[1], config)

            # You must use deep copy to take a snapshot of current best weights.
            if valid_loss <= lowest_loss:
                lowest_loss = valid_loss
                best_model = deepcopy(self.model.state_dict())

            print("Epoch(%d/%d): train_loss=%.4e  valid_loss=%.4e  lowest_loss=%.4e" % (
                epoch_index + 1,
                config.n_epochs,
                train_loss,
                valid_loss,
                lowest_loss,
            ))

        # Restore to best model.
        self.model.load_state_dict(best_model)

train() 함수는 입력 받은 데이터를 epoch 만큼 학습을 시작합니다.
이때 _train()과 _valid()함수를 호출하는데 _train()은 학습을 _valid()는 검증을 수행합니다.

입력 데이터는 단어의 one-hot 데이터를 사용합니다. 더 좋은 결과를 얻기 위해서는 one-hot 보다는 embedding된 데이터를 사용하는 것이 좋습니다. 그 이유는 one-hot의 특징상 단어간의 관계를 표현 할 수 없기 때문이며 one-hot 데이터가 sparse하기 때문입니다.

word를 embedding하는 가장 대표적인 방법인 word2vec을 사용하기를 추천합니다. 다만 여기서는 naive한 형태의 챗봇이기 때문에 one-hot을 사용하여 테스트했습니다.

    def _train(self, x, y, config):
        self.model.train()

        # Shuffle before begin.
        indices = torch.randperm(x.size(0), device=x.device)
        x = torch.index_select(x, dim=0, index=indices).split(config.batch_size, dim=0)
        y = torch.index_select(y, dim=0, index=indices).split(config.batch_size, dim=0)

        total_loss = 0

        for i, (x_i, y_i) in enumerate(zip(x, y)):
            y_hat_i = self.model(x_i)
            loss_i = self.crit(y_hat_i, y_i.squeeze())

            # Initialize the gradients of the model.
            self.optimizer.zero_grad()
            loss_i.backward()

            self.optimizer.step()

            if config.verbose >= 2:
                print("Train Iteration(%d/%d): loss=%.4e" % (i + 1, len(x), float(loss_i)))

            # Don't forget to detach to prevent memory leak.
            total_loss += float(loss_i)

        return total_loss / len(x)

위의 코드와 같이 학습을 시작합니다.
학습 데이터는 사전에 정의한 배치 사이즈에 맞춰 분할 학습을 수행합니다.
이때 현재 모델이 학습 중이라는 것을 알려주기 위해 model.train()을 선언합니다.

def _validate(self, x, y, config):
        # Turn evaluation mode on.
        self.model.eval()

        # Turn on the no_grad mode to make more efficintly.
        with torch.no_grad():
            # Shuffle before begin.
            indices = torch.randperm(x.size(0), device=x.device)
            x = torch.index_select(x, dim=0, index=indices).split(config.batch_size, dim=0)
            y = torch.index_select(y, dim=0, index=indices).split(config.batch_size, dim=0)

            total_loss = 0

            for i, (x_i, y_i) in enumerate(zip(x, y)):
                y_hat_i = self.model(x_i)
                loss_i = self.crit(y_hat_i, y_i.squeeze())

                if config.verbose >= 2:
                    print("Valid Iteration(%d/%d): loss=%.4e" % (i + 1, len(x), float(loss_i)))

                total_loss += float(loss_i)

            return total_loss / len(x)

validate 코드도 train 코드와 거의 동일합니다.
다른 점은 train에서 학습에 관련된 부분이 validate에서는 빠져있다는 부분입니다. 단순히 검증만 하는 데이터이기 때문에 학습이 일어나지 않습니다.
특히 잊지 말아야 할 것은 model.eval()을 실행시켜줘야 한다는 것입니다.

validation은 과적합을 방지하기 위해서 실행하는 것으로 대부분 데이터셋을 8:2, 7:3 정도로 분리하여 학습과 검증을 수행합니다.

이제 남은 부분은 이렇게 만들어진 모델을 통해 예측을 수행하는 코드가 남아 있습니다.

해당 코드는 Naive-FAQ-Chatbot-3에서 설명하겠습니다.

답글 남기기

이메일 주소는 공개되지 않습니다. 필수 항목은 *(으)로 표시합니다