Source code for pixyz.losses.losses

import numbers
from copy import deepcopy

from ..utils import tolist


[docs]class Loss(object): def __init__(self, p1, p2=None, input_var=None): self._p1 = p1 self._p2 = p2 self._loss_text = None if input_var is not None: self._input_var = input_var else: _input_var = deepcopy(p1.input_var) if p2 is not None: _input_var += deepcopy(p2.input_var) _input_var = sorted(set(_input_var), key=_input_var.index) self._input_var = _input_var @property def input_var(self): return self._input_var @property def loss_text(self): return "loss({},{})".format(self._p1.prob_text, self._p2.prob_text) def __str__(self): return self.loss_text def __add__(self, other): return AddLoss(self, other) def __radd__(self, other): return AddLoss(self, other) def __sub__(self, other): return SubLoss(self, other) def __rsub__(self, other): return SubLoss(self, other) def __mul__(self, other): return MulLoss(self, other) def __rmul__(self, other): return MulLoss(self, other) def __truediv__(self, other): return DivLoss(self, other) def __rtruediv__(self, other): return DivLoss(self, other) def __neg__(self): return NegLoss(self)
[docs] def abs(self): return AbsLoss(self)
[docs] def mean(self): return BatchMean(self)
[docs] def sum(self): return BatchSum(self)
[docs] def estimate(self, x={}, return_dict=False, **kwargs): if not(set(list(x.keys())) >= set(self._input_var)): raise ValueError("Input keys are not valid, got {}.".format(list(x.keys()))) loss, x = self._get_estimated_value(x, **kwargs) if return_dict: return loss, x return loss
def _get_estimated_value(self, x, **kwargs): return x, x
[docs] def train(self, x={}, **kwargs): """ Train the implicit (adversarial) loss function. """ return 0
[docs] def test(self, x={}, **kwargs): """ Test the implicit (adversarial) loss function. """ return 0
class ValueLoss(Loss): def __init__(self, loss1): self._loss1 = loss1 self._input_var = [] def _get_estimated_value(self, x={}, **kwargs): return self._loss1, x @property def loss_text(self): return str(self._loss1)
[docs]class Parameter(Loss): def __init__(self, input_var): if not isinstance(input_var, str): raise ValueError self._input_var = tolist(input_var) def _get_estimated_value(self, x={}, **kwargs): return x[self._input_var[0]], x @property def loss_text(self): return str(self._input_var[0])
[docs]class LossOperator(Loss): def __init__(self, loss1, loss2): _input_var = [] if isinstance(loss1, Loss): _input_var += deepcopy(loss1.input_var) elif isinstance(loss1, numbers.Number): loss1 = ValueLoss(loss1) elif isinstance(loss2, type(None)): pass else: raise ValueError("{} cannot be operated with {}.".format(type(loss1), type(loss2))) if isinstance(loss2, Loss): _input_var += deepcopy(loss2.input_var) elif isinstance(loss2, numbers.Number): loss2 = ValueLoss(loss2) elif isinstance(loss2, type(None)): pass else: raise ValueError("{} cannot be operated with {}.".format(type(loss2), type(loss1))) _input_var = sorted(set(_input_var), key=_input_var.index) self._input_var = _input_var self._loss1 = loss1 self._loss2 = loss2 @property def _loss_text_list(self): loss_text_list = [] if not isinstance(self._loss1, type(None)): loss_text_list.append(self._loss1.loss_text) if not isinstance(self._loss2, type(None)): loss_text_list.append(self._loss2.loss_text) return loss_text_list @property def loss_text(self): NotImplementedError def _get_estimated_value(self, x={}, **kwargs): if not isinstance(self._loss1, type(None)): loss1, x1 = self._loss1._get_estimated_value(x, **kwargs) else: loss1 = 0 x1 = {} if not isinstance(self._loss2, type(None)): loss2, x2 = self._loss2._get_estimated_value(x, **kwargs) else: loss2 = 0 x2 = {} x1.update(x2) return loss1, loss2, x1
[docs] def train(self, x, **kwargs): """ TODO: Fix """ loss1 = self._loss1.train(x, **kwargs) loss2 = self._loss2.train(x, **kwargs) return loss1 + loss2
[docs] def test(self, x, **kwargs): """ TODO: Fix """ loss1 = self._loss1.test(x, **kwargs) loss2 = self._loss2.test(x, **kwargs) return loss1 + loss2
[docs]class AddLoss(LossOperator): @property def loss_text(self): return " + ".join(self._loss_text_list) def _get_estimated_value(self, x={}, **kwargs): loss1, loss2, x = super()._get_estimated_value(x, **kwargs) return loss1 + loss2, x
[docs]class SubLoss(LossOperator): @property def loss_text(self): return " - ".join(self._loss_text_list) def _get_estimated_value(self, x={}, **kwargs): loss1, loss2, x = super()._get_estimated_value(x, **kwargs) return loss1 - loss2, x
[docs]class MulLoss(LossOperator): @property def loss_text(self): return " * ".join(self._loss_text_list) def _get_estimated_value(self, x={}, **kwargs): loss1, loss2, x = super()._get_estimated_value(x, **kwargs) return loss1 * loss2, x
[docs]class DivLoss(LossOperator): @property def loss_text(self): return " / ".join(self._loss_text_list) def _get_estimated_value(self, x={}, **kwargs): loss1, loss2, x = super()._get_estimated_value(x, **kwargs) return loss1 / loss2, x
[docs]class LossSelfOperator(Loss): def __init__(self, loss1): _input_var = [] if isinstance(loss1, type(None)): raise ValueError if isinstance(loss1, Loss): _input_var = deepcopy(loss1.input_var) elif isinstance(loss1, numbers.Number): loss1 = ValueLoss(loss1) else: raise ValueError self._input_var = _input_var self._loss1 = loss1
[docs] def train(self, x={}, **kwargs): return self._loss1.train(x, **kwargs)
[docs] def test(self, x={}, **kwargs): return self._loss1.test(x, **kwargs)
[docs]class NegLoss(LossSelfOperator): @property def loss_text(self): return "-({})".format(self._loss1.loss_text) def _get_estimated_value(self, x={}, **kwargs): loss, x = self._loss1._get_estimated_value(x, **kwargs) return -loss, x
[docs]class AbsLoss(LossSelfOperator): @property def loss_text(self): return "|{}|".format(self._loss1.loss_text) def _get_estimated_value(self, x={}, **kwargs): loss, x = self._loss1._get_estimated_value(x, **kwargs) return loss.abs(), x
[docs]class BatchMean(LossSelfOperator): r""" Loss averaged over batch data. .. math:: \mathbb{E}_{p_{data}(x)}[\mathcal{L}(x)] \approx \frac{1}{N}\sum_{i=1}^N \mathcal{L}(x_i), where :math:`x_i \sim p_{data}(x)` and :math:`\mathcal{L}` is a loss function. """ @property def loss_text(self): return "mean({})".format(self._loss1.loss_text) # TODO: fix it def _get_estimated_value(self, x={}, **kwargs): loss, x = self._loss1._get_estimated_value(x, **kwargs) return loss.mean(), x
[docs]class BatchSum(LossSelfOperator): r""" Loss summed over batch data. .. math:: \sum_{i=1}^N \mathcal{L}(x_i), where :math:`x_i \sim p_{data}(x)` and :math:`\mathcal{L}` is a loss function. """ @property def loss_text(self): return "sum({})".format(self._loss1.loss_text) # TODO: fix it def _get_estimated_value(self, x={}, **kwargs): loss, x = self._loss1._get_estimated_value(x, **kwargs) return loss.sum(), x