package org.openimaj.demos.sandbox.ml.regression;

import java.io.IOException;
import java.util.Iterator;
import javax.swing.JFrame;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.data.time.Day;
import org.jfree.data.time.TimeSeries;
import org.jfree.data.time.TimeSeriesCollection;
import org.joda.time.DateTime;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import org.openimaj.io.Cache;
import org.openimaj.ml.timeseries.IncompatibleTimeSeriesException;
import org.openimaj.ml.timeseries.aggregator.MeanSquaredDifferenceAggregator;
import org.openimaj.ml.timeseries.processor.LinearRegressionProcessor;
import org.openimaj.ml.timeseries.processor.MovingAverageProcessor;
import org.openimaj.ml.timeseries.processor.WindowedLinearRegressionProcessor;
import org.openimaj.ml.timeseries.series.DoubleTimeSeries;
import org.openimaj.twitter.finance.YahooFinanceData;
import org.openimaj.util.pair.IndependentPair;

/* loaded from: input_file:org/openimaj/demos/sandbox/ml/regression/LinearRegressionPlayground.class */
public class LinearRegressionPlayground {
    public static void main(String[] strArr) throws IOException, IncompatibleTimeSeriesException {
        DateTimeFormatter forPattern = DateTimeFormat.forPattern("YYYY-MM-dd");
        long millis = forPattern.parseDateTime("2010-01-01").getMillis();
        long millis2 = forPattern.parseDateTime("2010-05-01").getMillis();
        DoubleTimeSeries doubleTimeSeries = (DoubleTimeSeries) ((YahooFinanceData) Cache.load(new YahooFinanceData("AAPL", "2010-01-01", "2010-12-31", "YYYY-MM-dd"))).seriesMap().get("High");
        DoubleTimeSeries doubleTimeSeries2 = doubleTimeSeries.get(millis, millis2);
        TimeSeriesCollection timeSeriesCollection = new TimeSeriesCollection();
        timeSeriesCollection.addSeries(timeSeriesToChart("High Value", doubleTimeSeries));
        DoubleTimeSeries process = doubleTimeSeries.process(new MovingAverageProcessor(2592000000L));
        DoubleTimeSeries process2 = doubleTimeSeries2.process(new MovingAverageProcessor(2592000000L));
        timeSeriesCollection.addSeries(timeSeriesToChart("High Value MA", process));
        timeSeriesCollection.addSeries(timeSeriesToChart("High Value MA Regressed (all seen)", process.process(new WindowedLinearRegressionProcessor(10, 7))));
        timeSeriesCollection.addSeries(timeSeriesToChart("High Value MA Regressed (latter half unseen)", process.process(new WindowedLinearRegressionProcessor(process2, 10, 7))));
        displayTimeSeries(timeSeriesCollection, "AAPL", "Date", "Price");
        TimeSeriesCollection timeSeriesCollection2 = new TimeSeriesCollection();
        timeSeriesCollection2.addSeries(timeSeriesToChart("High Value", doubleTimeSeries));
        MeanSquaredDifferenceAggregator.error(new DoubleTimeSeries[]{doubleTimeSeries.process(new LinearRegressionProcessor()), doubleTimeSeries}).doubleValue();
        DoubleTimeSeries process3 = doubleTimeSeries.process(new WindowedLinearRegressionProcessor(10, 7));
        DoubleTimeSeries process4 = doubleTimeSeries.process(new WindowedLinearRegressionProcessor(3, 1));
        DoubleTimeSeries process5 = doubleTimeSeries.process(new WindowedLinearRegressionProcessor(doubleTimeSeries2, 10, 7));
        double doubleValue = MeanSquaredDifferenceAggregator.error(new DoubleTimeSeries[]{process3, doubleTimeSeries}).doubleValue();
        double doubleValue2 = MeanSquaredDifferenceAggregator.error(new DoubleTimeSeries[]{process4, doubleTimeSeries}).doubleValue();
        double doubleValue3 = MeanSquaredDifferenceAggregator.error(new DoubleTimeSeries[]{process5, doubleTimeSeries}).doubleValue();
        timeSeriesCollection2.addSeries(timeSeriesToChart(String.format("OLR (m=7,n=10) (MSE=%.2f)", Double.valueOf(doubleValue)), process3));
        timeSeriesCollection2.addSeries(timeSeriesToChart(String.format("OLR (m=1,n=3) (MSE=%.2f)", Double.valueOf(doubleValue2)), process4));
        timeSeriesCollection2.addSeries(timeSeriesToChart(String.format("OLR unseen (m=7,n=10) (MSE=%.2f)", Double.valueOf(doubleValue3)), process5));
        displayTimeSeries(timeSeriesCollection2, "AAPL", "Date", "Price");
    }

    private static void displayTimeSeries(TimeSeriesCollection timeSeriesCollection, String str, String str2, String str3) {
        ChartPanel chartPanel = new ChartPanel(ChartFactory.createTimeSeriesChart(str, str2, str3, timeSeriesCollection, true, false, false));
        chartPanel.setFillZoomRectangle(true);
        JFrame jFrame = new JFrame();
        jFrame.setContentPane(chartPanel);
        jFrame.pack();
        jFrame.setVisible(true);
        jFrame.setDefaultCloseOperation(3);
    }

    private static TimeSeries timeSeriesToChart(String str, DoubleTimeSeries doubleTimeSeries) {
        TimeSeries timeSeries = new TimeSeries(str);
        Iterator it = doubleTimeSeries.iterator();
        while (it.hasNext()) {
            IndependentPair independentPair = (IndependentPair) it.next();
            DateTime dateTime = new DateTime(independentPair.firstObject());
            timeSeries.add(new Day(dateTime.getDayOfMonth(), dateTime.getMonthOfYear(), dateTime.getYear()), (Number) independentPair.secondObject());
        }
        return timeSeries;
    }
}
