package org.openimaj.image.processing.face.alignment;

import Jama.Matrix;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.openimaj.image.FImage;
import org.openimaj.image.ImageUtilities;
import org.openimaj.image.processing.face.detection.keypoints.FKEFaceDetector;
import org.openimaj.image.processing.face.detection.keypoints.FacialKeypoint;
import org.openimaj.image.processing.face.detection.keypoints.KEDetectedFace;
import org.openimaj.image.processing.transform.PiecewiseMeshWarp;
import org.openimaj.math.geometry.point.Point2d;
import org.openimaj.math.geometry.point.Point2dImpl;
import org.openimaj.math.geometry.shape.Polygon;
import org.openimaj.math.geometry.shape.Shape;
import org.openimaj.math.geometry.transforms.TransformUtilities;
import org.openimaj.util.pair.Pair;

/* loaded from: input_file:org/openimaj/image/processing/face/alignment/MeshWarpAligner.class */
public class MeshWarpAligner implements FaceAligner<KEDetectedFace> {
    private static final String[][] DEFAULT_MESH_DEFINITION = {new String[]{"EYE_LEFT_RIGHT", "EYE_RIGHT_LEFT", "NOSE_MIDDLE"}, new String[]{"EYE_LEFT_LEFT", "EYE_LEFT_RIGHT", "NOSE_LEFT"}, new String[]{"EYE_RIGHT_RIGHT", "EYE_RIGHT_LEFT", "NOSE_RIGHT"}, new String[]{"EYE_LEFT_RIGHT", "NOSE_LEFT", "NOSE_MIDDLE"}, new String[]{"EYE_RIGHT_LEFT", "NOSE_RIGHT", "NOSE_MIDDLE"}, new String[]{"MOUTH_LEFT", "MOUTH_RIGHT", "NOSE_MIDDLE"}, new String[]{"MOUTH_LEFT", "NOSE_LEFT", "NOSE_MIDDLE"}, new String[]{"MOUTH_RIGHT", "NOSE_RIGHT", "NOSE_MIDDLE"}, new String[]{"MOUTH_LEFT", "NOSE_LEFT", "EYE_LEFT_LEFT"}, new String[]{"MOUTH_RIGHT", "NOSE_RIGHT", "EYE_RIGHT_RIGHT"}};
    private static final Point2d P0 = new Point2dImpl(0.0f, 0.0f);
    private static final Point2d P1 = new Point2dImpl(80.0f, 0.0f);
    private static final Point2d P2 = new Point2dImpl(80.0f, 80.0f);
    private static final Point2d P3 = new Point2dImpl(0.0f, 80.0f);
    private static FacialKeypoint[] canonical = loadCanonicalPoints();
    String[][] meshDefinition;
    FImage mask;

    public MeshWarpAligner() {
        this(DEFAULT_MESH_DEFINITION);
    }

    public MeshWarpAligner(String[][] strArr) {
        this.meshDefinition = DEFAULT_MESH_DEFINITION;
        this.meshDefinition = strArr;
        List<Pair<Shape>> createMesh = createMesh(canonical);
        this.mask = new FImage((int) P2.getX(), (int) P2.getY());
        this.mask.fill(1.0f);
        this.mask = this.mask.processInplace(new PiecewiseMeshWarp(createMesh));
    }

    private static FacialKeypoint[] loadCanonicalPoints() {
        FacialKeypoint[] facialKeypointArr = new FacialKeypoint[AffineAligner.Pmu[0].length];
        for (int i = 0; i < facialKeypointArr.length; i++) {
            facialKeypointArr[i] = new FacialKeypoint(FacialKeypoint.FacialKeypointType.valueOf(i));
            facialKeypointArr[i].position = new Point2dImpl((2.0f * AffineAligner.Pmu[0][i]) - 40.0f, (2.0f * AffineAligner.Pmu[1][i]) - 40.0f);
        }
        return facialKeypointArr;
    }

    protected FacialKeypoint[] getActualPoints(FacialKeypoint[] facialKeypointArr, Matrix matrix) {
        FacialKeypoint[] facialKeypointArr2 = new FacialKeypoint[AffineAligner.Pmu[0].length];
        for (int i = 0; i < facialKeypointArr2.length; i++) {
            facialKeypointArr2[i] = new FacialKeypoint(FacialKeypoint.FacialKeypointType.valueOf(i));
            facialKeypointArr2[i].position = new Point2dImpl(FacialKeypoint.getKeypoint(facialKeypointArr, FacialKeypoint.FacialKeypointType.valueOf(i)).position.transform(matrix));
        }
        return facialKeypointArr2;
    }

    protected List<Pair<Shape>> createMesh(FacialKeypoint[] facialKeypointArr) {
        ArrayList arrayList = new ArrayList();
        for (String[] strArr : this.meshDefinition) {
            Polygon polygon = new Polygon();
            Polygon polygon2 = new Polygon();
            for (String str : strArr) {
                polygon.getVertices().add(lookupVertex(str, facialKeypointArr));
                polygon2.getVertices().add(lookupVertex(str, canonical));
            }
            arrayList.add(new Pair(polygon, polygon2));
        }
        return arrayList;
    }

    private Point2d lookupVertex(String str, FacialKeypoint[] facialKeypointArr) {
        return str.equals("P0") ? P0 : str.equals("P1") ? P1 : str.equals("P2") ? P2 : str.equals("P3") ? P3 : FacialKeypoint.getKeypoint(facialKeypointArr, FacialKeypoint.FacialKeypointType.valueOf(str)).position;
    }

    @Override // org.openimaj.image.processing.face.alignment.FaceAligner
    public FImage align(KEDetectedFace kEDetectedFace) {
        Matrix scaleMatrix = TransformUtilities.scaleMatrix(P2.getX() / kEDetectedFace.getFacePatch().width, P2.getY() / kEDetectedFace.getFacePatch().height);
        Matrix inverse = scaleMatrix.inverse();
        return getWarpedImage(kEDetectedFace.getKeypoints(), FKEFaceDetector.extractPatch(FKEFaceDetector.pyramidResize(kEDetectedFace.getFacePatch(), inverse), inverse, 80, 0), scaleMatrix);
    }

    protected FImage getWarpedImage(FacialKeypoint[] facialKeypointArr, FImage fImage, Matrix matrix) {
        return fImage.process(new PiecewiseMeshWarp(createMesh(getActualPoints(facialKeypointArr, matrix))));
    }

    @Override // org.openimaj.image.processing.face.alignment.FaceAligner
    public FImage getMask() {
        return this.mask;
    }

    /* JADX WARN: Type inference failed for: r1v1, types: [java.lang.String[], java.lang.String[][]] */
    public void readBinary(DataInput dataInput) throws IOException {
        this.meshDefinition = new String[dataInput.readInt()];
        for (int i = 0; i < this.meshDefinition.length; i++) {
            this.meshDefinition[i] = new String[dataInput.readInt()];
            for (int i2 = 0; i2 < this.meshDefinition[i].length; i2++) {
                this.meshDefinition[i][i2] = dataInput.readUTF();
            }
        }
        this.mask = ImageUtilities.readF(dataInput);
    }

    public byte[] binaryHeader() {
        return getClass().getName().getBytes();
    }

    public void writeBinary(DataOutput dataOutput) throws IOException {
        dataOutput.writeInt(this.meshDefinition.length);
        for (String[] strArr : this.meshDefinition) {
            dataOutput.writeInt(strArr.length);
            for (String str : strArr) {
                dataOutput.writeUTF(str);
            }
        }
        ImageUtilities.write(this.mask, "png", dataOutput);
    }
}
