-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptions.py
31 lines (20 loc) · 898 Bytes
/
options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from typing import Dict, List
import numpy as np
DEFAULT_FLOAT_VALUE = 0.0
DEFAULT_INT_VALUE = 0
def parse_list_option(option_config: Dict) -> np.ndarray:
option_value = option_config.get("value", [])
return np.array(option_value)
def parse_repeated_list_option(option_config: Dict) -> np.ndarray:
option_value = option_config.get("value", [])
times = option_config.get("times", 1)
return np.array(option_value * times)
def parse_random_list_option(option_config: Dict) -> np.ndarray:
size = option_config.get("size", 1)
return np.random.randn(size) # This is a placeholder
def parse_float_option(option_config: Dict) -> float:
option_value = option_config.get("value", DEFAULT_FLOAT_VALUE)
return option_value
def parse_int_option(option_config: Dict) -> int:
option_value = option_config.get("value", DEFAULT_INT_VALUE)
return option_value