Source code for deepr.layers.string

"""String Layers"""

import tensorflow as tf

from deepr.layers import base
from deepr.utils.broadcasting import make_same_shape


[docs]class StringJoin(base.Layer): """String Join Layer"""
[docs] def __init__(self, n_in: int = 2, separator: str = " ", **kwargs): super().__init__(n_in=n_in, n_out=1, **kwargs) self.separator = separator
[docs] def forward(self, tensors, mode: str = None): """Forward method of the layer""" tensors = [tensor if tensor.dtype == tf.string else tf.as_string(tensor) for tensor in tensors] tensors = make_same_shape(tensors) return tf.strings.join(tensors, separator=self.separator)