diff --git a/setup.py b/setup.py index 4858897..5871273 100644 --- a/setup.py +++ b/setup.py @@ -9,11 +9,13 @@ # This source code is licensed under the Apache-2.0 license found in the # LICENSE file in the root directory of this source tree. + import os import sys import shutil import setuptools + # ------------------Package Meta-Data------------------ PACKAGE_INFO = {} diff --git a/ynmt/binaries/default/preprocess.py b/ynmt/binaries/default/preprocess.py index 6094b12..65b1d6d 100644 --- a/ynmt/binaries/default/preprocess.py +++ b/ynmt/binaries/default/preprocess.py @@ -10,15 +10,12 @@ # LICENSE file in the root directory of this source tree. -import os -import collections - import ynmt.hocon.arguments as harg from ynmt.tasks import build_task from ynmt.utilities.random import fix_random_procedure -from ynmt.utilities.logging import setup_logger, get_logger, logging_level +from ynmt.utilities.logging import setup_logger, logging_level def preprocess(args): diff --git a/ynmt/binaries/default/test.py b/ynmt/binaries/default/test.py index 4fc7a58..3355190 100644 --- a/ynmt/binaries/default/test.py +++ b/ynmt/binaries/default/test.py @@ -9,10 +9,9 @@ # This source code is licensed under the Apache-2.0 license found in the # LICENSE file in the root directory of this source tree. + import os import torch -import importlib -import pickle import ynmt.hocon.arguments as harg diff --git a/ynmt/binaries/default/train.py b/ynmt/binaries/default/train.py index 35b2b0e..37260f8 100644 --- a/ynmt/binaries/default/train.py +++ b/ynmt/binaries/default/train.py @@ -10,10 +10,7 @@ # LICENSE file in the root directory of this source tree. -import os import torch -import importlib -import pickle import ynmt.hocon.arguments as harg diff --git a/ynmt/criterions/__init__.py b/ynmt/criterions/__init__.py index 69b0ea8..fe30dd3 100644 --- a/ynmt/criterions/__init__.py +++ b/ynmt/criterions/__init__.py @@ -16,6 +16,7 @@ from ynmt.utilities.registration import Registration, import_modules + criterion_registration = Registration(Criterion) diff --git a/ynmt/criterions/criterion.py b/ynmt/criterions/criterion.py index 34869cc..9e4b994 100644 --- a/ynmt/criterions/criterion.py +++ b/ynmt/criterions/criterion.py @@ -12,7 +12,6 @@ import torch - from ynmt.utilities.statistics import Statistics diff --git a/ynmt/criterions/cross_entropy.py b/ynmt/criterions/cross_entropy.py index fc58d68..ee03e9e 100644 --- a/ynmt/criterions/cross_entropy.py +++ b/ynmt/criterions/cross_entropy.py @@ -12,7 +12,6 @@ import torch - from ynmt.criterions import register_criterion, Criterion diff --git a/ynmt/criterions/label_smoothing_cross_entropy.py b/ynmt/criterions/label_smoothing_cross_entropy.py index 953ff40..75ae80e 100644 --- a/ynmt/criterions/label_smoothing_cross_entropy.py +++ b/ynmt/criterions/label_smoothing_cross_entropy.py @@ -12,7 +12,6 @@ import torch - from ynmt.criterions import register_criterion, Criterion diff --git a/ynmt/data/batch.py b/ynmt/data/batch.py index d0e7562..2561654 100644 --- a/ynmt/data/batch.py +++ b/ynmt/data/batch.py @@ -10,9 +10,6 @@ # LICENSE file in the root directory of this source tree. -import torch - - class Batch(object): def __init__(self, structure, instances=list()): assert isinstance(structure, set), 'Type of structure should be set().' diff --git a/ynmt/data/iterator.py b/ynmt/data/iterator.py index cb67b5c..46fb226 100644 --- a/ynmt/data/iterator.py +++ b/ynmt/data/iterator.py @@ -13,10 +13,9 @@ import random from ynmt.data.batch import Batch -from ynmt.data.instance import Instance, InstanceComparator +from ynmt.data.instance import InstanceComparator from ynmt.utilities.file import load_datas from ynmt.utilities.random import shuffled -from ynmt.utilities.statistics import Statistics class Iterator(object): diff --git a/ynmt/hocon/arguments.py b/ynmt/hocon/arguments.py index 90d855f..c7cbd0a 100644 --- a/ynmt/hocon/arguments.py +++ b/ynmt/hocon/arguments.py @@ -15,7 +15,6 @@ import argparse import collections - from ynmt.utilities.constant import Constant diff --git a/ynmt/models/__init__.py b/ynmt/models/__init__.py index cc9e45e..e02ef9c 100644 --- a/ynmt/models/__init__.py +++ b/ynmt/models/__init__.py @@ -16,6 +16,7 @@ from ynmt.utilities.registration import Registration, import_modules + model_registration = Registration(Model) diff --git a/ynmt/modules/decoders/transformer_decoder.py b/ynmt/modules/decoders/transformer_decoder.py index 83816ba..18e979c 100644 --- a/ynmt/modules/decoders/transformer_decoder.py +++ b/ynmt/modules/decoders/transformer_decoder.py @@ -13,7 +13,6 @@ import math import torch - from ynmt.modules.embeddings import TrigonometricPositionalEmbedding from ynmt.modules.attentions import MultiHeadAttention from ynmt.modules.perceptrons import PositionWiseFeedForward diff --git a/ynmt/modules/embeddings/learned_positional_embedding.py b/ynmt/modules/embeddings/learned_positional_embedding.py index 7f3273b..73c347d 100644 --- a/ynmt/modules/embeddings/learned_positional_embedding.py +++ b/ynmt/modules/embeddings/learned_positional_embedding.py @@ -10,7 +10,6 @@ # LICENSE file in the root directory of this source tree. -import math import torch diff --git a/ynmt/modules/encoders/transformer_encoder.py b/ynmt/modules/encoders/transformer_encoder.py index cb6f37d..6d6fd1c 100644 --- a/ynmt/modules/encoders/transformer_encoder.py +++ b/ynmt/modules/encoders/transformer_encoder.py @@ -13,7 +13,6 @@ import math import torch - from ynmt.modules.embeddings import TrigonometricPositionalEmbedding from ynmt.modules.attentions import MultiHeadAttention from ynmt.modules.perceptrons import PositionWiseFeedForward diff --git a/ynmt/schedulers/__init__.py b/ynmt/schedulers/__init__.py index 560af59..d41c547 100644 --- a/ynmt/schedulers/__init__.py +++ b/ynmt/schedulers/__init__.py @@ -16,6 +16,7 @@ from ynmt.utilities.registration import Registration, import_modules + scheduler_registration = Registration(Scheduler) diff --git a/ynmt/tasks/mixins/seq.py b/ynmt/tasks/mixins/seq.py index fdb8d1e..1d03073 100644 --- a/ynmt/tasks/mixins/seq.py +++ b/ynmt/tasks/mixins/seq.py @@ -12,7 +12,6 @@ import collections - from ynmt.utilities.file import load_plain from ynmt.utilities.multiprocessing import multi_process diff --git a/ynmt/tasks/task.py b/ynmt/tasks/task.py index 102e042..cace8da 100644 --- a/ynmt/tasks/task.py +++ b/ynmt/tasks/task.py @@ -12,7 +12,7 @@ from ynmt.data.dataset import Dataset -from ynmt.utilities.file import mk_temp, rm_temp, load_data, load_datas, dump_data, dump_datas +from ynmt.utilities.file import mk_temp, rm_temp, load_data, dump_data, dump_datas from ynmt.utilities.multiprocessing import multi_process diff --git a/ynmt/testers/seq2seq.py b/ynmt/testers/seq2seq.py index 20790cf..d88c8ad 100644 --- a/ynmt/testers/seq2seq.py +++ b/ynmt/testers/seq2seq.py @@ -10,8 +10,6 @@ # LICENSE file in the root directory of this source tree. -import os -import re import torch from ynmt.testers import register_tester, Tester diff --git a/ynmt/testers/tester.py b/ynmt/testers/tester.py index 63800ae..845e2de 100644 --- a/ynmt/testers/tester.py +++ b/ynmt/testers/tester.py @@ -10,7 +10,6 @@ # LICENSE file in the root directory of this source tree. -import re import torch from ynmt.utilities.timer import Timer diff --git a/ynmt/trainers/trainer.py b/ynmt/trainers/trainer.py index ab59ed7..f81b451 100644 --- a/ynmt/trainers/trainer.py +++ b/ynmt/trainers/trainer.py @@ -10,12 +10,10 @@ # LICENSE file in the root directory of this source tree. -import re import torch - from ynmt.utilities.timer import Timer -from ynmt.utilities.statistics import Statistics, perplexity +from ynmt.utilities.statistics import Statistics from ynmt.utilities.checkpoint import save_checkpoint from ynmt.utilities.distributed import reduce_all, gather_all @@ -133,7 +131,9 @@ def launch(self, accumulated_train_batches, accumulated_valid_batches): if self.step % self.training_period == 0: self.save() - self.save() + if self.step % self.training_period != 0: + self.save() + return def train(self, accumulated_train_batch): diff --git a/ynmt/utilities/distributed.py b/ynmt/utilities/distributed.py index 53f3c9e..aa7c714 100644 --- a/ynmt/utilities/distributed.py +++ b/ynmt/utilities/distributed.py @@ -13,7 +13,6 @@ import torch import threading - from ynmt.utilities.file import dumps, loads diff --git a/ynmt/utilities/logging.py b/ynmt/utilities/logging.py index 317aed8..bf9a423 100644 --- a/ynmt/utilities/logging.py +++ b/ynmt/utilities/logging.py @@ -12,7 +12,6 @@ import logging - from ynmt.utilities.file import mk_temp diff --git a/ynmt/utilities/registration.py b/ynmt/utilities/registration.py index 427bd26..ff1c130 100644 --- a/ynmt/utilities/registration.py +++ b/ynmt/utilities/registration.py @@ -14,9 +14,6 @@ import importlib -from ynmt.utilities.constant import Constant - - def import_modules(father, directory): file_names = os.listdir(directory) for file_name in file_names: diff --git a/ynmt/utilities/timer.py b/ynmt/utilities/timer.py index 966b7e0..13bc67c 100644 --- a/ynmt/utilities/timer.py +++ b/ynmt/utilities/timer.py @@ -12,7 +12,6 @@ import time - from ynmt.utilities.constant import Constant diff --git a/ynmt/utilities/tracker.py b/ynmt/utilities/tracker.py index f1ab5a0..35911dd 100644 --- a/ynmt/utilities/tracker.py +++ b/ynmt/utilities/tracker.py @@ -17,7 +17,6 @@ import inspect import datetime - from ynmt.utilities.file import get_temp_file_path diff --git a/ynmt/utilities/visualizing.py b/ynmt/utilities/visualizing.py index 52e78f4..fc01471 100644 --- a/ynmt/utilities/visualizing.py +++ b/ynmt/utilities/visualizing.py @@ -14,7 +14,6 @@ import json import visdom - from ynmt.utilities.file import mk_temp from ynmt.utilities.constant import Constant diff --git a/ynmt_cli/default/preprocess.py b/ynmt_cli/default/preprocess.py index 10f5419..d0a6f9d 100644 --- a/ynmt_cli/default/preprocess.py +++ b/ynmt_cli/default/preprocess.py @@ -9,6 +9,7 @@ # This source code is licensed under the Apache-2.0 license found in the # LICENSE file in the root directory of this source tree. + from ynmt.binaries.default.preprocess import main diff --git a/ynmt_cli/default/test.py b/ynmt_cli/default/test.py index 01c8689..17ca6f7 100644 --- a/ynmt_cli/default/test.py +++ b/ynmt_cli/default/test.py @@ -9,6 +9,7 @@ # This source code is licensed under the Apache-2.0 license found in the # LICENSE file in the root directory of this source tree. + from ynmt.binaries.default.test import main diff --git a/ynmt_cli/default/train.py b/ynmt_cli/default/train.py index 680bc4c..b85d046 100644 --- a/ynmt_cli/default/train.py +++ b/ynmt_cli/default/train.py @@ -9,6 +9,7 @@ # This source code is licensed under the Apache-2.0 license found in the # LICENSE file in the root directory of this source tree. + from ynmt.binaries.default.train import main