Skip to content

Commit

Permalink
feat: word2vec 사전 학습 모델로 수정
Browse files Browse the repository at this point in the history
  • Loading branch information
joosomi committed Jul 8, 2024
1 parent 4d2cef2 commit cde87e8
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 16 deletions.
41 changes: 26 additions & 15 deletions app/api/similarity.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import os
from fastapi import APIRouter
from app.schemas import SimilarityRequest
from app.utils import get_sentence_embedding, normalize_vector, cosine_similarity
from transformers import AutoModel, AutoTokenizer
from app.utils import normalize_vector, cosine_similarity
import numpy as np
import gensim
from gensim.models import KeyedVectors

router = APIRouter()

# KLUE/BERT 모델과 토크나이저 로드
tokenizer = AutoTokenizer.from_pretrained('klue/bert-base')
model = AutoModel.from_pretrained('klue/bert-base')
# 한국어 Word2Vec 모델 로드
current_directory = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(current_directory, '..', 'ko.bin')
absolute_file_path = os.path.abspath(file_path)
print(f"Loading model from: {absolute_file_path}")

# Word2Vec 모델 로드
model = gensim.models.Word2Vec.load(absolute_file_path)

@router.post("/")
async def calculate_similarity(request: SimilarityRequest):
Expand All @@ -19,9 +26,15 @@ async def calculate_similarity(request: SimilarityRequest):
speakingA = request.speakingA
speakingB = request.speakingB

def get_word_vector(word, model):
if word in model.wv:
return model.wv[word]
else:
return np.zeros(model.vector_size)

# 관심사를 벡터로 변환하고 정규화
vectorsA = [normalize_vector(get_sentence_embedding(interest, tokenizer, model)) for interest in interestsA]
vectorsB = [normalize_vector(get_sentence_embedding(interest, tokenizer, model)) for interest in interestsB]
vectorsA = [normalize_vector(get_word_vector(interest, model)) for interest in interestsA]
vectorsB = [normalize_vector(get_word_vector(interest, model)) for interest in interestsB]

if not vectorsA or not vectorsB:
return {"similarity": 0.0}
Expand All @@ -38,12 +51,10 @@ async def calculate_similarity(request: SimilarityRequest):
speaking_listening_complementary = (speakingA + listeningB) / 20

# 관심사/경청/발화 지수 전체 계산
# overall_similarity = (
# 0.5 * interest_similarity +
# 0.25 * listening_speaking_complementary +
# 0.25 * speaking_listening_complementary
# )

overall_similarity = interest_similarity
overall_similarity = (
0.5 * interest_similarity +
0.25 * listening_speaking_complementary +
0.25 * speaking_listening_complementary
)

return {"similarity": float(overall_similarity)}
return {"similarity": float(interest_similarity)}
Binary file added app/ko.bin
Binary file not shown.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ fastapi
uvicorn
transformers
torch
pydantic
pydantic
gensim

0 comments on commit cde87e8

Please sign in to comment.