-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/guilherme' into main
- Loading branch information
Showing
12 changed files
with
3,626 additions
and
1,680 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import logging | ||
from typing import List | ||
|
||
import pandas as pd | ||
from tqdm import tqdm | ||
|
||
from models.marian import MarianModel | ||
from models.mbart import MbartModel | ||
from models.nllb import NllbModel | ||
from models.t5 import t5Model | ||
from models.m2m100 import M2m100Model | ||
from utilities.check_csv_restricoes import verificar_restricoes_csv | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.INFO) | ||
|
||
|
||
def select_model(modelname): | ||
if modelname == "marian": | ||
return MarianModel() | ||
elif modelname == "t5": | ||
return t5Model() | ||
elif modelname == "nllb": | ||
return NllbModel() | ||
elif modelname == "mbart": | ||
return MbartModel() | ||
elif modelname == "m2m100": | ||
return M2m100Model() | ||
logger.debug("Provided invalid modelname!") | ||
|
||
|
||
def translate_csv( | ||
csv_path, collumns: List[str], models: List[str], output_format="csv", filename=None | ||
): | ||
""" | ||
Args: | ||
csv_file (str): Path to the CSV file. | ||
columns_to_translate (list): List of column names to be translated. | ||
translation_map (dict): Dictionary containing the translation map for each column. | ||
""" | ||
verificar_restricoes_csv(csv_path) | ||
for modelname in models: | ||
logger.info(f"Initializing {modelname} model ...") | ||
model = select_model(modelname) | ||
logger.info(f"Initializing {modelname} model ... done") | ||
dataframe = pd.read_csv(csv_path) | ||
dataframe = dataframe[:100] | ||
for collum in collumns: | ||
translations = [] | ||
for index, row in tqdm( | ||
dataframe.iterrows(), desc=f"Translating {collum} with {modelname}" | ||
): | ||
text = row[collum] | ||
translated_text = model.translate_text(text) | ||
translations.append(translated_text) | ||
translated_collum_name = f"{collum} {modelname} translation" | ||
dataframe.insert(len(dataframe.columns),translated_collum_name,translations) | ||
#dataframe[translated_collum_name] = translations | ||
dataframe.to_csv(f"{filename}_translation.csv") | ||
logger.info(f"Saving file...") | ||
|
||
|
||
def translate_webdataset(url, models: List[str]): | ||
pass | ||
|
||
|
||
if __name__ == "__main__": | ||
# Configuração dos parâmetros de tradução | ||
# caminho_pasta_entrada_original = input("Digite o caminho da pasta de arquivo CSV: ") | ||
# caminho_pasta_saida_traduzida = input( | ||
# "Digite o caminho do destino do arquivo CSV traduzido: " | ||
# ) | ||
translate_csv( | ||
"/home/guilhermeramirez/nlp/translate-dataset/data/raw/test.csv", | ||
collumns=["article","highlights"], | ||
#" Category", " Question"], | ||
models=["m2m100"], | ||
filename="cnn_inteiro_m2m100", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from google.cloud import translate_v2 as translate | ||
import os | ||
|
||
|
||
def traduzir_texto(texto, idioma_destino, caminho_chave_api): | ||
""" | ||
Realiza a tradução do texto para o idioma de destino utilizando a API do Google Translate. | ||
Args: | ||
texto (str): O texto a ser traduzido. | ||
idioma_destino (str): O idioma de destino para a tradução. | ||
caminho_chave_api (str): | ||
O caminho do arquivo JSON contendo a chave de API do Google Translate. | ||
Returns: | ||
str: O texto traduzido. | ||
""" | ||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = caminho_chave_api | ||
client = translate.Client() | ||
traducao = client.translate(texto, target_language=idioma_destino) | ||
return traducao["translatedText"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import logging | ||
|
||
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class M2m100Model: | ||
def __init__(self) -> None: | ||
self.model = M2M100ForConditionalGeneration.from_pretrained( | ||
"facebook/m2m100_418M" | ||
) | ||
self.tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M") | ||
|
||
def translate_text(self, sentence): | ||
# if count == text_limit: break | ||
|
||
inputs = self.tokenizer(sentence, return_tensors="pt", padding=True) | ||
|
||
output_sequences = self.model.generate( | ||
input_ids=inputs["input_ids"], | ||
attention_mask=inputs["attention_mask"], | ||
do_sample=False, # disable sampling to test if batching affects output | ||
forced_bos_token_id=self.tokenizer.get_lang_id("pt"), | ||
) | ||
|
||
sentence_decoded = self.tokenizer.batch_decode( | ||
output_sequences, skip_special_tokens=True | ||
) | ||
return sentence_decoded[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import logging | ||
|
||
from transformers import MarianMTModel, MarianTokenizer | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class MarianModel: | ||
def __init__(self) -> None: | ||
model_name = "Helsinki-NLP/opus-mt-en-ROMANCE" | ||
self.tokenizer = MarianTokenizer.from_pretrained(model_name) | ||
self.model = MarianMTModel.from_pretrained(model_name) | ||
|
||
def translate_text(self, sentence): | ||
inputs = self.tokenizer( | ||
">>pt<<" + sentence if len(sentence) < 512 else ">>pt<<" + sentence[:512], | ||
return_tensors="pt", | ||
padding=True, | ||
) | ||
|
||
output_sequences = self.model.generate( | ||
input_ids=inputs["input_ids"], | ||
attention_mask=inputs["attention_mask"], | ||
do_sample=False, # disable sampling to test if batching affects output, | ||
max_length=1024, | ||
) | ||
sentence_decoded = self.tokenizer.batch_decode( | ||
output_sequences, skip_special_tokens=True | ||
) | ||
|
||
return sentence_decoded[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration | ||
|
||
|
||
class MbartModel: | ||
def __init__(self) -> None: | ||
self.tokenizer = MBart50TokenizerFast.from_pretrained( | ||
"Narrativa/mbart-large-50-finetuned-opus-en-pt-translation" | ||
) | ||
self.model = MBartForConditionalGeneration.from_pretrained( | ||
"Narrativa/mbart-large-50-finetuned-opus-en-pt-translation" | ||
) | ||
|
||
def translate_text(self, sentence): | ||
self.tokenizer.src_lang = "en_XX" | ||
|
||
inputs = self.tokenizer(sentence, return_tensors="pt", padding=True) | ||
|
||
output_sequences = self.model.generate( | ||
input_ids=inputs["input_ids"], | ||
attention_mask=inputs["attention_mask"], | ||
do_sample=False, # disable sampling to test if batching affects output | ||
forced_bos_token_id=self.tokenizer.lang_code_to_id["pt_XX"], | ||
) | ||
|
||
sentence_decoded = self.tokenizer.batch_decode( | ||
output_sequences, skip_special_tokens=True | ||
) | ||
return sentence_decoded |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import logging | ||
|
||
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class NlbbModel: | ||
def __init__(self) -> None: | ||
self.model = M2M100ForConditionalGeneration.from_pretrained( | ||
"facebook/m2m100_418M" | ||
) | ||
self.tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M") | ||
|
||
def translate_text(self, sentence): | ||
list_of_sentences = [] | ||
|
||
self.tokenizer.src_lang = "en" | ||
|
||
inputs = self.tokenizer(sentence, return_tensors="pt", padding=True) | ||
|
||
output_sequences = self.model.generate( | ||
input_ids=inputs["input_ids"], | ||
attention_mask=inputs["attention_mask"], | ||
do_sample=False, # disable sampling to test if batching affects output | ||
forced_bos_token_id=self.tokenizer.get_lang_id("pt"), | ||
) | ||
|
||
sentence_decoded = self.tokenizer.batch_decode( | ||
output_sequences, skip_special_tokens=True | ||
) | ||
list_of_sentences.append(sentence_decoded) | ||
return list_of_sentences |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import logging | ||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline | ||
|
||
|
||
source_lang = "eng_Latn" | ||
target_lang = "por_Latn" | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class NllbModel: | ||
def __init__(self) -> None: | ||
self.model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") | ||
self.tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | ||
self.translator = pipeline("translation", model=self.model, tokenizer=self.tokenizer, src_lang=source_lang, tgt_lang=target_lang, max_length = 400) | ||
def translate_text(self, sentence): | ||
list_of_sentences = [] | ||
out = self.translator(sentence) | ||
out = [o["translation_text"] for o in out] | ||
return out[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import logging | ||
|
||
import pandas as pd | ||
from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer, | ||
T5ForConditionalGeneration, T5Tokenizer) | ||
|
||
|
||
class t5Model: | ||
def __init__(self) -> None: | ||
model_name = "unicamp-dl/translation-en-pt-t5" | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | ||
|
||
def translate_text(self, sentence): | ||
""" | ||
Translates the input sentence. | ||
Args: | ||
sentence (str): The input sentence to be translated. | ||
Returns: | ||
str: The translated version of the input sentence. | ||
""" | ||
task_prefix = "translate English to Portuguese:" | ||
inputs = self.tokenizer( | ||
task_prefix + sentence, return_tensors="pt", padding=True | ||
) | ||
|
||
output_sequences = self.model.generate( | ||
input_ids=inputs["input_ids"], | ||
attention_mask=inputs["attention_mask"], | ||
do_sample=False, # disable sampling to test if batching affects output | ||
) | ||
sentence_decoded = self.tokenizer.batch_decode( | ||
output_sequences, skip_special_tokens=True | ||
) | ||
|
||
return sentence_decoded[0] |
Oops, something went wrong.