/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.vectors.query;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.Version;
import org.elasticsearch.common.logging.DeprecationCategory;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder;
import org.elasticsearch.xpack.vectors.query.VectorScriptDocValues;

public class ScoreScriptUtils {
    private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(ScoreScriptUtils.class);
    static final String DEPRECATION_MESSAGE = "The vector functions of the form function(query, doc['field']) are deprecated, and the form function(query, 'field') should be used instead. For example, cosineSimilarity(query, doc['field']) is replaced by cosineSimilarity(query, 'field').";

    private static double intDotProductSparse(float[] v1Values, int[] v1Dims, float[] v2Values, int[] v2Dims) {
        double v1v2DotProduct = 0.0;
        int v1Index = 0;
        int v2Index = 0;
        while (v1Index < v1Values.length && v2Index < v2Values.length) {
            if (v1Dims[v1Index] == v2Dims[v2Index]) {
                v1v2DotProduct += (double)(v1Values[v1Index] * v2Values[v2Index]);
                ++v1Index;
                ++v2Index;
                continue;
            }
            if (v1Dims[v1Index] > v2Dims[v2Index]) {
                ++v2Index;
                continue;
            }
            ++v1Index;
        }
        return v1v2DotProduct;
    }

    public static final class CosineSimilaritySparse
    extends SparseVectorFunction {
        final double queryVectorMagnitude;

        public CosineSimilaritySparse(ScoreScript scoreScript, Map<String, Number> queryVector, Object docVector) {
            super(scoreScript, queryVector, docVector);
            double dotProduct = 0.0;
            for (int i = 0; i < this.queryDims.length; ++i) {
                dotProduct += (double)(this.queryValues[i] * this.queryValues[i]);
            }
            this.queryVectorMagnitude = Math.sqrt(dotProduct);
        }

        public double cosineSimilaritySparse() {
            BytesRef vector = this.getEncodedVector();
            int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(this.docValues.indexVersion(), vector);
            float[] values = VectorEncoderDecoder.decodeSparseVector(this.docValues.indexVersion(), vector);
            double docQueryDotProduct = ScoreScriptUtils.intDotProductSparse(this.queryValues, this.queryDims, values, docDims);
            double docVectorMagnitude = 0.0;
            if (this.docValues.indexVersion().onOrAfter(Version.V_7_5_0)) {
                docVectorMagnitude = VectorEncoderDecoder.decodeMagnitude(this.docValues.indexVersion(), vector);
            } else {
                for (float docValue : values) {
                    docVectorMagnitude += (double)(docValue * docValue);
                }
                docVectorMagnitude = (float)Math.sqrt(docVectorMagnitude);
            }
            return docQueryDotProduct / (docVectorMagnitude * this.queryVectorMagnitude);
        }
    }

    public static final class DotProductSparse
    extends SparseVectorFunction {
        public DotProductSparse(ScoreScript scoreScript, Map<String, Number> queryVector, Object docVector) {
            super(scoreScript, queryVector, docVector);
        }

        public double dotProductSparse() {
            BytesRef vector = this.getEncodedVector();
            int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(this.docValues.indexVersion(), vector);
            float[] values = VectorEncoderDecoder.decodeSparseVector(this.docValues.indexVersion(), vector);
            return ScoreScriptUtils.intDotProductSparse(this.queryValues, this.queryDims, values, docDims);
        }
    }

    public static final class L2NormSparse
    extends SparseVectorFunction {
        public L2NormSparse(ScoreScript scoreScript, Map<String, Number> queryVector, Object docVector) {
            super(scoreScript, queryVector, docVector);
        }

        public double l2normSparse() {
            BytesRef vector = this.getEncodedVector();
            int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(this.docValues.indexVersion(), vector);
            float[] values = VectorEncoderDecoder.decodeSparseVector(this.docValues.indexVersion(), vector);
            int queryIndex = 0;
            int docIndex = 0;
            double l2norm = 0.0;
            while (queryIndex < this.queryDims.length && docIndex < docDims.length) {
                double diff;
                if (this.queryDims[queryIndex] == docDims[docIndex]) {
                    diff = this.queryValues[queryIndex] - values[docIndex];
                    l2norm += diff * diff;
                    ++queryIndex;
                    ++docIndex;
                    continue;
                }
                if (this.queryDims[queryIndex] > docDims[docIndex]) {
                    diff = values[docIndex];
                    l2norm += diff * diff;
                    ++docIndex;
                    continue;
                }
                diff = this.queryValues[queryIndex];
                l2norm += diff * diff;
                ++queryIndex;
            }
            while (queryIndex < this.queryDims.length) {
                l2norm += (double)(this.queryValues[queryIndex] * this.queryValues[queryIndex]);
                ++queryIndex;
            }
            while (docIndex < docDims.length) {
                l2norm += (double)(values[docIndex] * values[docIndex]);
                ++docIndex;
            }
            return Math.sqrt(l2norm);
        }
    }

    public static final class L1NormSparse
    extends SparseVectorFunction {
        public L1NormSparse(ScoreScript scoreScript, Map<String, Number> queryVector, Object docVector) {
            super(scoreScript, queryVector, docVector);
        }

        public double l1normSparse() {
            BytesRef vector = this.getEncodedVector();
            int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(this.docValues.indexVersion(), vector);
            float[] values = VectorEncoderDecoder.decodeSparseVector(this.docValues.indexVersion(), vector);
            int queryIndex = 0;
            int docIndex = 0;
            double l1norm = 0.0;
            while (queryIndex < this.queryDims.length && docIndex < docDims.length) {
                if (this.queryDims[queryIndex] == docDims[docIndex]) {
                    l1norm += (double)Math.abs(this.queryValues[queryIndex] - values[docIndex]);
                    ++queryIndex;
                    ++docIndex;
                    continue;
                }
                if (this.queryDims[queryIndex] > docDims[docIndex]) {
                    l1norm += (double)Math.abs(values[docIndex]);
                    ++docIndex;
                    continue;
                }
                l1norm += (double)Math.abs(this.queryValues[queryIndex]);
                ++queryIndex;
            }
            while (queryIndex < this.queryDims.length) {
                l1norm += (double)Math.abs(this.queryValues[queryIndex]);
                ++queryIndex;
            }
            while (docIndex < docDims.length) {
                l1norm += (double)Math.abs(values[docIndex]);
                ++docIndex;
            }
            return l1norm;
        }
    }

    public static class SparseVectorFunction {
        final ScoreScript scoreScript;
        final float[] queryValues;
        final int[] queryDims;
        final VectorScriptDocValues.SparseVectorScriptDocValues docValues;

        public SparseVectorFunction(ScoreScript scoreScript, Map<String, Number> queryVector, Object field) {
            this.scoreScript = scoreScript;
            int n = queryVector.size();
            this.queryValues = new float[n];
            this.queryDims = new int[n];
            int i = 0;
            for (Map.Entry<String, Number> dimValue : queryVector.entrySet()) {
                try {
                    this.queryDims[i] = Integer.parseInt(dimValue.getKey());
                }
                catch (NumberFormatException e) {
                    throw new IllegalArgumentException("Failed to parse a query vector dimension, it must be an integer!", e);
                }
                this.queryValues[i] = dimValue.getValue().floatValue();
                ++i;
            }
            VectorEncoderDecoder.sortSparseDimsFloatValues(this.queryDims, this.queryValues, n);
            if (field instanceof String) {
                String fieldName = (String)field;
                this.docValues = (VectorScriptDocValues.SparseVectorScriptDocValues)((Object)scoreScript.getDoc().get(fieldName));
            } else if (field instanceof VectorScriptDocValues.SparseVectorScriptDocValues) {
                this.docValues = (VectorScriptDocValues.SparseVectorScriptDocValues)((Object)field);
                deprecationLogger.critical(DeprecationCategory.SCRIPTING, "vector_function_signature", ScoreScriptUtils.DEPRECATION_MESSAGE, new Object[0]);
            } else {
                throw new IllegalArgumentException("For vector functions, the 'field' argument must be of type String or VectorScriptDocValues");
            }
            deprecationLogger.critical(DeprecationCategory.MAPPINGS, "sparse_vector_function", "The [sparse_vector] field type is deprecated and will be removed in 8.0.", new Object[0]);
        }

        BytesRef getEncodedVector() {
            try {
                this.docValues.setNextDocId(this.scoreScript._getDocId());
            }
            catch (IOException e) {
                throw ExceptionsHelper.convertToElastic((Exception)e);
            }
            BytesRef vector = this.docValues.getEncodedValue();
            if (vector == null) {
                throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
            }
            return vector;
        }
    }

    public static final class CosineSimilarity
    extends DenseVectorFunction {
        public CosineSimilarity(ScoreScript scoreScript, List<Number> queryVector, Object field) {
            super(scoreScript, queryVector, field, true);
        }

        public double cosineSimilarity() {
            BytesRef vector = this.getEncodedVector();
            ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
            double dotProduct = 0.0;
            for (float queryValue : this.queryVector) {
                dotProduct += (double)(queryValue * byteBuffer.getFloat());
            }
            return dotProduct / (double)this.docValues.getMagnitude();
        }
    }

    public static final class DotProduct
    extends DenseVectorFunction {
        public DotProduct(ScoreScript scoreScript, List<Number> queryVector, Object field) {
            super(scoreScript, queryVector, field);
        }

        public double dotProduct() {
            BytesRef vector = this.getEncodedVector();
            ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
            double dotProduct = 0.0;
            for (float queryValue : this.queryVector) {
                dotProduct += (double)(queryValue * byteBuffer.getFloat());
            }
            return dotProduct;
        }
    }

    public static final class L2Norm
    extends DenseVectorFunction {
        public L2Norm(ScoreScript scoreScript, List<Number> queryVector, Object field) {
            super(scoreScript, queryVector, field);
        }

        public double l2norm() {
            BytesRef vector = this.getEncodedVector();
            ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
            double l2norm = 0.0;
            for (float queryValue : this.queryVector) {
                double diff = queryValue - byteBuffer.getFloat();
                l2norm += diff * diff;
            }
            return Math.sqrt(l2norm);
        }
    }

    public static final class L1Norm
    extends DenseVectorFunction {
        public L1Norm(ScoreScript scoreScript, List<Number> queryVector, Object field) {
            super(scoreScript, queryVector, field);
        }

        public double l1norm() {
            BytesRef vector = this.getEncodedVector();
            ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
            double l1norm = 0.0;
            for (float queryValue : this.queryVector) {
                l1norm += (double)Math.abs(queryValue - byteBuffer.getFloat());
            }
            return l1norm;
        }
    }

    public static class DenseVectorFunction {
        final ScoreScript scoreScript;
        final float[] queryVector;
        final VectorScriptDocValues.DenseVectorScriptDocValues docValues;

        public DenseVectorFunction(ScoreScript scoreScript, List<Number> queryVector, Object field) {
            this(scoreScript, queryVector, field, false);
        }

        public DenseVectorFunction(ScoreScript scoreScript, List<Number> queryVector, Object field, boolean normalizeQuery) {
            this.scoreScript = scoreScript;
            if (field instanceof String) {
                String fieldName = (String)field;
                this.docValues = (VectorScriptDocValues.DenseVectorScriptDocValues)((Object)scoreScript.getDoc().get(fieldName));
            } else if (field instanceof VectorScriptDocValues.DenseVectorScriptDocValues) {
                this.docValues = (VectorScriptDocValues.DenseVectorScriptDocValues)((Object)field);
                deprecationLogger.critical(DeprecationCategory.SCRIPTING, "vector_function_signature", ScoreScriptUtils.DEPRECATION_MESSAGE, new Object[0]);
            } else {
                throw new IllegalArgumentException("For vector functions, the 'field' argument must be of type String or VectorScriptDocValues");
            }
            if (this.docValues.dims() != queryVector.size()) {
                throw new IllegalArgumentException("The query vector has a different number of dimensions [" + queryVector.size() + "] than the document vectors [" + this.docValues.dims() + "].");
            }
            this.queryVector = new float[queryVector.size()];
            double queryMagnitude = 0.0;
            for (int i = 0; i < queryVector.size(); ++i) {
                float value;
                this.queryVector[i] = value = queryVector.get(i).floatValue();
                queryMagnitude += (double)(value * value);
            }
            queryMagnitude = Math.sqrt(queryMagnitude);
            if (normalizeQuery) {
                int dim = 0;
                while (dim < this.queryVector.length) {
                    int n = dim++;
                    this.queryVector[n] = (float)((double)this.queryVector[n] / queryMagnitude);
                }
            }
        }

        BytesRef getEncodedVector() {
            try {
                this.docValues.setNextDocId(this.scoreScript._getDocId());
            }
            catch (IOException e) {
                throw ExceptionsHelper.convertToElastic((Exception)e);
            }
            BytesRef vector = this.docValues.getEncodedValue();
            if (vector == null) {
                throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
            }
            return vector;
        }
    }
}

