from datetime import datetime, timedelta
from typing import Union, Iterable
from os.path import join
import os
import shutil
from machin.utils.logging import default_logger
from machin.utils.prepare import \
prep_clear_dirs, prep_create_dirs
[docs]class SaveEnv:
def __init__(self,
env_root: str,
restart_from_trial: Union[str, None] = None,
time_format="%Y_%m_%d_%H_%M_%S"):
"""
Create the default environment for saving. creates something like::
<your environment root>
├── config
├── log
│ ├── images
│ └── train_log
└── model
Args:
env_root: root directory for all trials of the environment.
restart_from_trial: instead of creating a new save environment
for a new trial, use a existing save environment of an older
trial, old trial name should be in format ``time_format``
"""
self.env_root = env_root
self.time_format = time_format
if restart_from_trial is None:
self.env_create_time = datetime.now()
else:
self.env_create_time = datetime.strptime(restart_from_trial,
self.time_format)
self._check_dirs()
self._prep_dirs()
[docs] def create_dirs(self, dirs: Iterable[str]):
"""
Create additional directories in root.
Args:
dirs: Directories.
"""
prep_create_dirs([join(
self.env_root,
self.env_create_time.strftime(self.time_format),
d
) for d in dirs])
[docs] def get_trial_root(self):
# pylint: disable=missing-docstring
return join(self.env_root,
self.env_create_time.strftime(self.time_format))
[docs] def get_trial_config_dir(self):
# pylint: disable=missing-docstring
return join(self.env_root,
self.env_create_time.strftime(self.time_format),
"config")
[docs] def get_trial_model_dir(self):
# pylint: disable=missing-docstring
return join(self.env_root,
self.env_create_time.strftime(self.time_format),
"model")
[docs] def get_trial_image_dir(self):
# pylint: disable=missing-docstring
return join(self.env_root,
self.env_create_time.strftime(self.time_format),
"log", "images")
[docs] def get_trial_train_log_dir(self):
# pylint: disable=missing-docstring
return join(self.env_root,
self.env_create_time.strftime(self.time_format),
"log", "train_log")
[docs] def get_trial_time(self):
# pylint: disable=missing-docstring
return self.env_create_time
[docs] def clear_trial_config_dir(self):
# pylint: disable=missing-docstring
prep_clear_dirs([
join(self.env_root,
self.env_create_time.strftime(self.time_format),
"config")
])
[docs] def clear_trial_model_dir(self):
# pylint: disable=missing-docstring
prep_clear_dirs([
join(self.env_root,
self.env_create_time.strftime(self.time_format),
"model")
])
[docs] def clear_trial_image_dir(self):
# pylint: disable=missing-docstring
prep_clear_dirs([
join(self.env_root,
self.env_create_time.strftime(self.time_format),
"log", "images")
])
[docs] def clear_trial_train_log_dir(self):
# pylint: disable=missing-docstring
prep_clear_dirs([
join(self.env_root,
self.env_create_time.strftime(self.time_format),
"log", "train_log")
])
[docs] def remove_trials_older_than(self,
diff_day: int = 0,
diff_hour: int = 1,
diff_minute: int = 0,
diff_second: int = 0):
"""
By default this function removes all trials started one hour earlier
than current time.
Args:
diff_day: Difference in days.
diff_hour: Difference in hours.
diff_minute: Difference in minutes.
diff_second: Difference in seconds.
"""
trial_list = [f for f in os.listdir(self.env_root)]
current_time = datetime.now()
diff_threshold = timedelta(days=diff_day, hours=diff_hour,
minutes=diff_minute, seconds=diff_second)
for file in trial_list:
try:
time = datetime.strptime(file, self.time_format)
except ValueError:
# not a trial
pass
else:
diff_time = current_time - time
if diff_time > diff_threshold:
rm_path = join(self.env_root, file)
default_logger.info("Removing trial directory: {}"
.format(rm_path))
shutil.rmtree(rm_path)
def _prep_dirs(self):
root_dir = join(self.env_root,
self.env_create_time.strftime(self.time_format))
prep_create_dirs((join(root_dir, "model"),
join(root_dir, "config"),
join(root_dir, "log", "images"),
join(root_dir, "log", "train_log")))
def _check_dirs(self):
"""
Overload this function in your environment class to check directory
mapping
Raises:
RuntimeError if directory mapping is invalid.
"""
pass