Skip to content

Commit fc33ae2

Browse files
committed
loading configurational files using hydra
1 parent cdeb12c commit fc33ae2

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

marl-task/src/main.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from torchrl.objectives import ClipPPOLoss, ValueEstimators
1313
from matplotlib import pyplot as plt
1414
from tqdm import tqdm
15+
import hydra
16+
from omegaconf import DictConfig, OmegaConf
1517

1618
from envs import PSOEnv
1719
from utils import LandscapeWrapper, PSOActionExtractor, PSOObservationWrapper
@@ -99,7 +101,15 @@ def train(collector, replay_buffer, loss_module, optim, num_epochs, max_grad_nor
99101
pbar.close()
100102
return losses, rewards
101103

102-
def main():
104+
@hydra.main(version_base=None, config_path="conf", config_name="config")
105+
def main(cfg: DictConfig):
106+
print(OmegaConf.to_yaml(cfg))
107+
108+
# Access config
109+
dim = cfg.env.landscape_dim
110+
hidden = cfg.model.hidden_sizes
111+
print(f"Dim: {dim}, Hidden: {hidden}")
112+
103113
landscape_dim = 2
104114
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
105115

0 commit comments

Comments
 (0)