Skip to content

Commit

Permalink
ultimas mudancas
Browse files Browse the repository at this point in the history
  • Loading branch information
ramireguilherme committed Nov 5, 2023
1 parent 3ca511f commit f0cf0da
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 13 deletions.
27 changes: 16 additions & 11 deletions src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from models.marian import MarianModel
from models.mbart import MbartModel
from models.nlbb import NlbbModel
from models.nllb import NllbModel
from models.t5 import t5Model
from models.m2m100 import M2m100Model
from utilities.check_csv_restricoes import verificar_restricoes_csv

logger = logging.getLogger(__name__)
Expand All @@ -19,10 +20,12 @@ def select_model(modelname):
return MarianModel()
elif modelname == "t5":
return t5Model()
elif modelname == "nlbb":
return NlbbModel()
elif modelname == "nllb":
return NllbModel()
elif modelname == "mbart":
return MbartModel()
elif modelname == "m2m100":
return M2m100Model()
logger.debug("Provided invalid modelname!")


Expand All @@ -42,7 +45,7 @@ def translate_csv(
model = select_model(modelname)
logger.info(f"Initializing {modelname} model ... done")
dataframe = pd.read_csv(csv_path)
dataframe = dataframe[:10]
dataframe = dataframe[:100]
for collum in collumns:
translations = []
for index, row in tqdm(
Expand All @@ -52,9 +55,10 @@ def translate_csv(
translated_text = model.translate_text(text)
translations.append(translated_text)
translated_collum_name = f"{collum} {modelname} translation"
dataframe[translated_collum_name] = translations

dataframe.to_csv(f"{filename}_translation.csv")
dataframe.insert(len(dataframe.columns),translated_collum_name,translations)
#dataframe[translated_collum_name] = translations
dataframe.to_csv(f"{filename}_translation.csv")
logger.info(f"Saving file...")


def translate_webdataset(url, models: List[str]):
Expand All @@ -68,8 +72,9 @@ def translate_webdataset(url, models: List[str]):
# "Digite o caminho do destino do arquivo CSV traduzido: "
# )
translate_csv(
"PATH TO FILE HERE",
collumns=[" Category", " Question"],
models=["marian"],
filename="jeopardy",
"/home/guilhermeramirez/nlp/translate-dataset/data/raw/test.csv",
collumns=["article","highlights"],
#" Category", " Question"],
models=["m2m100"],
filename="cnn_inteiro_m2m100",
)
4 changes: 2 additions & 2 deletions src/models/m2m100.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
logger = logging.getLogger(__name__)


class MarianModel:
class M2m100Model:
def __init__(self) -> None:
self.model = M2M100ForConditionalGeneration.from_pretrained(
"facebook/m2m100_418M"
Expand All @@ -27,4 +27,4 @@ def translate_text(self, sentence):
sentence_decoded = self.tokenizer.batch_decode(
output_sequences, skip_special_tokens=True
)
return sentence_decoded
return sentence_decoded[0]
19 changes: 19 additions & 0 deletions src/models/nllb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import logging
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline


source_lang = "eng_Latn"
target_lang = "por_Latn"

logger = logging.getLogger(__name__)

class NllbModel:
def __init__(self) -> None:
self.model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
self.tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
self.translator = pipeline("translation", model=self.model, tokenizer=self.tokenizer, src_lang=source_lang, tgt_lang=target_lang, max_length = 400)
def translate_text(self, sentence):
list_of_sentences = []
out = self.translator(sentence)
out = [o["translation_text"] for o in out]
return out[0]

0 comments on commit f0cf0da

Please sign in to comment.