# Author: Theo Gnassounou <theo.gnassounou@inria.fr>
# Remi Flamary <remi.flamary@polytechnique.edu>
# Yanis Lalou <yanis.lalou@polytechnique.edu>
# Antoine Collas <contact@antoinecollas.fr>
#
# License: BSD 3-Clause
from functools import partial
import ot
import skorch # noqa: F401
import torch # noqa: F401
import torch.nn.functional as F
from torch.nn.functional import mse_loss
from skada.deep.base import BaseDALoss
from skada.deep.utils import SphericalKMeans
[docs]
def deepcoral_loss(features, features_target, assume_centered=False):
"""Estimate the Frobenius norm divide by 4*n**2
for DeepCORAL method [12]_.
Parameters
----------
features : tensor
Source features.
features_target : tensor
Target features.
assume_centered: bool, default=False
If True, data are not centered before computation.
Returns
-------
loss : ndarray
The loss of the method.
References
----------
.. [12] Baochen Sun and Kate Saenko. Deep coral:
Correlation alignment for deep domain
adaptation. In ECCV Workshops, 2016.
"""
if not assume_centered:
features = features - features.mean(0)
features_target = features_target - features_target.mean(0)
cov = torch.cov(features.T)
cov_target = torch.cov(features_target.T)
divergence = mse_loss(cov, cov_target, reduction="sum")
dim = features.shape[1]
loss = (1 / (4 * (dim**2))) * divergence
return loss
[docs]
def deepjdot_loss(
y_s,
y_pred_t,
features_s,
features_t,
reg_dist,
reg_cl,
sample_weights=None,
target_sample_weights=None,
criterion=None,
):
"""Compute the OT loss for DeepJDOT method [13]_.
Parameters
----------
y_s : tensor
labels of the source data used to perform the distance matrix.
y_pred_t : tensor
labels of the target data used to perform the distance matrix.
features_s : tensor
features of the source data used to perform the distance matrix.
features_t : tensor
features of the target data used to perform the distance matrix.
reg_dist : float
Divergence term regularization parameter.
reg_cl : float, default=1
Class distance term regularization parameter.
sample_weights : tensor
Weights of the source samples.
If None, create uniform weights.
target_sample_weights : tensor
Weights of the source samples.
If None, create uniform weights.
criterion : torch criterion (class)
The criterion (loss) used to compute the
DeepJDOT loss. If None, use the CrossEntropyLoss.
Returns
-------
loss : ndarray
The loss of the method.
References
----------
.. [13] Bharath Bhushan Damodaran, Benjamin Kellenberger,
Remi Flamary, Devis Tuia, and Nicolas Courty.
DeepJDOT: Deep Joint Distribution Optimal Transport
for Unsupervised Domain Adaptation. In ECCV 2018
15th European Conference on Computer Vision,
September 2018. Springer.
"""
dist = torch.cdist(features_s, features_t, p=2) ** 2
y_target_matrix = y_pred_t.repeat(len(y_pred_t), 1, 1).permute(1, 2, 0)
if criterion is None:
criterion = torch.nn.CrossEntropyLoss(reduction="none")
loss_target = criterion(y_target_matrix, y_s.repeat(len(y_s), 1)).T
M = reg_dist * dist + reg_cl * loss_target
# Compute the loss
if sample_weights is None:
sample_weights = torch.full(
(len(features_s),), 1.0 / len(features_s), device=features_s.device
)
if target_sample_weights is None:
target_sample_weights = torch.full(
(len(features_t),), 1.0 / len(features_t), device=features_t.device
)
loss = ot.emd2(sample_weights, target_sample_weights, M)
return loss
def _gaussian_kernel(x, y, sigmas):
"""Computes multi gaussian kernel between each pair of the two vectors."""
sigmas = sigmas.view(sigmas.shape[0], 1)
beta = 1.0 / sigmas
dist = torch.cdist(x, y)
dist_ = dist.view(1, -1)
s = torch.matmul(beta, dist_)
return torch.sum(torch.exp(-s), 0).view_as(dist)
def _maximum_mean_discrepancy(x, y, kernel):
"""Computes the maximum mean discrepancy between the vectors
using the given kernel.
"""
cost = torch.mean(kernel(x, x))
cost += torch.mean(kernel(y, y))
cost -= 2 * torch.mean(kernel(x, y))
return cost
[docs]
def dan_loss(features_s, features_t, sigmas=None, eps=1e-7):
"""Define the mmd loss based on multi-kernel defined in [14]_.
Parameters
----------
features_s : tensor
Source features used to compute the mmd loss.
features_t : tensor
Target features used to compute the mmd loss.
sigmas : array like, default=None,
If array, sigmas used for the multi gaussian kernel.
If None, uses sigmas proposed in [1]_.
eps : float, default=1e-7
Small constant added to median distance calculation for numerical stability.
Returns
-------
loss : float
The loss of the method.
References
----------
.. [14] Mingsheng Long et. al. Learning Transferable
Features with Deep Adaptation Networks.
In ICML, 2015.
"""
if sigmas is None:
median_pairwise_distance = (
torch.median(torch.cdist(features_s, features_s)) + eps
)
sigmas = (
torch.tensor([2 ** (-8) * 2 ** (i * 1 / 2) for i in range(33)]).to(
features_s.device
)
* median_pairwise_distance
)
else:
sigmas = torch.tensor(sigmas).to(features_s.device)
gaussian_kernel = partial(_gaussian_kernel, sigmas=sigmas)
loss = _maximum_mean_discrepancy(features_s, features_t, kernel=gaussian_kernel)
return loss
[docs]
def cdd_loss(
y_s,
features_s,
features_t,
target_kmeans,
sigmas=None,
distance_threshold=0.5,
class_threshold=3,
eps=1e-7,
):
"""Define the contrastive domain discrepancy loss based on [33]_.
Parameters
----------
y_s : tensor
labels of the source data used to compute the loss.
features_s : tensor
features of the source data used to compute the loss.
features_t : tensor
features of the target data used to compute the loss.
target_kmeans : SphericalKMeans
Pre-computed target KMeans clustering model.
sigmas : array like, default=None,
If array, sigmas used for the multi gaussian kernel.
If None, uses sigmas proposed in [1]_.
distance_threshold : float, optional (default=0.5)
Distance threshold for discarding the samples that are
to far from the centroids.
class_threshold : int, optional (default=3)
Minimum number of samples in a class to be considered for the loss.
eps : float, default=1e-7
Small constant added to median distance calculation for numerical stability.
Returns
-------
loss : float
The loss of the method.
References
----------
.. [33] Kang, G., Jiang, L., Yang, Y., & Hauptmann, A. G. (2019).
Contrastive adaptation network for unsupervised domain adaptation.
In Proceedings of the IEEE/CVF Conference on Computer Vision
and Pattern Recognition (pp. 4893-4902).
"""
n_classes = len(y_s.unique())
# Use pre-computed target_kmeans
if type(target_kmeans) is not SphericalKMeans:
raise ValueError(
"cdd_loss: Please ensure `target_kmeans` is initialized before proceeding."
"A fitted SphericalKMeans should be provided."
)
# Predict clusters for target samples
cluster_labels_t = target_kmeans.predict(features_t)
# Discard ambiguous target samples
similarities = F.cosine_similarity(
features_t.unsqueeze(1), target_kmeans.cluster_centers_.unsqueeze(0)
)
mask_t = 0.5 * (1 - similarities.max(dim=1)[0]) < distance_threshold
features_t = features_t[mask_t]
cluster_labels_t = cluster_labels_t[mask_t]
# Discard ambiguous classes
class_counts = torch.bincount(cluster_labels_t, minlength=n_classes)
valid_classes = class_counts >= class_threshold
mask_t = valid_classes[cluster_labels_t]
features_t = features_t[mask_t]
cluster_labels_t = cluster_labels_t[mask_t]
# Define sigmas
if sigmas is None:
median_pairwise_distance = (
torch.median(torch.cdist(features_s, features_s)) + eps
)
sigmas = (
torch.tensor([2 ** (-8) * 2 ** (i * 1 / 2) for i in range(33)]).to(
features_s.device
)
* median_pairwise_distance
)
else:
sigmas = torch.tensor(sigmas).to(features_s.device)
# Compute CDD
intraclass = 0
interclass = 0
for c1 in range(n_classes):
for c2 in range(c1, n_classes):
if valid_classes[c1] and valid_classes[c2]:
# Compute e1
kernel_ss = _gaussian_kernel(features_s, features_s, sigmas)
mask_c1_c1 = (y_s == c1).float()
# e1 measure the intra-class domain discrepancy
# Thus if mask_c1_c1.sum() = 0 --> e1 = 0
if mask_c1_c1.sum() > 0:
e1 = (kernel_ss * mask_c1_c1).sum() / (mask_c1_c1.sum() ** 2)
else:
e1 = 0
# Compute e2
kernel_tt = _gaussian_kernel(features_t, features_t, sigmas)
mask_c2_c2 = (cluster_labels_t == c2).float()
# e2 measure the intra-class domain discrepancy
# Thus if mask_c2_c2.sum() = 0 --> e2 = 0
if mask_c2_c2.sum() > 0:
e2 = (kernel_tt * mask_c2_c2).sum() / (mask_c2_c2.sum() ** 2)
else:
e2 = 0
# Compute e3
kernel_st = _gaussian_kernel(features_s, features_t, sigmas)
mask_c1 = (y_s == c1).float().unsqueeze(1)
mask_c2 = (cluster_labels_t == c2).float().unsqueeze(0)
mask_c1_c2 = mask_c1 * mask_c2
# e3 measure the inter-class domain discrepancy
# Thus if mask_c1_c2.sum() = 0 --> e3 = 0
if mask_c1_c2.sum() > 0:
e3 = (kernel_st * mask_c1_c2).sum() / (mask_c1_c2.sum() ** 2)
else:
e3 = 0
if c1 == c2:
intraclass += e1 + e2 - 2 * e3
else:
interclass += e1 + e2 - 2 * e3
cdd = (intraclass / len(valid_classes)) - (
interclass / (len(valid_classes) ** 2 - len(valid_classes))
)
return cdd
class TestLoss(BaseDALoss):
"""Test Loss to check the deep API"""
def __init__(
self,
):
super().__init__()
def forward(
self,
**kwargs,
):
"""Compute the domain adaptation loss"""
return 0
def probability_scaling(logits, temperature=1):
"""Probability scaling.
Parameters
----------
logits : torch.Tensor
The logits.
temperature : float, default=1
The temperature.
Returns
-------
torch.Tensor
The scaled probabilities.
"""
return torch.nn.functional.softmax(logits / temperature, dim=1)
[docs]
def mcc_loss(y, T=1, eps=1e-7):
"""Estimate the Frobenius norm divide by 4*n**2
for DeepCORAL method [33]_.
Parameters
----------
y : tensor
The output of target domain of the model.
T : float, default=1
The temperature for the scaling.
eps : float, default=1e-7
Small constant added to median distance calculation for numerical stability.
Returns
-------
loss : ndarray
The loss of the method.
References
----------
.. [33] Ying Jin, Ximei Wang, Mingsheng Long, Jianmin Wang.
Minimum Class Confusion for Versatile Domain Adaptation.
In ECCV, 2020.
"""
# Probability Rescaling
y_scaled = probability_scaling(y, temperature=T)
# Uncertainty Reweighting & class correlation matrix
H = -torch.sum(y_scaled * torch.log(y_scaled + eps), axis=1)
W = (1 + torch.exp(-H)) / torch.mean(1 + torch.exp(-H))
y_weighted = torch.matmul(torch.diag(W), y_scaled)
C = torch.einsum("ij,ik->jk", y_scaled, y_weighted)
# Category Normalization
C_tilde = C / torch.sum(C, axis=1, keepdim=True)
# MCC Loss
C_ = C_tilde - torch.diag(torch.diag(C_tilde))
loss = torch.mean(torch.sum(torch.abs(C_), axis=1))
return loss
def _adj(s, t, metric="euc"):
"""Inspired by https://github.com/CrownX/SPA"""
# s, t [bsize, dim], [bsize, dim] -> [bsize, bsize]
if metric == "cos":
s_norm = F.normalize(s, p=2, dim=1)
t_norm = F.normalize(t, p=2, dim=1)
return torch.mm(s_norm, t_norm.t())
elif metric == "gauss":
squared_dist = torch.cdist(s, t, p=2) ** 2
sigma_ = 1.5
return torch.exp(-0.5 * squared_dist / sigma_**2)
elif metric == "euc":
return torch.cdist(s, t, p=2)
raise ValueError(f"Unknown metric: {metric}")
def _laplacian(A, laplac="laplac1"):
"""Inspired by https://github.com/CrownX/SPA"""
eps = 1e-7 # For numerical stability
v = torch.sum(A, dim=1)
if laplac == "laplac1":
v_inv = 1 / (v + eps)
D_inv = torch.diag(v_inv)
return -D_inv @ A
elif laplac == "laplac2":
D = torch.diag(v)
return D - A
elif laplac == "laplac3":
v_sqrt = 1 / torch.sqrt(v + eps)
D_sqrt = torch.diag(v_sqrt)
iden = torch.eye(A.shape[0], device=A.device)
return iden - D_sqrt @ A @ D_sqrt
raise ValueError(f"Unknown Laplacian type: {laplac}")
def gda_loss(s, t, metric="euc", laplac="laplac1"):
"""Compute the GDA loss between two graphs.
Inspired by https://github.com/CrownX/SPA
Parameters
----------
s : torch.Tensor
Source features.
t : torch.Tensor
Target features.
metric : str, default="euc"
The metric to use for the adjacency matrix.
laplac : str, default="laplac1"
The Laplacian matrix to use.
"""
# s, t [bsize, dim], [bsize, dim]
s_matrix = _adj(s, s, metric)
t_matrix = _adj(t, t, metric)
s_matrix = _laplacian(s_matrix, laplac)
t_matrix = _laplacian(t_matrix, laplac)
_, s_v, _ = torch.linalg.svd(s_matrix)
_, t_v, _ = torch.linalg.svd(t_matrix)
svd_loss = torch.linalg.norm(s_v - t_v)
return svd_loss
def nap_loss(features_t, y_pred_t, memory_features, memory_outputs, sample_idx_t, K=5):
"""Compute the NAP loss.
Inspired by https://github.com/CrownX/SPA
Parameters
----------
features_t : torch.Tensor
Target features.
y_pred_t : torch.Tensor
Target predictions.
memory_features : torch.Tensor
Memory features.
memory_outputs : torch.Tensor
Memory outputs.
sample_idx_t : torch.Tensor
The sample indices in the batch features_t
K : int, default=5
The number of nearest neighbors.
"""
dis = torch.cdist(features_t.detach(), memory_features, p=2) ** 2
dis[..., sample_idx_t] = float("+inf")
# Get top-K neighbors
_, top_k_indices = torch.topk(-dis, k=K, dim=1)
batch_size, mem_size = features_t.size(0), memory_features.size(0)
w = torch.zeros(batch_size, mem_size, device=features_t.device)
w.scatter_(1, top_k_indices, 1.0 / K)
weight_, pred = torch.max(w.mm(memory_outputs), 1)
loss_ = torch.nn.CrossEntropyLoss(reduction="none")(y_pred_t, pred)
classifier_loss = torch.sum(weight_ * loss_) / (torch.sum(weight_).item() + 1e-7)
return classifier_loss