From c8669828c94ac7059dd56fb75e216755667e139b Mon Sep 17 00:00:00 2001 From: Arthit Suriyawongkul Date: Sun, 8 Sep 2019 12:23:06 +0700 Subject: [PATCH] load_dict() read with utf-8 encoding specify --- attacut/utils.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/attacut/utils.py b/attacut/utils.py index 03a2bc7..dc39a1e 100644 --- a/attacut/utils.py +++ b/attacut/utils.py @@ -1,15 +1,16 @@ +import json import os -import numpy as np -import yaml import time -import json +from typing import Callable, Dict, NamedTuple, Union -from typing import Callable, NamedTuple, Dict, Union +import yaml from attacut import logger + log = logger.get_logger(__name__) + class ModelParams(NamedTuple): name: str params: str @@ -25,7 +26,7 @@ def __enter__(self): def __exit__(self, type, value, traceback): self.stop = time.time() diff = self.stop - self.start - log.info("Finished block: %s with %d seconds" % (self.name, diff)) + log.info("Finished block: %s with %d seconds" % (self.name, diff)) def maybe(cond: bool, func: Callable[[], None], desc: str, verbose=0): @@ -38,9 +39,7 @@ def maybe(cond: bool, func: Callable[[], None], desc: str, verbose=0): def maybe_create_dir(path: str): return maybe( - not os.path.exists(path), - lambda : os.mkdir(path), - "create dir %s" % path + not os.path.exists(path), lambda: os.mkdir(path), "create dir %s" % path ) @@ -57,10 +56,11 @@ def wc_l(path: str) -> int: # count total lines in a file s = 0 with open(path, "r") as f: - for l in f: + for _ in f: s += 1 return s + def save_training_params(dir_path: str, params: ModelParams): dir_path = "%s/params.yml" % dir_path @@ -68,9 +68,10 @@ def save_training_params(dir_path: str, params: ModelParams): params = dict(params._asdict()) - with open(dir_path, 'w') as outfile: + with open(dir_path, "w") as outfile: yaml.dump(params, outfile, default_flow_style=False) + def load_training_params(path: str) -> ModelParams: with open("%s/params.yml" % path, "r") as f: params = yaml.load(f, Loader=yaml.BaseLoader) @@ -90,7 +91,7 @@ def parse_model_params(ss: str) -> Dict[str, Union[int, float]]: def load_dict(data_path: str) -> Dict: - with open(data_path, "r") as f: + with open(data_path, "r", encoding="utf-8") as f: dd = json.load(f) log.info("loaded %d items from dict:%s" % (len(dd), data_path)) @@ -107,4 +108,4 @@ def create_start_stop_indices(seq_lengths): st_indices.append(prev) sp_indices.append(prev + s) - return list(zip(st_indices, sp_indices)) \ No newline at end of file + return list(zip(st_indices, sp_indices))