Skip to content

Commit

Permalink
load_dict() read with utf-8 encoding specify
Browse files Browse the repository at this point in the history
  • Loading branch information
bact committed Sep 8, 2019
1 parent 264c044 commit c866982
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions attacut/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import json
import os
import numpy as np
import yaml
import time
import json
from typing import Callable, Dict, NamedTuple, Union

from typing import Callable, NamedTuple, Dict, Union
import yaml

from attacut import logger


log = logger.get_logger(__name__)


class ModelParams(NamedTuple):
name: str
params: str
Expand All @@ -25,7 +26,7 @@ def __enter__(self):
def __exit__(self, type, value, traceback):
self.stop = time.time()
diff = self.stop - self.start
log.info("Finished block: %s with %d seconds" % (self.name, diff))
log.info("Finished block: %s with %d seconds" % (self.name, diff))


def maybe(cond: bool, func: Callable[[], None], desc: str, verbose=0):
Expand All @@ -38,9 +39,7 @@ def maybe(cond: bool, func: Callable[[], None], desc: str, verbose=0):

def maybe_create_dir(path: str):
return maybe(
not os.path.exists(path),
lambda : os.mkdir(path),
"create dir %s" % path
not os.path.exists(path), lambda: os.mkdir(path), "create dir %s" % path
)


Expand All @@ -57,20 +56,22 @@ def wc_l(path: str) -> int:
# count total lines in a file
s = 0
with open(path, "r") as f:
for l in f:
for _ in f:
s += 1
return s


def save_training_params(dir_path: str, params: ModelParams):

dir_path = "%s/params.yml" % dir_path
print("Saving training params to %s" % dir_path)

params = dict(params._asdict())

with open(dir_path, 'w') as outfile:
with open(dir_path, "w") as outfile:
yaml.dump(params, outfile, default_flow_style=False)


def load_training_params(path: str) -> ModelParams:
with open("%s/params.yml" % path, "r") as f:
params = yaml.load(f, Loader=yaml.BaseLoader)
Expand All @@ -90,7 +91,7 @@ def parse_model_params(ss: str) -> Dict[str, Union[int, float]]:


def load_dict(data_path: str) -> Dict:
with open(data_path, "r") as f:
with open(data_path, "r", encoding="utf-8") as f:
dd = json.load(f)

log.info("loaded %d items from dict:%s" % (len(dd), data_path))
Expand All @@ -107,4 +108,4 @@ def create_start_stop_indices(seq_lengths):
st_indices.append(prev)
sp_indices.append(prev + s)

return list(zip(st_indices, sp_indices))
return list(zip(st_indices, sp_indices))

0 comments on commit c866982

Please sign in to comment.