-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Sikan Li
committed
Aug 6, 2024
1 parent
325fc8d
commit 06dfb25
Showing
4 changed files
with
465 additions
and
164 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
defaults: | ||
- _self_ | ||
- override hydra/hydra_logging: disabled | ||
- override hydra/job_logging: disabled | ||
|
||
hydra: | ||
output_subdir: null | ||
run: | ||
dir: . | ||
|
||
# Top-level configuration | ||
mode: train | ||
|
||
# Data configuration | ||
data: | ||
path: ../gns-sample/Cylinder/dataset/ | ||
batch_size: 2 | ||
noise_std: 2e-3 | ||
input_sequence_length: 1 | ||
node_type_embedding_size: 9 | ||
dt: 0.01 | ||
|
||
# Model configuration | ||
model: | ||
path: ../gns-sample/Cylinder/models/ | ||
file: model-100.pt | ||
train_state_file: train_state-100.pt | ||
|
||
# Output configuration | ||
output: | ||
path: ../gns-sample/Cylinder/rollouts/ | ||
filename: rollout | ||
|
||
# Training configuration | ||
training: | ||
steps: 100 | ||
validation_interval: null | ||
save_steps: 500 | ||
resume: False | ||
learning_rate: | ||
initial: 1e-4 | ||
decay: 0.1 | ||
decay_steps: 5000000 | ||
|
||
# Hardware configuration | ||
hardware: | ||
cuda_device_number: null | ||
|
||
# Logging configuration | ||
logging: | ||
tensorboard_dir: logs/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Optional | ||
from omegaconf import MISSING | ||
from hydra.core.config_store import ConfigStore | ||
|
||
|
||
@dataclass | ||
class DataConfig: | ||
path: str = MISSING | ||
batch_size: int = 2 | ||
noise_std: float = 2e-3 | ||
input_sequence_length: int = 1 | ||
node_type_embedding_size: int = 9 | ||
dt: float = 0.01 | ||
|
||
|
||
@dataclass | ||
class ModelConfig: | ||
path: str = "../gns-sample/Cylinder/models/" | ||
file: Optional[str] = None | ||
train_state_file: Optional[str] = None | ||
|
||
|
||
@dataclass | ||
class OutputConfig: | ||
path: str = "../gns-sample/Cylinder/rollouts/" | ||
filename: str = "rollout" | ||
|
||
|
||
@dataclass | ||
class LearningRateConfig: | ||
initial: float = 1e-4 | ||
decay: float = 0.1 | ||
decay_steps: int = 5000000 | ||
|
||
|
||
@dataclass | ||
class TrainingConfig: | ||
steps: int = 100 | ||
validation_interval: Optional[int] = None | ||
save_steps: int = 500 | ||
resume: Optional[bool] = False | ||
learning_rate: LearningRateConfig = field(default_factory=LearningRateConfig) | ||
|
||
|
||
@dataclass | ||
class HardwareConfig: | ||
cuda_device_number: Optional[int] = None | ||
|
||
|
||
@dataclass | ||
class LoggingConfig: | ||
tensorboard_dir: str = "logs/" | ||
|
||
|
||
@dataclass | ||
class Config: | ||
mode: str = "train" | ||
data: DataConfig = field(default_factory=DataConfig) | ||
model: ModelConfig = field(default_factory=ModelConfig) | ||
output: OutputConfig = field(default_factory=OutputConfig) | ||
training: TrainingConfig = field(default_factory=TrainingConfig) | ||
hardware: HardwareConfig = field(default_factory=HardwareConfig) | ||
logging: LoggingConfig = field(default_factory=LoggingConfig) | ||
|
||
|
||
# Hydra configuration | ||
cs = ConfigStore.instance() | ||
cs.store(name="base_config", node=Config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import os | ||
import json | ||
|
||
|
||
def read_metadata(data_path: str, purpose: str, file_name: str = "metadata.json"): | ||
"""Read metadata of datasets | ||
Args: | ||
data_path (str): Path to metadata JSON file | ||
purpose (str): Optional str whether "train" or "rollout" | ||
file_name (str): Name of metadata file | ||
Returns: | ||
metadata json object | ||
""" | ||
try: | ||
with open(os.path.join(data_path, file_name), "rt") as fp: | ||
# New version use separate metadata for `train` and `rollout`. | ||
metadata = json.loads(fp.read())[purpose] | ||
|
||
except: | ||
with open(os.path.join(data_path, file_name), "rt") as fp: | ||
# The previous format of the metadata does not distinguish the purpose of metadata | ||
metadata = json.loads(fp.read()) | ||
|
||
return metadata | ||
|
||
|
||
def flags_to_dict(FLAGS): | ||
flags_dict = {} | ||
for name in FLAGS: | ||
flag_value = FLAGS[name].value | ||
flags_dict[name] = flag_value | ||
return flags_dict |
Oops, something went wrong.