/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.codec.nativeindex;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Supplier;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.IndexOutput;
import org.opensearch.common.Nullable;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.MediaType;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategy;
import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategyFactory;
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.codec.util.KNNCodecUtil;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.store.IndexOutputWithBuffer;
import org.opensearch.knn.index.util.IndexUtil;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.indices.Model;
import org.opensearch.knn.indices.ModelCache;
import org.opensearch.knn.plugin.stats.KNNGraphValue;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

public class NativeIndexWriter {
    @Generated
    private static final Logger log = LogManager.getLogger(NativeIndexWriter.class);
    private static final Long CRC32_CHECKSUM_SANITY = -4294967296L;
    private final SegmentWriteState state;
    private final FieldInfo fieldInfo;
    private final NativeIndexBuildStrategyFactory indexBuilderFactory;
    @Nullable
    private final QuantizationState quantizationState;

    public static NativeIndexWriter getWriter(FieldInfo fieldInfo, SegmentWriteState state) {
        return NativeIndexWriter.createWriter(fieldInfo, state, null, new NativeIndexBuildStrategyFactory());
    }

    public static NativeIndexWriter getWriter(FieldInfo fieldInfo, SegmentWriteState state, QuantizationState quantizationState, NativeIndexBuildStrategyFactory nativeIndexBuildStrategyFactory) {
        return NativeIndexWriter.createWriter(fieldInfo, state, quantizationState, nativeIndexBuildStrategyFactory);
    }

    public void flushIndex(Supplier<KNNVectorValues<?>> knnVectorValuesSupplier, int totalLiveDocs) throws IOException {
        this.buildAndWriteIndex(knnVectorValuesSupplier, totalLiveDocs, true);
        this.recordRefreshStats();
    }

    public void mergeIndex(Supplier<KNNVectorValues<?>> knnVectorValuesSupplier, int totalLiveDocs) throws IOException {
        KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();
        KNNCodecUtil.initializeVectorValues(knnVectorValues);
        if (knnVectorValues.docId() == Integer.MAX_VALUE) {
            log.debug("Skipping mergeIndex, vector values are already iterated for {}", (Object)this.fieldInfo.name);
            return;
        }
        long bytesPerVector = knnVectorValues.bytesPerVector();
        this.startMergeStats(totalLiveDocs, bytesPerVector);
        this.buildAndWriteIndex(knnVectorValuesSupplier, totalLiveDocs, false);
        this.endMergeStats(totalLiveDocs, bytesPerVector);
    }

    private void buildAndWriteIndex(Supplier<KNNVectorValues<?>> knnVectorValuesSupplier, int totalLiveDocs, boolean isFlush) throws IOException {
        if (totalLiveDocs == 0) {
            log.debug("No live docs for field {}", (Object)this.fieldInfo.name);
            return;
        }
        KNNEngine knnEngine = FieldInfoExtractor.extractKNNEngine(this.fieldInfo);
        String engineFileName = KNNCodecUtil.buildEngineFileName(this.state.segmentInfo.name, knnEngine.getVersion(), this.fieldInfo.name, knnEngine.getExtension());
        try (IndexOutput output = this.state.directory.createOutput(engineFileName, this.state.context);){
            IndexOutputWithBuffer indexOutputWithBuffer = new IndexOutputWithBuffer(output);
            BuildIndexParams nativeIndexParams = this.indexParams(this.fieldInfo, indexOutputWithBuffer, knnEngine, knnVectorValuesSupplier, totalLiveDocs, isFlush);
            NativeIndexBuildStrategy indexBuilder = this.indexBuilderFactory.getBuildStrategy(this.fieldInfo, totalLiveDocs, knnVectorValuesSupplier.get(), nativeIndexParams);
            indexBuilder.buildAndWriteIndex(nativeIndexParams);
            CodecUtil.writeFooter((IndexOutput)output);
        }
    }

    private BuildIndexParams indexParams(FieldInfo fieldInfo, IndexOutputWithBuffer indexOutputWithBuffer, KNNEngine knnEngine, Supplier<KNNVectorValues<?>> knnVectorValuesSupplier, int totalLiveDocs, boolean isFlush) throws IOException {
        Map<String, Object> parameters;
        VectorDataType vectorDataType = this.quantizationState != null ? QuantizationService.getInstance().getVectorDataTypeForTransfer(fieldInfo, this.state.segmentInfo.getVersion()) : FieldInfoExtractor.extractVectorDataType(fieldInfo);
        if (fieldInfo.attributes().containsKey("model_id")) {
            Model model = this.getModel(fieldInfo);
            parameters = this.getTemplateParameters(fieldInfo, model);
        } else {
            parameters = this.getParameters(fieldInfo, vectorDataType, knnEngine);
        }
        return BuildIndexParams.builder().fieldName(fieldInfo.name).parameters(parameters).vectorDataType(vectorDataType).knnEngine(knnEngine).indexOutputWithBuffer(indexOutputWithBuffer).quantizationState(this.quantizationState).knnVectorValuesSupplier(knnVectorValuesSupplier).totalLiveDocs(totalLiveDocs).segmentWriteState(this.state).isFlush(isFlush).build();
    }

    private Map<String, Object> getParameters(FieldInfo fieldInfo, VectorDataType vectorDataType, KNNEngine knnEngine) throws IOException {
        HashMap<String, Object> parameters = new HashMap<String, Object>();
        Map fieldAttributes = fieldInfo.attributes();
        String parametersString = (String)fieldAttributes.get("parameters");
        if (parametersString == null) {
            String m;
            parameters.put("spaceType", fieldAttributes.getOrDefault("spaceType", SpaceType.DEFAULT.getValue()));
            String efConstruction = (String)fieldAttributes.get("efConstruction");
            HashMap<String, Integer> algoParams = new HashMap<String, Integer>();
            if (efConstruction != null) {
                algoParams.put("ef_construction", Integer.parseInt(efConstruction));
            }
            if ((m = (String)fieldAttributes.get("M")) != null) {
                algoParams.put("m", Integer.parseInt(m));
            }
            parameters.put("parameters", algoParams);
        } else {
            parameters.putAll(XContentHelper.createParser((NamedXContentRegistry)NamedXContentRegistry.EMPTY, (DeprecationHandler)DeprecationHandler.THROW_UNSUPPORTED_OPERATION, (BytesReference)new BytesArray(parametersString), (MediaType)MediaTypeRegistry.getDefaultMediaType()).map());
        }
        parameters.put("data_type", vectorDataType.getValue());
        this.maybeAddBinaryPrefixForFaissBWC(knnEngine, parameters, fieldAttributes);
        parameters.put("indexThreadQty", KNNSettings.getIndexThreadQty());
        return parameters;
    }

    private void maybeAddBinaryPrefixForFaissBWC(KNNEngine knnEngine, Map<String, Object> parameters, Map<String, String> fieldAttributes) {
        if (KNNEngine.FAISS != knnEngine) {
            return;
        }
        if (!VectorDataType.BINARY.getValue().equals(fieldAttributes.getOrDefault("data_type", VectorDataType.DEFAULT.getValue()))) {
            return;
        }
        if (parameters.get("index_description") == null) {
            return;
        }
        if (parameters.get("index_description").toString().startsWith("B")) {
            return;
        }
        parameters.put("index_description", "B" + parameters.get("index_description").toString());
        IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY);
    }

    private Map<String, Object> getTemplateParameters(FieldInfo fieldInfo, Model model) throws IOException {
        HashMap<String, Object> parameters = new HashMap<String, Object>();
        parameters.put("indexThreadQty", KNNSettings.getIndexThreadQty());
        parameters.put("model_id", fieldInfo.attributes().get("model_id"));
        parameters.put("model_blob", model.getModelBlob());
        if (FieldInfoExtractor.extractQuantizationConfig(fieldInfo, this.state.segmentInfo.getVersion()) != QuantizationConfig.EMPTY) {
            IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY);
        } else {
            IndexUtil.updateVectorDataTypeToParameters(parameters, model.getModelMetadata().getVectorDataType());
        }
        return parameters;
    }

    private Model getModel(FieldInfo fieldInfo) {
        String modelId = (String)fieldInfo.attributes().get("model_id");
        Model model = ModelCache.getInstance().get(modelId);
        if (model.getModelBlob() == null) {
            throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId));
        }
        return model;
    }

    private void startMergeStats(int numDocs, long bytesPerVector) {
        KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment();
        KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(numDocs);
        KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(bytesPerVector);
        KNNGraphValue.MERGE_TOTAL_OPERATIONS.increment();
        KNNGraphValue.MERGE_TOTAL_DOCS.incrementBy(numDocs);
        KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.incrementBy(bytesPerVector);
    }

    private void endMergeStats(int numDocs, long arraySize) {
        KNNGraphValue.MERGE_CURRENT_OPERATIONS.decrement();
        KNNGraphValue.MERGE_CURRENT_DOCS.decrementBy(numDocs);
        KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.decrementBy(arraySize);
    }

    private void recordRefreshStats() {
        KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment();
    }

    private static NativeIndexWriter createWriter(FieldInfo fieldInfo, SegmentWriteState state, @Nullable QuantizationState quantizationState, NativeIndexBuildStrategyFactory nativeIndexBuildStrategyFactory) {
        return new NativeIndexWriter(state, fieldInfo, nativeIndexBuildStrategyFactory, quantizationState);
    }

    @Generated
    public NativeIndexWriter(SegmentWriteState state, FieldInfo fieldInfo, NativeIndexBuildStrategyFactory indexBuilderFactory, QuantizationState quantizationState) {
        this.state = state;
        this.fieldInfo = fieldInfo;
        this.indexBuilderFactory = indexBuilderFactory;
        this.quantizationState = quantizationState;
    }
}

