package org.openimaj.math.statistics.distribution.metrics;

import Jama.Matrix;
import org.openimaj.math.matrix.MatrixUtils;
import org.openimaj.math.statistics.distribution.MultivariateGaussian;
import org.openimaj.util.comparator.DistanceComparator;

/* loaded from: input_file:org/openimaj/math/statistics/distribution/metrics/GaussianKLDivergence.class */
public class GaussianKLDivergence implements DistanceComparator<MultivariateGaussian> {
    @Override // org.openimaj.util.comparator.DistanceComparator
    public boolean isDistance() {
        return true;
    }

    @Override // org.openimaj.util.comparator.DistanceComparator
    public double compare(MultivariateGaussian multivariateGaussian, MultivariateGaussian multivariateGaussian2) {
        Matrix covariance = multivariateGaussian.getCovariance();
        Matrix covariance2 = multivariateGaussian2.getCovariance();
        Matrix mean = multivariateGaussian.getMean();
        Matrix mean2 = multivariateGaussian2.getMean();
        int numDims = multivariateGaussian.numDims();
        Matrix inverse = covariance2.inverse();
        double trace = MatrixUtils.trace(inverse.times(covariance));
        Matrix minus = mean2.minus(mean);
        return 0.5d * (((trace + minus.transpose().times(inverse).times(minus).get(0, 0)) - numDims) - Math.log(covariance.norm1() / covariance2.norm1()));
    }
}
