/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.training;

import java.util.ArrayList;
import java.util.List;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
import org.opensearch.knn.jni.JNICommons;
import org.opensearch.knn.training.TrainingDataConsumer;
import org.opensearch.search.SearchHit;

public class ByteTrainingDataConsumer
extends TrainingDataConsumer {
    private static final Logger logger = LogManager.getLogger(TrainingDataConsumer.class);

    public ByteTrainingDataConsumer(NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation) {
        super(trainingDataAllocation);
    }

    @Override
    public void accept(List<?> byteVectors) {
        long memoryAddress = this.trainingDataAllocation.getMemoryAddress();
        memoryAddress = JNICommons.storeByteVectorData(memoryAddress, (byte[][])byteVectors.toArray((T[])new byte[0][0]), byteVectors.size());
        this.trainingDataAllocation.setMemoryAddress(memoryAddress);
    }

    @Override
    public void processTrainingVectors(SearchResponse searchResponse, int vectorsToAdd, String fieldName) {
        SearchHit[] hits = searchResponse.getHits().getHits();
        ArrayList<byte[]> vectors = new ArrayList<byte[]>();
        String[] fieldPath = fieldName.split("\\.");
        int nullVectorCount = 0;
        for (int vector = 0; vector < vectorsToAdd; ++vector) {
            Object fieldValue = this.extractFieldValue(hits[vector], fieldPath);
            if (fieldValue == null) {
                ++nullVectorCount;
                continue;
            }
            if (!(fieldValue instanceof List)) continue;
            List fieldList = (List)fieldValue;
            byte[] byteArray = new byte[fieldList.size()];
            for (int i = 0; i < fieldList.size(); ++i) {
                byteArray[i] = ((Number)fieldList.get(i)).byteValue();
            }
            vectors.add(byteArray);
        }
        if (nullVectorCount > 0) {
            logger.warn("Found {} documents with null byte vectors in field {}", (Object)nullVectorCount, (Object)fieldName);
        }
        this.setTotalVectorsCountAdded(this.getTotalVectorsCountAdded() + vectors.size());
        this.accept(vectors);
    }
}

