package org.openimaj.math.statistics.distribution;

import gnu.trove.procedure.TObjectDoubleProcedure;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.openimaj.math.statistics.distribution.kernel.UnivariateKernel;
import org.openimaj.util.pair.ObjectDoublePair;
import org.openimaj.util.tree.DoubleKDTree;

/* loaded from: input_file:org/openimaj/math/statistics/distribution/MultivariateKernelDensityEstimate.class */
public class MultivariateKernelDensityEstimate extends AbstractMultivariateDistribution {
    double[][] data;
    UnivariateKernel kernel;
    private double bandwidth;
    DoubleKDTree tree;

    public MultivariateKernelDensityEstimate(double[][] dArr, UnivariateKernel univariateKernel, double d) {
        this.data = dArr;
        this.tree = new DoubleKDTree(dArr);
        this.kernel = univariateKernel;
        this.bandwidth = d;
    }

    public MultivariateKernelDensityEstimate(List<double[]> list, UnivariateKernel univariateKernel, double d) {
        this.data = (double[][]) list.toArray((Object[]) new double[list.size()]);
        this.tree = new DoubleKDTree(this.data);
        this.kernel = univariateKernel;
        this.bandwidth = d;
    }

    @Override // org.openimaj.math.statistics.distribution.MultivariateDistribution
    public double[] sample(Random random) {
        double[] dArr = (double[]) this.data[random.nextInt(this.data.length)].clone();
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = dArr[i] + (this.kernel.sample(random) * getBandwidth());
        }
        return dArr;
    }

    @Override // org.openimaj.math.statistics.distribution.MultivariateDistribution
    public double estimateProbability(double[] dArr) {
        final double[] dArr2 = new double[1];
        final int[] iArr = new int[1];
        this.tree.coordinateRadiusSearch(dArr, this.kernel.getCutOff() * getBandwidth(), new TObjectDoubleProcedure<double[]>() { // from class: org.openimaj.math.statistics.distribution.MultivariateKernelDensityEstimate.1
            public boolean execute(double[] dArr3, double d) {
                double[] dArr4 = dArr2;
                dArr4[0] = dArr4[0] + MultivariateKernelDensityEstimate.this.kernel.evaluate(Math.sqrt(d) / MultivariateKernelDensityEstimate.this.getBandwidth());
                int[] iArr2 = iArr;
                iArr2[0] = iArr2[0] + 1;
                return true;
            }
        });
        return dArr2[0] / (getBandwidth() * iArr[0]);
    }

    public List<ObjectDoublePair<double[]>> getSupport(double[] dArr) {
        final ArrayList arrayList = new ArrayList();
        this.tree.coordinateRadiusSearch(dArr, this.kernel.getCutOff() * getBandwidth(), new TObjectDoubleProcedure<double[]>() { // from class: org.openimaj.math.statistics.distribution.MultivariateKernelDensityEstimate.2
            public boolean execute(double[] dArr2, double d) {
                arrayList.add(ObjectDoublePair.pair(dArr2, MultivariateKernelDensityEstimate.this.kernel.evaluate(Math.sqrt(d) / MultivariateKernelDensityEstimate.this.getBandwidth())));
                return true;
            }
        });
        return arrayList;
    }

    public double[][] getData() {
        return this.data;
    }

    public double getBandwidth() {
        return this.bandwidth;
    }

    public double getScaledBandwidth() {
        return this.bandwidth * this.kernel.getCutOff();
    }
}
