import tables
import functools
import pickle
import numpy as np
from pathlib import Path
from astropy.table import Table, MaskedColumn
class TableCache:
@staticmethod
def save(table, filename, force=False):
if filename.exists() and not force:
return
if not filename.parent.exists():
filename.parent.mkdir(parents=True, exist_ok=True)
if not isinstance(table, Table):
table = Table(table)
if "description" in table.meta:
table.meta["descript"] = table.meta["description"]
del table.meta["description"]
table.write(filename, overwrite=force)
@staticmethod
def load(filename):
if filename.exists():
table = Table.read(filename)
table.convert_bytestring_to_unicode()
return table
class HDFCache:
@staticmethod
def save(table, filename, force=False):
if filename.exists() and not force:
raise Exception(f"File {filename} already exists.")
if not filename.parent.exists():
filename.parent.mkdir(parents=True, exist_ok=True)
if isinstance(table, Table):
table = table.as_array()
mask = Table()
for col in table.dtype.names:
if hasattr(table[col], "mask"):
mask[col] = table[col].mask
with tables.open_file(filename, "w") as h5:
h5.create_table("/", "data", table)
h5.create_table("/", "mask", mask.as_array())
@staticmethod
def load(filename):
if filename.exists():
with tables.open_file(filename) as h5:
table = Table(h5.root.data[:])
mask = h5.root.mask[:]
for col in mask.dtype.names:
table[col] = MaskedColumn(table[col], mask=mask[col])
return table
class PickleCache:
@staticmethod
def save(table, filename, force=False):
if filename.exists() and not force:
raise Exception(f"File {filename} already exists.")
if not filename.parent.exists():
filename.parent.mkdir(parents=True, exist_ok=True)
if isinstance(table, Table):
table = table.as_array()
with open(filename, "wb") as fh:
pickle.dump(table, fh)
@staticmethod
def load(filename):
if filename.exists():
with open(filename, "rb") as fh:
return pickle.load(fh)
CACHES = {
".fits.gz": TableCache,
".fits": TableCache,
".csv": TableCache,
".h5": HDFCache,
".pkl": PickleCache,
}
class Cache:
@staticmethod
def save(table, filenames, force=False):
if isinstance(filenames, list):
filenames = tuple(filenames)
elif not isinstance(filenames, tuple):
filenames = (filenames,)
filenames = [Path(filename) for filename in filenames]
cache = [CACHES["".join(filename.suffixes)] for filename in filenames]
if not hasattr(table, "__len__") or isinstance(table, Table):
table = [table]
assert len(table) == len(
filenames
), f"Expected {len(filenames)} tables, got {len(table)}."
for res, cch, filename in zip(table, cache, filenames):
cch.save(res, filename, force)
@staticmethod
def load(filenames):
if isinstance(filenames, list):
filenames = tuple(filenames)
elif not isinstance(filenames, tuple):
filenames = (filenames,)
filenames = [Path(filename) for filename in filenames]
cache = [CACHES["".join(filename.suffixes)] for filename in filenames]
result = [cch.load(filename) for cch, filename in zip(cache, filenames)]
if np.any([res is None for res in result]):
# one invalid entry invalidates all the others, since they come from the same function
return None
if len(result) == 1:
result = result[0]
return result
[docs]def cached(name=None, force=False):
"""Decorator to cache the result of a function in a [list of] file[s]."""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, filenames=None, force=force, **kwargs):
from agasc_gaia.config import FILES
if filenames is None:
filenames = FILES[name]
result = None
if not force:
result = Cache.load(filenames)
if result is None:
result = func(*args, **kwargs)
Cache.save(result, filenames, force)
return result
return wrapper
return decorator
class Files:
def __init__(self, data_dir, files):
self.data_dir = data_dir
self.files = files
def _prepend_path_(self, value):
if isinstance(value, list) or isinstance(value, tuple):
value = [self.data_dir / file for file in value]
else:
value = self.data_dir / value
return value
def keys(self):
return self.files.keys()
def values(self):
for value in self.files.values():
yield self._prepend_path_(value)
def items(self):
for key, value in self.files.items():
yield key, self._prepend_path_(value)
def __len__(self):
return len(self.files)
def __iter__(self):
return self.files.__iter__()
def __repr__(self):
return repr(dict(self))
def __str__(self):
return str(dict(self))
def __getitem__(self, key):
return self._prepend_path_(self.files[key])