"""Tables Utilities"""
import logging
from typing import Dict
import numpy as np
import tensorflow as tf
from deepr.utils.field import TensorType
LOGGER = logging.getLogger(__name__)
[docs]class TableContext:
"""Context Manager to reuse Tensorflow tables.
Tensorflow does not have a ``tf.get_variable`` equivalent for
tables. The ``TableContext`` is here to provide this functionality.
Example
-------
>>> import deepr
>>> with deepr.utils.TableContext() as tables:
... table = deepr.utils.table_from_mapping(name="my_table", mapping={1: 2})
... tables.get("my_table") is table
True
>>> with deepr.utils.TableContext():
... table = deepr.utils.table_from_mapping(name="my_table", mapping={1: 2})
... reused = deepr.utils.table_from_mapping(name="my_table", reuse=True)
... table is reused
True
"""
_ACTIVE = None
[docs] def __init__(self):
if TableContext._ACTIVE is not None:
msg = "TableContext already active."
raise ValueError(msg)
TableContext._ACTIVE = self
self._tables = {}
def __enter__(self):
return self
def __exit__(self, *exc):
self.close()
def __contains__(self, name: str):
return name in self._tables
[docs] def close(self):
TableContext._ACTIVE = None
self._tables.clear()
[docs] def get(self, name: str):
if name not in self._tables:
msg = f"Table '{name}' not in tables. Did you forget a reuse=True?"
raise KeyError(msg)
return self._tables[name]
[docs] def set(self, name: str, table):
if name in self._tables:
msg = f"Table '{name}' already exists. Did you forget a reuse=True?"
raise ValueError(msg)
self._tables[name] = table
[docs] @classmethod
def is_active(cls):
return cls._ACTIVE is not None
[docs] @classmethod
def active(cls):
if cls._ACTIVE is None:
msg = "No active TableContext found. Wrap your code in a `with TableContext():`"
raise ValueError(msg)
return cls._ACTIVE
[docs]def table_from_file(name: str, path: str = None, key_dtype=None, reuse: bool = False, default_value: int = -1):
"""Create table from file"""
if reuse is True or (reuse is tf.AUTO_REUSE and name in TableContext.active()):
return TableContext.active().get(name)
else:
LOGGER.info(f"Creating table {name} from {path}")
if path is None:
raise ValueError("Path cannot be None")
table = tf.contrib.lookup.index_table_from_file(
vocabulary_file=path, name=name, key_dtype=key_dtype, default_value=default_value
)
if TableContext.is_active():
TableContext.active().set(name=name, table=table)
return table
[docs]def index_to_string_table_from_file(
name: str, path: str = None, vocab_size: int = None, default_value="UNK", reuse: bool = False
):
"""Create reverse table from file"""
if reuse is True or (reuse is tf.AUTO_REUSE and name in TableContext.active()):
return TableContext.active().get(name)
else:
LOGGER.info(f"Creating reverse table {name} from {path}")
if path is None:
raise ValueError("Path cannot be None")
table = tf.contrib.lookup.index_to_string_table_from_file(
vocabulary_file=path, name=name, vocab_size=vocab_size, default_value=default_value
)
if TableContext.is_active():
TableContext.active().set(name=name, table=table)
return table
[docs]def table_from_mapping(
name: str, mapping: Dict = None, default_value=None, key_dtype=None, value_dtype=None, reuse: bool = False
):
"""Create table from mapping"""
if reuse is True or (reuse is tf.AUTO_REUSE and name in TableContext.active()):
return TableContext.active().get(name)
else:
LOGGER.info(f"Creating table {name} from mapping.")
if mapping is None:
raise ValueError("Mapping cannot be None")
# Convert mapping to arrays of keys and values
keys, values = zip(*mapping.items()) # type: ignore
keys_np = np.array(keys)
values_np = np.array(values)
# Infer default value if not given
if default_value is None:
default_value = TensorType(type(values_np[0].item())).default
# Infer types if not given
if key_dtype is None:
key_dtype = TensorType(type(keys_np[0].item())).tf
if value_dtype is None:
value_dtype = TensorType(type(values_np[0].item())).tf
# Create table
table = tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(
keys=keys_np, values=values_np, key_dtype=key_dtype, value_dtype=value_dtype
),
name=name,
default_value=default_value,
)
if TableContext.is_active():
TableContext.active().set(name=name, table=table)
return table