import torch
from torch import nn
from ..utils import get_dict_values
from .losses import Loss
[docs]class SimilarityLoss(Loss):
"""
Learning Modality-Invariant Representations
for Speech and Images (Leidai et. al.)
"""
def __init__(self, p1, p2, input_var=None, var=["z"], margin=0):
super().__init__(p1, p2, input_var)
self.var = var
self.loss = nn.MarginRankingLoss(margin=margin, reduce=False)
def _sim(self, x1, x2):
return torch.sum(x1*x2, dim=1)
[docs] def estimate(self, x):
x = super().estimate(x)
inputs = get_dict_values(x, self._p1.input_var, True)
sample1 = get_dict_values(self._p1.sample(inputs), self.var)[0]
inputs = get_dict_values(x, self._p2.input_var, True)
sample2 = get_dict_values(self._p2.sample(inputs), self.var)[0]
batch_size = sample1.shape[0]
shuffle_id = torch.randperm(batch_size)
_sample1 = sample1[shuffle_id]
_sample2 = sample2[shuffle_id]
sim12 = self._sim(sample1, sample2)
sim1_2 = self._sim(sample1, _sample2)
sim_12 = self._sim(_sample1, sample2)
dummy_label = torch.ones_like(sim12)
loss = self.loss(sim12, sim1_2, dummy_label) \
+ self.loss(sim12, sim_12, dummy_label)
return loss
[docs]class MultiModalContrastivenessLoss(Loss):
"""
Disentangling by Partitioning:
A Representation Learning Framework for Multimodal Sensory Data
"""
def __init__(self, p1, p2, input_var=None, margin=0.5):
super().__init__(p1, p2, input_var)
self.loss = nn.MarginRankingLoss(margin=margin)
def _sim(self, x1, x2):
return torch.exp(-torch.norm(x1-x2, 2, dim=1) / 2)
[docs] def estimate(self, x):
x = super().estimate(x)
inputs = get_dict_values(x, self._p1.input_var, True)
sample1 = self._p1.sample_mean(inputs)
inputs = get_dict_values(x, self._p2.input_var, True)
sample2 = self._p2.sample_mean(inputs)
batch_size = sample1.shape[0]
shuffle_id = torch.randperm(batch_size)
_sample1 = sample1[shuffle_id]
_sample2 = sample2[shuffle_id]
sim12 = self._sim(sample1, sample2)
sim1_2 = self._sim(sample1, _sample2)
sim_12 = self._sim(_sample1, sample2)
dummy_label = torch.ones_like(sim12)
loss = self.loss(sim12, sim1_2, dummy_label) \
+ self.loss(sim12, sim_12, dummy_label)
return loss