/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.agent.tools.utils.clustering;

import java.util.ArrayList;
import java.util.List;

public class HierarchicalAgglomerativeClustering {
    private final double[][] data;
    private final double[][] distanceMatrix;
    private final int nSamples;
    private final int nFeatures;

    public HierarchicalAgglomerativeClustering(double[][] data) {
        this.data = data;
        this.nSamples = data.length;
        this.nFeatures = data[0].length;
        this.distanceMatrix = new double[this.nSamples][this.nSamples];
        this.computeCosineDistanceMatrix();
    }

    private void computeCosineDistanceMatrix() {
        int i;
        double[] norms = new double[this.nSamples];
        for (i = 0; i < this.nSamples; ++i) {
            double norm = 0.0;
            for (int j = 0; j < this.nFeatures; ++j) {
                norm += this.data[i][j] * this.data[i][j];
            }
            norms[i] = Math.sqrt(norm);
        }
        for (i = 0; i < this.nSamples; ++i) {
            this.distanceMatrix[i][i] = 0.0;
            for (int j = i + 1; j < this.nSamples; ++j) {
                double distance;
                double similarity = HierarchicalAgglomerativeClustering.calculateCosineSimilarity(this.data[i], this.data[j], norms[i], norms[j]);
                double d = distance = 1.0 - similarity;
                this.distanceMatrix[j][i] = d;
                this.distanceMatrix[i][j] = d;
            }
        }
    }

    private static double calculateCosineSimilarity(double[] a, double[] b, double normA, double normB) {
        if (normA == 0.0 || normB == 0.0) {
            return 0.0;
        }
        double dotProduct = 0.0;
        for (int i = 0; i < a.length; ++i) {
            dotProduct += a[i] * b[i];
        }
        return dotProduct / (normA * normB);
    }

    public List<ClusterNode> fit(LinkageMethod linkage, double threshold) {
        int[] closestPair;
        if (threshold < 0.0) {
            throw new IllegalArgumentException("Distance threshold must be non-negative");
        }
        ArrayList<ClusterNode> activeClusters = new ArrayList<ClusterNode>();
        for (int i = 0; i < this.nSamples; ++i) {
            activeClusters.add(new ClusterNode(i, i));
        }
        int nextClusterId = this.nSamples;
        while (activeClusters.size() > 1 && (closestPair = this.findClosestClusters(activeClusters, linkage, threshold)) != null) {
            int i = closestPair[0];
            int j = closestPair[1];
            ClusterNode newCluster = new ClusterNode(nextClusterId++, (ClusterNode)activeClusters.get(i), (ClusterNode)activeClusters.get(j));
            activeClusters.remove(Math.max(i, j));
            activeClusters.remove(Math.min(i, j));
            activeClusters.add(newCluster);
        }
        return activeClusters;
    }

    private int[] findClosestClusters(List<ClusterNode> clusters, LinkageMethod linkage, double threshold) {
        int[] nArray;
        double minDistance = threshold;
        int bestI = -1;
        int bestJ = -1;
        for (int i = 0; i < clusters.size(); ++i) {
            for (int j = i + 1; j < clusters.size(); ++j) {
                double distance = this.computeClusterDistance(clusters.get(i), clusters.get(j), linkage);
                if (!(distance < minDistance)) continue;
                minDistance = distance;
                bestI = i;
                bestJ = j;
            }
        }
        if (bestI == -1) {
            nArray = null;
        } else {
            int[] nArray2 = new int[2];
            nArray2[0] = bestI;
            nArray = nArray2;
            nArray2[1] = bestJ;
        }
        return nArray;
    }

    private double computeClusterDistance(ClusterNode c1, ClusterNode c2, LinkageMethod linkage) {
        return switch (linkage.ordinal()) {
            default -> throw new MatchException(null, null);
            case 0 -> this.singleLinkage(c1, c2);
            case 1 -> this.completeLinkage(c1, c2);
            case 2 -> this.averageLinkage(c1, c2);
        };
    }

    private double singleLinkage(ClusterNode c1, ClusterNode c2) {
        double minDist = Double.MAX_VALUE;
        for (int i : c1.samples) {
            for (int j : c2.samples) {
                double dist = this.distanceMatrix[i][j];
                if (!(dist < minDist) || !((minDist = dist) < 1.0E-10)) continue;
                return minDist;
            }
        }
        return minDist;
    }

    private double completeLinkage(ClusterNode c1, ClusterNode c2) {
        double maxDist = Double.MIN_VALUE;
        for (int i : c1.samples) {
            for (int j : c2.samples) {
                double dist = this.distanceMatrix[i][j];
                if (!(dist > maxDist)) continue;
                maxDist = dist;
            }
        }
        return maxDist;
    }

    private double averageLinkage(ClusterNode c1, ClusterNode c2) {
        double sumDist = 0.0;
        int count = 0;
        for (int i : c1.samples) {
            for (int j : c2.samples) {
                sumDist += this.distanceMatrix[i][j];
                ++count;
            }
        }
        return sumDist / (double)count;
    }

    public int getClusterCentroid(ClusterNode cluster) {
        if (cluster.samples.size() == 1) {
            return cluster.samples.getFirst();
        }
        int medoidIndex = cluster.samples.getFirst();
        double minTotalDistance = Double.MAX_VALUE;
        for (int pointI : cluster.samples) {
            double totalDistance = 0.0;
            for (int pointJ : cluster.samples) {
                totalDistance += this.distanceMatrix[pointI][pointJ];
            }
            if (!(totalDistance < minTotalDistance)) continue;
            minTotalDistance = totalDistance;
            medoidIndex = pointI;
        }
        return medoidIndex;
    }

    public static double calculateCosineSimilarity(double[] a, double[] b) {
        double dotProduct = 0.0;
        double normA = 0.0;
        double normB = 0.0;
        for (int i = 0; i < a.length; ++i) {
            dotProduct += a[i] * b[i];
            normA += a[i] * a[i];
            normB += b[i] * b[i];
        }
        if (normA == 0.0 || normB == 0.0) {
            return 0.0;
        }
        return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
    }

    public static class ClusterNode {
        final int id;
        final List<Integer> samples;
        final int size;

        ClusterNode(int id, int sample) {
            this.id = id;
            this.samples = new ArrayList<Integer>();
            this.samples.add(sample);
            this.size = 1;
        }

        ClusterNode(int id, ClusterNode left, ClusterNode right) {
            this.id = id;
            this.samples = new ArrayList<Integer>();
            this.samples.addAll(left.samples);
            this.samples.addAll(right.samples);
            this.size = left.size + right.size;
        }
    }

    public static enum LinkageMethod {
        SINGLE,
        COMPLETE,
        AVERAGE;

    }
}

