import os
import shutil
import time
import uuid
import warnings
from _thread import start_new_thread
from ast import literal_eval
from contextlib import ContextDecorator
from ctypes import c_long, py_object, pythonapi
from multiprocessing import TimeoutError as MpTimeoutError
from queue import Empty as Queue_Empty
from queue import Queue
from typing import Optional, Sequence
from absl import logging
from rich.table import Table
warnings.simplefilter(action="ignore")
[docs]
def create_dir(dir, overwrite=False):
"""Creates a directory at the specified path. If the directory already exists,
it can optionally overwrite its contents.
Args:
dir (str): The path of the directory to create.
overwrite (bool, optional): If True and the directory exists,
all its contents will be removed. Defaults to False.
"""
if os.path.exists(dir):
if overwrite:
for f in os.listdir(dir):
if os.path.isdir(os.path.join(dir, f)):
shutil.rmtree(os.path.join(dir, f))
else:
os.remove(os.path.join(dir, f))
else:
os.makedirs(dir)
[docs]
def to_eng_str(num, long_than=4, precision=3):
"""Convert a number to engineering notation with a specified precision.
Args:
num (float): The number to convert.
long_than (int, optional): The minimum length of the number string.
Defaults to 4.
precision (int, optional): The number of decimal places to include.
Defaults to 3.
Returns:
str: The number in engineering notation.
"""
if not isinstance(num, (int, float)):
return str(num)
if 10**-long_than < abs(num) < 10**long_than:
num = str(num)
if len(num) > long_than:
return num[:long_than]
return num
return f"{num:.{precision}e}"
[docs]
def df2richtable(
df,
title=None,
max_row_len=None,
highlight_columns: Optional[list] = None,
max_col_width=None,
col_center=None,
max_seq_value_len=None,
list_centers: Optional[dict] = None,
highlight_list_centers=False,
alternating_row_colors=False,
use_eng_str=False,
):
def _trnc(l_len, width=None, center=0):
if not width or l_len <= width:
return (0, l_len), (False, False), 0
mid = width // 2
raw_lr = (center - mid, center + mid + 1)
trunc_lr = (raw_lr[0] > 0, raw_lr[1] < l_len)
shift = 0
if not trunc_lr[0]:
shift = -raw_lr[0]
elif not trunc_lr[1]:
shift = l_len - raw_lr[1]
final_lr = (i + shift for i in raw_lr)
center_idx = mid - shift
return final_lr, trunc_lr, center_idx
def add_trnc_ind(l, trc, s="...", sft=0):
if trc[0]:
l = l[:sft] + [s] + l[sft:]
if trc[1]:
l = l + [s]
return l
idx_len = len(df.index.names)
col_center = df.columns.get_loc(col_center or df.columns[0])
col_lr, is_trnc_col_lr, _ = _trnc(len(df.columns), max_col_width, col_center)
df = df[df.columns[slice(*col_lr)]]
df = df.reset_index()
df_len = len(df)
if max_row_len:
df_tail = len(df), *df.tail(1).iloc[0].values
df = df.head(max_row_len)
list_centers = list_centers or {}
centers = [None] + [list_centers.get(n, None) for n in df.columns]
centers = add_trnc_ind(centers, is_trnc_col_lr, s=None, sft=idx_len + 1)
highlight_columns = highlight_columns or []
h_col = [None] + [n in highlight_columns for n in df.columns]
h_col = add_trnc_ind(h_col, is_trnc_col_lr, s=None, sft=idx_len + 1)
table = Table(title=title)
table.add_column("id")
for f in add_trnc_ind(list(df), is_trnc_col_lr, sft=idx_len):
table.add_column(f)
def _process_entry(v, col_i):
clr = lambda s: f"[on red]{s}[/on red]" if h_col[col_i] else str(s)
if not max_seq_value_len or not isinstance(v, Sequence) or len(v) <= max_seq_value_len:
if use_eng_str and isinstance(v, (int, float)):
v = to_eng_str(v, precision=0)
return clr(v)
par = "[]" if isinstance(v, list) else "()"
is_tuple = isinstance(v, tuple)
v = list(v)
l2s = lambda l: par[0] + ", ".join(map(str, l)) + par[1]
# if v is seq
c = centers[col_i] or 0
ml = max_seq_value_len
lr, is_trnc_lr, c_i = _trnc(len(v), ml, c)
slc_l = v[slice(*lr)]
slc_l = add_trnc_ind(slc_l, is_trnc_lr, s="...", sft=0)
if use_eng_str:
slc_l = [to_eng_str(n, precision=0) for n in slc_l]
if highlight_list_centers:
slc_l[c_i + idx_len] = clr(slc_l[c_i + idx_len])
return l2s(slc_l)
else:
return clr(l2s(tuple(slc_l) if is_tuple else slc_l))
prev_row = None
for i, row in enumerate(df.itertuples(name=None)):
print_row = [*row]
if prev_row:
print_row = ["" if p == r else r for p, r in zip(prev_row, print_row)]
print_row = add_trnc_ind(print_row, is_trnc_col_lr, sft=idx_len + 1)
table.add_row(
*[_process_entry(v, j) for j, v in enumerate(print_row)],
style="on bright_black" if (alternating_row_colors and i % 2) else "",
)
prev_row = row
if max_row_len and max_row_len < df_len:
table.add_row("", *([""] * len(df.columns)))
table.add_row("...", *(["..."] * len(df.columns)))
table.add_row("", *([""] * len(df.columns)))
print(df_tail[0])
df_tail = add_trnc_ind(df_tail, is_trnc_col_lr, sft=idx_len + 1)
table.add_row(*[_process_entry(v, j) for j, v in enumerate(df_tail)])
return table
[docs]
def list2tuple(l):
if isinstance(l, list):
return tuple(map(list2tuple, l))
if isinstance(l, dict):
return {k: list2tuple(v) for k, v in l.items()}
return l
[docs]
def str2value(value_str):
"""Casts string back to standard python types"""
if not isinstance(value_str, str):
return value_str
value_str = value_str.replace("inf", "2e+308").replace("nan", "None")
try:
return literal_eval(value_str)
except:
return value_str
[docs]
def append_metrics(metric_log=None, **new_metrics):
"""Add new metrics to metric_log"""
if metric_log == None:
metric_log = {}
for k, v in new_metrics.items():
assert type(v) in {int, float, bool, str}
metric_log[k] = metric_log.get(k, [])
metric_log[k].append(v)
return metric_log
[docs]
class QueuedFileLock(ContextDecorator):
__delim = "\n"
def __init__(self, lock_file: str, timeout: float = 10):
self.lock_file = lock_file
self.timeout = timeout
self.id = uuid.uuid4().int
self.acquire_count = 0
if not os.path.exists(lock_file):
with open(lock_file, "w") as f:
f.write("")
self.__read_queue()
def __read_queue(self):
success = False
for i in range(10):
try:
with open(self.lock_file) as f:
s = f.read()
parseint = lambda x: int(x.strip("\x00"))
self.queue = [*map(parseint, filter(bool, s.split(self.__delim)))]
success = True
break
except:
logging.info(f"Failed to read queue from {self.lock_file} (Attempt: {i + 1}/10). Retrying after 0.1s.")
time.sleep(0.1)
continue
if not success:
raise Exception(f"Failed to read queue from {self.lock_file}.")
def __write_queue(self):
with open(self.lock_file, "w") as f:
s = self.__delim.join(map(str, self.queue))
f.write(s)
def __append_write(self):
with open(self.lock_file, "a") as f:
f.write(f"{self.__delim}{self.id}")
self.__read_queue()
@property
def is_locked(self):
self.__read_queue()
return not self.queue or self.queue[0] != self.id
[docs]
def acquire(self, timeout: Optional[float] = None, poll_interval: float = 0.05):
self.acquire_count += 1
if timeout is None:
timeout = self.timeout
self.__read_queue()
if self.id not in self.queue:
self.__append_write()
logging.debug(f"Attempting to acquire filelock {self.id} on {self.lock_file}.")
start_t = time.time()
while self.is_locked:
logging.debug(
f"Failed to acquire filelock {self.id} on {self.lock_file}. Waiting for {poll_interval} seconds."
)
time.sleep(poll_interval)
if time.time() - start_t > timeout:
raise TimeoutError(f"Timeout while acquiring filelock {self.id} on {self.lock_file}.")
logging.debug(f"Filelock {self.id} acquired on {self.lock_file}.")
[docs]
def release(self, force=False):
if self.acquire_count == 0:
return
if self.acquire_count >= 1:
self.acquire_count -= 1
if self.acquire_count == 0 or force:
self.acquire_count = 0
self.__read_queue()
self.queue.remove(self.id)
self.__write_queue()
logging.debug(f"Released filelock {self.id} on {self.lock_file}.")
def __enter__(self):
self.acquire()
return self
def __exit__(self, *args):
self.release()
def __del__(self):
self.release(force=True)
[docs]
class FuncTimeoutError(Exception):
pass
[docs]
def async_raise(tid, exctype=Exception):
"""Raise an Exception in the Thread with id `tid`. Perform cleanup if
needed.
Based on Killable Threads By Tomer Filiba
from http://tomerfiliba.com/recipes/Thread2/
license: public domain.
"""
assert isinstance(tid, int), "Invalid thread id: must an integer"
tid = c_long(tid)
exception = py_object(exctype)
res = pythonapi.PyThreadState_SetAsyncExc(tid, exception)
if res == 0:
raise ValueError("Invalid thread id.")
elif res != 1:
# if it returns a number greater than one, you're in trouble,
# and you should call it again with exc=NULL to revert the effect
pythonapi.PyThreadState_SetAsyncExc(tid, 0)
raise SystemError("PyThreadState_SetAsyncExc failed.")
[docs]
def settimeout_func(func, timeout=3 * 24 * 60 * 60):
if timeout is None:
return func
def timeoutfunc(*args, **kwargs):
"""Threads-based interruptible runner, but is not reliable and works
only if everything is pickable.
"""
# We run `func` in a thread and block on a queue until timeout
q = Queue()
def runner():
try:
_res = func(*(args or ()), **(kwargs or {}))
q.put((None, _res))
except FuncTimeoutError:
# rasied by async_rasie to kill the orphan threads
pass
except Exception as ex:
q.put((ex, None))
tid = start_new_thread(runner, ())
try:
err, res = q.get(timeout=timeout)
if err:
raise err
return res
except (Queue_Empty, MpTimeoutError):
raise FuncTimeoutError(f"{func.__name__} timeout (taking more than {timeout} sec)")
finally:
try:
async_raise(tid, FuncTimeoutError)
except (SystemExit, ValueError):
pass
return timeoutfunc
[docs]
def path_common_decomposition(paths):
"""Decomposes a list of paths into a common prefix and remaining suffixes.
Args:
paths: List of file/directory path strings.
Returns:
tuple: A pair of (common_prefix, non_common_suffixes) where
common_prefix is a list of shared path components and
non_common_suffixes is a list of remaining path strings.
"""
if len(paths) == 0:
return "", []
elif len(paths) == 1:
return paths[0], []
paths = [p.split("/") for p in paths]
# Find common prefix
common = []
for _ in range(min(map(len, paths))):
if len(set(p[0] for p in paths)) != 1:
break
common.append(paths[0][0])
paths = [p[1:] for p in paths]
non_common = [os.path.join(*p) for p in paths]
return common, non_common
[docs]
def get_wandb_sweep_exp_dir(base_dir, entity, project, sweep_id):
"""Constructs the experiment directory path for a W&B sweep.
Args:
base_dir: Base directory for experiment storage.
entity: W&B entity name (unused in path construction).
project: W&B project name.
sweep_id: W&B sweep identifier.
Returns:
str: Path in the form ``base_dir/project/sweep_id``.
"""
return os.path.join(base_dir, project, sweep_id)