From 68dc29a39ee7bf4fa60249e7204fa292a9fda732 Mon Sep 17 00:00:00 2001 From: hoit1302 Date: Tue, 17 May 2022 22:00:40 +0900 Subject: [PATCH] hoit1302/attiary_model:2.2 --- .dockerignore | 4 ++-- Dockerfile | 3 +-- app.py | 17 +++++++++-------- model/chatbot/kobert/chatbot.py | 10 ++++++++++ model/chatbot/kogpt2/chatbot.py | 30 +++++++++++++++++++++--------- 5 files changed, 43 insertions(+), 21 deletions(-) diff --git a/.dockerignore b/.dockerignore index 49d5166..0ab2461 100644 --- a/.dockerignore +++ b/.dockerignore @@ -12,5 +12,5 @@ Lib *.cfg preprocess -kss_example.py -*/__pycache__/* \ No newline at end of file +*/__pycache__/* +checkpoint/emotion_pn_v5.pth \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 4021af1..ce7b21a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,5 +8,4 @@ RUN pip install -r requirements.txt EXPOSE 5000 -CMD python ./app.py -# CMD ["python3", "-m", "flask", "run", "--host=0.0.0.0"] \ No newline at end of file +CMD python ./app.py \ No newline at end of file diff --git a/app.py b/app.py index 4686106..d4cc0bd 100644 --- a/app.py +++ b/app.py @@ -1,6 +1,6 @@ import os -from model.chatbot.kogpt2 import chatbot as ch_v1 -from model.chatbot.kobert import chatbot as ch_v2 +from model.chatbot.kogpt2 import chatbot as ch_kogpt2 +from model.chatbot.kobert import chatbot as ch_kobert from model.emotion import service as emotion from util.emotion import Emotion from util.depression import Depression @@ -12,6 +12,7 @@ Emotion = Emotion() Depression = Depression() +@app.route('/') def hello(): return "deep learning server is running ๐Ÿ’—" @@ -19,7 +20,7 @@ def hello(): @app.route('/emotion') def classifyEmotion(): sentence = request.args.get("s") - if sentence is None or len(sentence) == 0: + if sentence is None or len(sentence) == 0 or sentence == '\n': return jsonify({ "emotion_no": 2, "emotion": "์ค‘๋ฆฝ" @@ -36,7 +37,7 @@ def classifyEmotion(): @app.route('/diary') def classifyEmotionDiary(): sentence = request.args.get("s") - if sentence is None or len(sentence) == 0: + if sentence is None or len(sentence) == 0 or sentence == '\n': return jsonify({ "joy": 0, "hope": 0, @@ -66,12 +67,12 @@ def classifyEmotionDiary(): @app.route('/chatbot/g') def reactChatbotV1(): sentence = request.args.get("s") - if sentence is None or len(sentence) == 0: + if sentence is None or len(sentence) == 0 or sentence == '\n': return jsonify({ "answer": "๋“ฃ๊ณ  ์žˆ์–ด์š”. ๋” ๋ง์”€ํ•ด์ฃผ์„ธ์š”~ (๋„๋•๋„๋•)" }) - answer = ch_v1.predict(sentence) + answer = ch_kogpt2.predict(sentence) return jsonify({ "answer": answer }) @@ -80,12 +81,12 @@ def reactChatbotV1(): @app.route('/chatbot/b') def reactChatbotV2(): sentence = request.args.get("s") - if sentence is None or len(sentence) == 0: + if sentence is None or len(sentence) == 0 or sentence == '\n': return jsonify({ "answer": "๋“ฃ๊ณ  ์žˆ์–ด์š”. ๋” ๋ง์”€ํ•ด์ฃผ์„ธ์š”~ (๋„๋•๋„๋•)" }) - answer, category, desc, softmax = ch_v2.chat(sentence) + answer, category, desc, softmax = ch_kobert.chat(sentence) return jsonify({ "answer": answer, "category": category, diff --git a/model/chatbot/kobert/chatbot.py b/model/chatbot/kobert/chatbot.py index 189a569..86123ec 100644 --- a/model/chatbot/kobert/chatbot.py +++ b/model/chatbot/kobert/chatbot.py @@ -132,4 +132,14 @@ def chat(sent): print(chat("๋‚จ๋“ค์ด ๋‚˜๋ฅผ ์–ด๋–ป๊ฒŒ ์ƒ๊ฐํ• ์ง€ ์‹ ๊ฒฝ์“ฐ๊ฒŒ ๋ผ")) print("\'์ž์กด๊ฐ์ด ๋‚ฎ์•„์ง€๋Š” ๊ฒƒ ๊ฐ™์•„\' ์ฑ—๋ด‡ ์‘๋‹ต: ", end='') print(chat("์ž์กด๊ฐ์ด ๋‚ฎ์•„์ง€๋Š” ๊ฒƒ ๊ฐ™์•„")) +print("\'๋ญ˜ ํ•ด๋„ ๊ธˆ๋ฐฉ ์ง€์ณ\' ์ฑ—๋ด‡ ์‘๋‹ต: ", end='') +print(chat("๋ญ˜ ํ•ด๋„ ๊ธˆ๋ฐฉ ์ง€์ณ")) +print("\'๊ฑ”ํ•œํ…Œ ์ง„์งœ ํฌ๊ฒŒ ๋ฐฐ์‹  ๋‹นํ–ˆ์–ด\' ์ฑ—๋ด‡ ์‘๋‹ต: ", end='') +print(chat("๊ฑ”ํ•œํ…Œ ์ง„์งœ ํฌ๊ฒŒ ๋ฐฐ์‹  ๋‹นํ–ˆ์–ด")) +print("\'๋‚ด์ผ ๋†€์ด๊ณต์› ๊ฐˆ๊ฑด๋ฐ ์‚ฌ๋žŒ ๋ณ„๋กœ ์—†์—ˆ์œผ๋ฉด ์ข‹๊ฒ ๋‹ค\' ์ฑ—๋ด‡ ์‘๋‹ต: ", end='') +print(chat("๋‚ด์ผ ๋†€์ด๊ณต์› ๊ฐˆ๊ฑด๋ฐ ์‚ฌ๋žŒ ๋ณ„๋กœ ์—†์—ˆ์œผ๋ฉด ์ข‹๊ฒ ๋‹ค")) +print("\'์˜ค๋Š˜์€ ๊ตฌ๋ฆ„์ด๋ž‘ ๋‹ฌ์ด ๋„ˆ๋ฌด๋„ˆ๋ฌด ์˜ˆ์˜๋”๋ผ\' ์ฑ—๋ด‡ ์‘๋‹ต: ", end='') +print(chat("์˜ค๋Š˜์€ ๊ตฌ๋ฆ„์ด๋ž‘ ๋‹ฌ์ด ๋„ˆ๋ฌด๋„ˆ๋ฌด ์˜ˆ์˜๋”๋ผ")) +print("\'๊ทธ๋ž˜๋„ ๋‚ด๊ฐ€ ๋จธ๋ฆฌ๋Š” ์ข€ ์ข‹์•„\' ์ฑ—๋ด‡ ์‘๋‹ต: ", end='') +print(chat("๊ทธ๋ž˜๋„ ๋‚ด๊ฐ€ ๋จธ๋ฆฌ๋Š” ์ข€ ์ข‹์•„")) print("=" * 50) diff --git a/model/chatbot/kogpt2/chatbot.py b/model/chatbot/kogpt2/chatbot.py index eaa61a8..91e9474 100644 --- a/model/chatbot/kogpt2/chatbot.py +++ b/model/chatbot/kogpt2/chatbot.py @@ -73,7 +73,6 @@ def add_model_specific_args(parent_parser): return parser def forward(self, inputs): - # (batch, seq_len, hiddens) output = self.kogpt2(inputs, return_dict=True) return output.logits @@ -82,19 +81,28 @@ def chat(self, input_sentence, sent='0'): sent_tokens = tok.tokenize(sent) with torch.no_grad(): q = input_sentence.strip() - q = q[len(q) - 32:] a = '' while 1: input_ids = torch.LongTensor(tok.encode(U_TKN + q + SENT + sent + S_TKN + a)).unsqueeze(dim=0) pred = self(input_ids) - gen = tok.convert_ids_to_tokens( - torch.argmax( - pred, - dim=-1).squeeze().numpy().tolist())[-1] - if gen == EOS: + gen = tok.convert_ids_to_tokens(torch.argmax(pred, dim=-1).squeeze().numpy().tolist())[-1] + # print(gen) # + if gen == EOS or gen == PAD: # PAD ๋ฌดํ•œ ๋ฃจํ”„ ์—๋Ÿฌ ๋ฐฉ์ง€ break a += gen.replace('โ–', ' ') - return a.strip() + a = a.strip() + period_pos = a.rfind(".") + question_pos = a.rfind("?") + exclamation_pos = a.rfind("!") + last_pos = len(a) - 1 + # print (str(period_pos) + " " + str(question_pos) + " " + str(exclamation_pos)) + if last_pos == period_pos or last_pos == question_pos or last_pos == exclamation_pos: + return a + mark_pos = max(max(period_pos, question_pos), exclamation_pos) + a = a[:mark_pos + 1] + if a == "": + return "(๋„๋•๋„๋•) ๋“ฃ๊ณ  ์žˆ์–ด์š”. ๋” ๋ง์”€ํ•ด์ฃผ์„ธ์š”!" + return a parser = KoGPT2Chat.add_model_specific_args(parser) @@ -111,10 +119,14 @@ def predict(sent): print("=" * 50) print("[*] kogpt2 chatbot test") -print("\'ํŠน๋ณ„ํ•œ ์ด์œ ๊ฐ€ ์—†๋Š”๋ฐ ๊ทธ๋ƒฅ ๋ถˆ์•ˆํ•ด\' ์ฑ—๋ด‡ ์‘๋‹ต: " + predict("ํŠน๋ณ„ํ•œ ์ด์œ ๊ฐ€ ์—†๋Š”๋ฐ ๊ทธ๋ƒฅ ๋ถˆ์•ˆํ•ด")) print("\'ํŠน๋ณ„ํ•œ ์ด์œ ๊ฐ€ ์—†๋Š”๋ฐ ๊ทธ๋ƒฅ ๋ถˆ์•ˆํ•˜๊ณ  ๋ˆˆ๋ฌผ์ด ๋‚˜์™€\' ์ฑ—๋ด‡ ์‘๋‹ต: " + predict("ํŠน๋ณ„ํ•œ ์ด์œ ๊ฐ€ ์—†๋Š”๋ฐ ๊ทธ๋ƒฅ ๋ถˆ์•ˆํ•˜๊ณ  ๋ˆˆ๋ฌผ์ด ๋‚˜์™€")) print("\'์ด ์„ธ์ƒ์—์„œ ์™„์ „ํžˆ ์‚ฌ๋ผ์ง€๊ณ  ์‹ถ์–ด\' ์ฑ—๋ด‡ ์‘๋‹ต: " + predict("์ด ์„ธ์ƒ์—์„œ ์™„์ „ํžˆ ์‚ฌ๋ผ์ง€๊ณ  ์‹ถ์–ด")) print("\'๊ฐ€์Šด์ด ๋‹ต๋‹ตํ•ด์„œ ํ„ฐ์งˆ ๊ฒƒ๋งŒ ๊ฐ™์•„์š”.\' ์ฑ—๋ด‡ ์‘๋‹ต: " + predict("๊ฐ€์Šด์ด ๋‹ต๋‹ตํ•ด์„œ ํ„ฐ์งˆ ๊ฒƒ๋งŒ ๊ฐ™์•„์š”.")) print("\'๋‚จ๋“ค์ด ๋‚˜๋ฅผ ์–ด๋–ป๊ฒŒ ์ƒ๊ฐํ• ์ง€ ์‹ ๊ฒฝ์“ฐ๊ฒŒ ๋ผ\' ์ฑ—๋ด‡ ์‘๋‹ต: " + predict("๋‚จ๋“ค์ด ๋‚˜๋ฅผ ์–ด๋–ป๊ฒŒ ์ƒ๊ฐํ• ์ง€ ์‹ ๊ฒฝ์“ฐ๊ฒŒ ๋ผ")) print("\'์ž์กด๊ฐ์ด ๋‚ฎ์•„์ง€๋Š” ๊ฒƒ ๊ฐ™์•„\' ์ฑ—๋ด‡ ์‘๋‹ต: " + predict("์ž์กด๊ฐ์ด ๋‚ฎ์•„์ง€๋Š” ๊ฒƒ ๊ฐ™์•„")) +print("\'๋ญ˜ ํ•ด๋„ ๊ธˆ๋ฐฉ ์ง€์ณ\' ์ฑ—๋ด‡ ์‘๋‹ต: " + predict("๋ญ˜ ํ•ด๋„ ๊ธˆ๋ฐฉ ์ง€์ณ")) +print("\'๊ฑ”ํ•œํ…Œ ์ง„์งœ ํฌ๊ฒŒ ๋ฐฐ์‹  ๋‹นํ–ˆ์–ด\' ์ฑ—๋ด‡ ์‘๋‹ต: " + predict("๊ฑ”ํ•œํ…Œ ์ง„์งœ ํฌ๊ฒŒ ๋ฐฐ์‹  ๋‹นํ–ˆ์–ด")) +print("\'๋‚ด์ผ ๋†€์ด๊ณต์› ๊ฐˆ๊ฑด๋ฐ ์‚ฌ๋žŒ ๋ณ„๋กœ ์—†์—ˆ์œผ๋ฉด ์ข‹๊ฒ ๋‹ค\' ์ฑ—๋ด‡ ์‘๋‹ต: " + predict("๋‚ด์ผ ๋†€์ด๊ณต์› ๊ฐˆ๊ฑด๋ฐ ์‚ฌ๋žŒ ๋ณ„๋กœ ์—†์—ˆ์œผ๋ฉด ์ข‹๊ฒ ๋‹ค")) +print("\'์˜ค๋Š˜์€ ๊ตฌ๋ฆ„์ด๋ž‘ ๋‹ฌ์ด ๋„ˆ๋ฌด๋„ˆ๋ฌด ์˜ˆ์˜๋”๋ผ\' ์ฑ—๋ด‡ ์‘๋‹ต: " + predict("์˜ค๋Š˜์€ ๊ตฌ๋ฆ„์ด๋ž‘ ๋‹ฌ์ด ๋„ˆ๋ฌด๋„ˆ๋ฌด ์˜ˆ์˜๋”๋ผ")) +print("\'๊ทธ๋ž˜๋„ ๋‚ด๊ฐ€ ๋จธ๋ฆฌ๋Š” ์ข€ ์ข‹์•„\' ์ฑ—๋ด‡ ์‘๋‹ต: " + predict("๊ทธ๋ž˜๋„ ๋‚ด๊ฐ€ ๋จธ๋ฆฌ๋Š” ์ข€ ์ข‹์•„")) print("=" * 50)