Naive-FAQ-Chatbot-3

간단한 FAQ 챗봇을 만들어보겠습니다.

이 챗봇은 간단한 형태로 챗봇을 처음 접하시는 분들을 위해 작성한 코드정도로 생각하시면 될듯합니다.

아래의 파일은 predict.py입니다.
해당 파일은 입력 받은 텍스트를 통해서 해당 텍스트가 어떤 질문인지를 예측하는 기능을 수행합니다.

예를 들어 사용자가 “회원정보를 수정하고 싶어요”라고 질문을 하면 해당 클래스는 입력 받은 질문을 형태소 분석하고 이 정보를 학습한 모델에 입력하여 해당 질문이 어떤 내용의 질문인지 찾아내 적절한 답변을 표시해주는 기능입니다.

아래의 블록은 predict.py 클래스 실행시 외부에서 입력 받는 parameter 값입니다.
parameter는 총 3가지로 질문 내용( q_message), 모델명(model_fn), 워드 벡터를 만들기 위해 입력한 파일(word_data)입니다.

def define_argparser():
    p = argparse.ArgumentParser()
    p.add_argument('--q_message', required=True)
    p.add_argument('--model_fn', required=True)
    p.add_argument('--word_data', required=True)
    config = p.parse_args()

    return config

입력 받은 질문은 미리 학습된 모델에 넣어서 적절한 값을 예측해냅니다.

def main(config):
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    # model load
    load = torch.load(config.model_fn, map_location=device)
    labels = load['labels']
    print(labels)

    IPT = 196
    H = 100
    OPT = 6

    model = FaqCategoryClassifier(IPT, H, OPT)
    model.load_state_dict(load['model'])
    
    okt = Okt()

    predict = PredictCategory(okt, model)
    
    words = fileRead()
    morphs = okt.morphs(config.q_message)
    x_data = myutils.bag_of_words(morphs,words)
    
    p = predict.getCategory(torch.FloatTensor(x_data))
    idx = torch.argmax(p)
    print('{}\n{}\n'.format(p,idx))
    print(labels[idx])

이것으로 3번에 나눠서 간단한 FAQ 챗봇에 대한 설명을 마무리하겠습니다.

해당 코드에서 입력 데이터를 만드는 부분을 자세히 설명하지 않았는데 그 이유는 입력 데이터는 각각 다양한 방법으로 만들 수 있기 때문입니다.

그에 따라서 모델의 모양도 변하기 때문입니다.
먼저는 어떤 데이터를 어떻게 만들지에 대해서 설계해보는 것이 중요합니다.

답글 남기기

이메일 주소는 공개되지 않습니다. 필수 항목은 *(으)로 표시합니다