/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.plugin.script;

import java.io.IOException;
import java.math.BigInteger;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import lombok.Generated;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.Version;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;
import org.opensearch.knn.index.query.KNNWeight;
import org.opensearch.knn.plugin.script.KNNScoreScript;
import org.opensearch.knn.plugin.script.KNNScoringSpaceUtil;
import org.opensearch.knn.plugin.script.KNNScoringUtil;
import org.opensearch.script.ScoreScript;
import org.opensearch.search.lookup.SearchLookup;

public interface KNNScoringSpace {
    public ScoreScript getScoreScript(Map<String, Object> var1, String var2, SearchLookup var3, LeafReaderContext var4, IndexSearcher var5) throws IOException;

    public static class HammingBit
    implements KNNScoringSpace {
        Object processedQuery;
        BiFunction<?, ?, Float> scoringMethod;

        public HammingBit(Object query, MappedFieldType fieldType) {
            if (KNNScoringSpaceUtil.isLongFieldType(fieldType)) {
                this.processedQuery = KNNScoringSpaceUtil.parseToLong(query);
                this.scoringMethod = (q, v) -> Float.valueOf(1.0f / (1.0f + KNNScoringUtil.calculateHammingBit(q, v)));
            } else if (KNNScoringSpaceUtil.isBinaryFieldType(fieldType)) {
                this.processedQuery = KNNScoringSpaceUtil.parseToBigInteger(query);
                this.scoringMethod = (q, v) -> Float.valueOf(1.0f / (1.0f + KNNScoringUtil.calculateHammingBit(q, v)));
            } else {
                throw new IllegalArgumentException("Incompatible field_type for hammingbit space. The field type must of type long or binary.");
            }
        }

        @Override
        public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx, IndexSearcher searcher) throws IOException {
            if (this.processedQuery instanceof Long) {
                return new KNNScoreScript.LongType(params, (Long)this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher);
            }
            return new KNNScoreScript.BigIntegerType(params, (BigInteger)this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher);
        }
    }

    public static class Hamming
    extends KNNFieldSpace {
        private static final Set<VectorDataType> DATA_TYPES_HAMMING = Set.of(VectorDataType.BINARY);

        public Hamming(Object query, MappedFieldType fieldType) {
            super(query, fieldType, "hamming", DATA_TYPES_HAMMING);
        }

        @Override
        public BiFunction<?, ?, Float> getScoringMethod(Object processedQuery) {
            return (q, v) -> Float.valueOf(1.0f / (1.0f + KNNScoringUtil.calculateHammingBit(q, v)));
        }
    }

    public static class InnerProd
    extends KNNFieldSpace {
        public InnerProd(Object query, MappedFieldType fieldType) {
            super(query, fieldType, "innerproduct");
        }

        @Override
        public BiFunction<?, ?, Float> getScoringMethod(Object processedQuery) {
            if (processedQuery instanceof float[]) {
                return (q, v) -> Float.valueOf(KNNWeight.normalizeScore(-KNNScoringUtil.innerProduct(q, v)));
            }
            return (q, v) -> Float.valueOf(KNNWeight.normalizeScore(-KNNScoringUtil.innerProduct(q, v)));
        }
    }

    public static class LInf
    extends KNNFieldSpace {
        public LInf(Object query, MappedFieldType fieldType) {
            super(query, fieldType, "l-inf");
        }

        @Override
        public BiFunction<?, ?, Float> getScoringMethod(Object processedQuery) {
            if (processedQuery instanceof float[]) {
                return (q, v) -> Float.valueOf(1.0f / (1.0f + KNNScoringUtil.lInfNorm(q, v)));
            }
            return (q, v) -> Float.valueOf(1.0f / (1.0f + KNNScoringUtil.lInfNorm(q, v)));
        }
    }

    public static class L1
    extends KNNFieldSpace {
        public L1(Object query, MappedFieldType fieldType) {
            super(query, fieldType, "l1");
        }

        @Override
        public BiFunction<?, ?, Float> getScoringMethod(Object processedQuery) {
            if (processedQuery instanceof float[]) {
                return (q, v) -> Float.valueOf(1.0f / (1.0f + KNNScoringUtil.l1Norm(q, v)));
            }
            return (q, v) -> Float.valueOf(1.0f / (1.0f + KNNScoringUtil.l1Norm(q, v)));
        }
    }

    public static class CosineSimilarity
    extends KNNFieldSpace {
        public CosineSimilarity(Object query, MappedFieldType fieldType) {
            super(query, fieldType, "cosine");
        }

        @Override
        public BiFunction<?, ?, Float> getScoringMethod(Object processedQuery) {
            return this.getScoringMethod(processedQuery, Version.CURRENT);
        }

        @Override
        protected BiFunction<?, ?, Float> getScoringMethod(Object processedQuery, Version indexCreatedVersion) {
            if (processedQuery instanceof float[]) {
                SpaceType.COSINESIMIL.validateVector((float[])processedQuery);
                float qVectorSquaredMagnitude = KNNScoringSpaceUtil.getVectorMagnitudeSquared((float[])processedQuery);
                if (indexCreatedVersion.onOrAfter(Version.V_2_19_0)) {
                    return (q, v) -> Float.valueOf(Math.max((1.0f + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude)) / 2.0f, 0.0f));
                }
                return (q, v) -> Float.valueOf(1.0f + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude));
            }
            SpaceType.COSINESIMIL.validateVector((byte[])processedQuery);
            return (q, v) -> Float.valueOf(1.0f + KNNScoringUtil.cosinesimil(q, v));
        }
    }

    public static class L2
    extends KNNFieldSpace {
        public L2(Object query, MappedFieldType fieldType) {
            super(query, fieldType, "l2");
        }

        @Override
        public BiFunction<?, ?, Float> getScoringMethod(Object processedQuery) {
            if (processedQuery instanceof float[]) {
                return (q, v) -> Float.valueOf(1.0f / (1.0f + KNNScoringUtil.l2Squared(q, v)));
            }
            return (q, v) -> Float.valueOf(1.0f / (1.0f + KNNScoringUtil.l2Squared(q, v)));
        }
    }

    public static abstract class KNNFieldSpace
    implements KNNScoringSpace {
        public static final Set<VectorDataType> DATA_TYPES_DEFAULT = Set.of(VectorDataType.FLOAT, VectorDataType.BYTE);
        private Object processedQuery;
        private BiFunction<?, ?, Float> scoringMethod;

        public KNNFieldSpace(Object query, MappedFieldType fieldType, String spaceName) {
            this(query, fieldType, spaceName, DATA_TYPES_DEFAULT);
        }

        public KNNFieldSpace(Object query, MappedFieldType fieldType, String spaceName, Set<VectorDataType> supportingVectorDataTypes) {
            KNNVectorFieldType knnVectorFieldType = this.toKNNVectorFieldType(fieldType, spaceName, supportingVectorDataTypes);
            this.processedQuery = this.getProcessedQuery(query, knnVectorFieldType);
            this.scoringMethod = this.getScoringMethod(this.processedQuery, knnVectorFieldType.getKnnMappingConfig().getIndexCreatedVersion());
        }

        @Override
        public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx, IndexSearcher searcher) throws IOException {
            if (this.processedQuery instanceof float[]) {
                return new KNNScoreScript.KNNFloatVectorType(params, (float[])this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher);
            }
            if (this.processedQuery instanceof byte[]) {
                return new KNNScoreScript.KNNByteVectorType(params, (byte[])this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher);
            }
            throw new IllegalStateException("Unexpected type for processedQuery. Expected float[] or byte[], but got: " + this.processedQuery.getClass().getName());
        }

        private KNNVectorFieldType toKNNVectorFieldType(MappedFieldType fieldType, String spaceName, Set<VectorDataType> supportingVectorDataTypes) {
            VectorDataType vectorDataType;
            if (!KNNScoringSpaceUtil.isKNNVectorFieldType(fieldType)) {
                throw new IllegalArgumentException(String.format(Locale.ROOT, "Incompatible field_type for %s space. The field type must be knn_vector.", spaceName));
            }
            KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldType)fieldType;
            VectorDataType vectorDataType2 = vectorDataType = knnVectorFieldType.getVectorDataType() == null ? VectorDataType.FLOAT : knnVectorFieldType.getVectorDataType();
            if (!supportingVectorDataTypes.contains((Object)vectorDataType)) {
                throw new IllegalArgumentException(String.format(Locale.ROOT, "Incompatible field_type for %s space. The data type should be %s but got %s", new Object[]{spaceName, supportingVectorDataTypes, vectorDataType}));
            }
            return knnVectorFieldType;
        }

        protected Object getProcessedQuery(Object query, KNNVectorFieldType knnVectorFieldType) {
            VectorDataType vectorDataType;
            VectorDataType vectorDataType2 = vectorDataType = knnVectorFieldType.getVectorDataType() == null ? VectorDataType.FLOAT : knnVectorFieldType.getVectorDataType();
            if (vectorDataType == VectorDataType.FLOAT) {
                return KNNScoringSpaceUtil.parseToFloatArray(query, KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldType), knnVectorFieldType.getVectorDataType());
            }
            return KNNScoringSpaceUtil.parseToByteArray(query, KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldType), knnVectorFieldType.getVectorDataType());
        }

        public abstract BiFunction<?, ?, Float> getScoringMethod(Object var1);

        protected BiFunction<?, ?, Float> getScoringMethod(Object processedQuery, Version indexCreatedVersion) {
            return this.getScoringMethod(processedQuery);
        }

        @Generated
        public BiFunction<?, ?, Float> getScoringMethod() {
            return this.scoringMethod;
        }
    }
}

