Skip to content

Commit

Permalink
Merge pull request #1 from PyThaiNLP/master
Browse files Browse the repository at this point in the history
load_dict() read with utf-8 encoding specify
  • Loading branch information
bact authored Sep 8, 2019
2 parents 264c044 + c866982 commit 5cc6a08
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions attacut/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import json
import os
import numpy as np
import yaml
import time
import json
from typing import Callable, Dict, NamedTuple, Union

from typing import Callable, NamedTuple, Dict, Union
import yaml

from attacut import logger


log = logger.get_logger(__name__)


class ModelParams(NamedTuple):
name: str
params: str
Expand All @@ -25,7 +26,7 @@ def __enter__(self):
def __exit__(self, type, value, traceback):
self.stop = time.time()
diff = self.stop - self.start
log.info("Finished block: %s with %d seconds" % (self.name, diff))
log.info("Finished block: %s with %d seconds" % (self.name, diff))


def maybe(cond: bool, func: Callable[[], None], desc: str, verbose=0):
Expand All @@ -38,9 +39,7 @@ def maybe(cond: bool, func: Callable[[], None], desc: str, verbose=0):

def maybe_create_dir(path: str):
return maybe(
not os.path.exists(path),
lambda : os.mkdir(path),
"create dir %s" % path
not os.path.exists(path), lambda: os.mkdir(path), "create dir %s" % path
)


Expand All @@ -57,20 +56,22 @@ def wc_l(path: str) -> int:
# count total lines in a file
s = 0
with open(path, "r") as f:
for l in f:
for _ in f:
s += 1
return s


def save_training_params(dir_path: str, params: ModelParams):

dir_path = "%s/params.yml" % dir_path
print("Saving training params to %s" % dir_path)

params = dict(params._asdict())

with open(dir_path, 'w') as outfile:
with open(dir_path, "w") as outfile:
yaml.dump(params, outfile, default_flow_style=False)


def load_training_params(path: str) -> ModelParams:
with open("%s/params.yml" % path, "r") as f:
params = yaml.load(f, Loader=yaml.BaseLoader)
Expand All @@ -90,7 +91,7 @@ def parse_model_params(ss: str) -> Dict[str, Union[int, float]]:


def load_dict(data_path: str) -> Dict:
with open(data_path, "r") as f:
with open(data_path, "r", encoding="utf-8") as f:
dd = json.load(f)

log.info("loaded %d items from dict:%s" % (len(dd), data_path))
Expand All @@ -107,4 +108,4 @@ def create_start_stop_indices(seq_lengths):
st_indices.append(prev)
sp_indices.append(prev + s)

return list(zip(st_indices, sp_indices))
return list(zip(st_indices, sp_indices))

0 comments on commit 5cc6a08

Please sign in to comment.