/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.sparse.query;

import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.builder.EqualsBuilder;
import org.apache.commons.lang3.builder.HashCodeBuilder;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Query;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.io.stream.NamedWriteable;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.neuralsearch.sparse.data.SparseVector;
import org.opensearch.neuralsearch.sparse.mapper.SparseVectorFieldType;
import org.opensearch.neuralsearch.sparse.quantization.ByteQuantizationUtil;
import org.opensearch.neuralsearch.sparse.quantization.ByteQuantizer;
import org.opensearch.neuralsearch.sparse.query.SparseQueryContext;
import org.opensearch.neuralsearch.sparse.query.SparseVectorQuery;
import org.opensearch.neuralsearch.stats.events.EventStatName;
import org.opensearch.neuralsearch.stats.events.EventStatsManager;

public class SparseAnnQueryBuilder
extends AbstractQueryBuilder<SparseAnnQueryBuilder> {
    @Generated
    private static final Logger log = LogManager.getLogger(SparseAnnQueryBuilder.class);
    public static final String NAME = "neural_sparse";
    @VisibleForTesting
    public static final ParseField TOP_N_FIELD = new ParseField("top_n", new String[0]);
    @VisibleForTesting
    public static final ParseField TOP_K_FIELD = new ParseField("k", new String[0]);
    @VisibleForTesting
    public static final ParseField HEAP_FACTOR_FIELD = new ParseField("heap_factor", new String[0]);
    @VisibleForTesting
    public static final ParseField METHOD_PARAMETERS_FIELD = new ParseField("method_parameters", new String[0]);
    @VisibleForTesting
    public static final ParseField FILTER_FIELD = new ParseField("filter", new String[0]);
    private String fieldName;
    private Integer queryCut;
    private Integer k;
    private Float heapFactor;
    private QueryBuilder filter;
    private Query fallbackQuery;
    private Map<String, Float> queryTokens;
    private static final int DEFAULT_TOP_K = 10;
    private static final int DEFAULT_QUERY_CUT = 10;
    private static final float DEFAULT_HEAP_FACTOR = 1.0f;

    public SparseAnnQueryBuilder(String fieldName, Integer queryCut, Integer k, Float heapFactor, QueryBuilder filter, Query fallbackQuery, Map<String, Float> queryTokens) {
        this.fieldName = fieldName;
        this.queryCut = queryCut;
        this.k = k;
        this.heapFactor = heapFactor;
        this.filter = filter;
        this.fallbackQuery = fallbackQuery;
        this.queryTokens = SparseAnnQueryBuilder.preprocessQueryTokens(queryTokens);
    }

    public SparseAnnQueryBuilder(StreamInput in) throws IOException {
        super(in);
        this.queryCut = in.readOptionalInt();
        this.k = in.readOptionalInt();
        this.heapFactor = in.readOptionalFloat();
        this.filter = (QueryBuilder)in.readOptionalNamedWriteable(QueryBuilder.class);
    }

    public SparseAnnQueryBuilder queryTokens(Map<String, Float> queryTokens) {
        this.queryTokens = SparseAnnQueryBuilder.preprocessQueryTokens(queryTokens);
        return this;
    }

    public static SparseAnnQueryBuilder fromXContent(XContentParser parser) throws IOException {
        EventStatsManager.increment(EventStatName.SEISMIC_QUERY_REQUESTS);
        String methodFieldName = "";
        XContentParser.Token token = parser.currentToken();
        if (token != XContentParser.Token.START_OBJECT) {
            throw new ParsingException(parser.getTokenLocation(), String.format(Locale.ROOT, "[%s] %s must be an object", NAME, METHOD_PARAMETERS_FIELD.getPreferredName()), new Object[0]);
        }
        SparseAnnQueryBuilderBuilder builder = SparseAnnQueryBuilder.builder();
        while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
            if (token == XContentParser.Token.FIELD_NAME) {
                methodFieldName = parser.currentName();
                continue;
            }
            if (token.isValue()) {
                if (TOP_N_FIELD.match(methodFieldName, parser.getDeprecationHandler())) {
                    builder.queryCut = parser.intValue();
                    if (builder.queryCut > 0) continue;
                    throw new ParsingException(parser.getTokenLocation(), String.format(Locale.ROOT, "[%s] %s must be a positive integer", NAME, TOP_N_FIELD.getPreferredName()), new Object[0]);
                }
                if (TOP_K_FIELD.match(methodFieldName, parser.getDeprecationHandler())) {
                    builder.k = parser.intValue();
                    if (builder.k > 0) continue;
                    throw new ParsingException(parser.getTokenLocation(), String.format(Locale.ROOT, "[%s] %s must be a positive integer", NAME, TOP_K_FIELD.getPreferredName()), new Object[0]);
                }
                if (HEAP_FACTOR_FIELD.match(methodFieldName, parser.getDeprecationHandler())) {
                    builder.heapFactor = Float.valueOf(parser.floatValue());
                    if (!(builder.heapFactor.floatValue() <= 0.0f)) continue;
                    throw new ParsingException(parser.getTokenLocation(), String.format(Locale.ROOT, "[%s] %s must be a positive float", NAME, HEAP_FACTOR_FIELD.getPreferredName()), new Object[0]);
                }
                throw new ParsingException(parser.getTokenLocation(), String.format(Locale.ROOT, "[%s] unknown field [%s]", NAME, methodFieldName), new Object[0]);
            }
            if (FILTER_FIELD.match(methodFieldName, parser.getDeprecationHandler())) {
                QueryBuilder filterQueryBuilder = SparseAnnQueryBuilder.parseInnerQueryBuilder((XContentParser)parser);
                builder.filter(filterQueryBuilder);
                continue;
            }
            throw new ParsingException(parser.getTokenLocation(), String.format(Locale.ROOT, "[%s] unknown token [%s] after [%s]", NAME, token, methodFieldName), new Object[0]);
        }
        return builder.build();
    }

    protected void doWriteTo(StreamOutput out) throws IOException {
        out.writeOptionalInt(this.queryCut);
        out.writeOptionalInt(this.k);
        out.writeOptionalFloat(this.heapFactor);
        out.writeOptionalNamedWriteable((NamedWriteable)this.filter);
    }

    protected void doXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
        if (Objects.nonNull(this.queryCut)) {
            xContentBuilder.field(TOP_N_FIELD.getPreferredName(), this.queryCut);
        }
        if (Objects.nonNull(this.k)) {
            xContentBuilder.field(TOP_K_FIELD.getPreferredName(), this.k);
        }
        if (Objects.nonNull(this.heapFactor)) {
            xContentBuilder.field(HEAP_FACTOR_FIELD.getPreferredName(), this.heapFactor);
        }
        if (Objects.nonNull(this.filter)) {
            xContentBuilder.field(FILTER_FIELD.getPreferredName(), (ToXContent)this.filter);
        }
    }

    protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
        return new SparseAnnQueryBuilder().fieldName(this.fieldName).queryCut(this.queryCut).k(this.k).filter(this.filter).fallbackQuery(this.fallbackQuery).heapFactor(this.heapFactor);
    }

    private SparseQueryContext constructSparseQueryContext() {
        int n = this.queryCut == null ? 10 : this.queryCut;
        n = Math.min(n, this.queryTokens.size());
        List<String> topTokens = this.queryTokens.entrySet().stream().sorted(Map.Entry.comparingByValue().reversed()).limit(n).map(Map.Entry::getKey).toList();
        return SparseQueryContext.builder().tokens(topTokens).heapFactor(this.heapFactor == null ? 1.0f : this.heapFactor.floatValue()).k(this.k == null || this.k == 0 ? 10 : this.k).build();
    }

    public Query doToQuery(QueryShardContext context) throws IOException {
        MappedFieldType fieldType = context.fieldMapper(this.fieldName);
        SparseAnnQueryBuilder.validateFieldType(fieldType);
        SparseQueryContext sparseQueryContext = this.constructSparseQueryContext();
        float quantizationCeilSearch = SparseAnnQueryBuilder.getQuantizationCeilSearch(context, this.fieldName);
        Query filterQuery = null;
        if (this.filter != null) {
            filterQuery = this.filter.toQuery(context);
        }
        HashMap<Integer, Float> integerTokens = new HashMap<Integer, Float>();
        for (Map.Entry<String, Float> entry : this.queryTokens.entrySet()) {
            int token = Integer.parseInt(entry.getKey());
            integerTokens.put(token, entry.getValue());
        }
        return new SparseVectorQuery.SparseVectorQueryBuilder().fieldName(this.fieldName).queryContext(sparseQueryContext).queryVector(new SparseVector(integerTokens, new ByteQuantizer(quantizationCeilSearch))).fallbackQuery(this.fallbackQuery).filter(filterQuery).build();
    }

    public static void validateFieldType(MappedFieldType fieldType) {
        if (Objects.isNull(fieldType) || !SparseVectorFieldType.isSparseVectorType(fieldType.typeName())) {
            throw new IllegalArgumentException("[neural_sparse] query with [" + METHOD_PARAMETERS_FIELD.getPreferredName() + "] only works on [sparse_vector] fields");
        }
    }

    protected boolean doEquals(SparseAnnQueryBuilder obj) {
        if (this == obj) {
            return true;
        }
        if (Objects.isNull((Object)obj) || ((Object)((Object)this)).getClass() != ((Object)((Object)obj)).getClass()) {
            return false;
        }
        EqualsBuilder equalsBuilder = new EqualsBuilder().append((Object)this.queryCut, (Object)obj.queryCut).append((Object)this.heapFactor, (Object)obj.heapFactor).append((Object)this.k, (Object)obj.k).append((Object)this.filter, (Object)obj.filter);
        return equalsBuilder.isEquals();
    }

    protected int doHashCode() {
        HashCodeBuilder builder = new HashCodeBuilder().append((Object)this.queryCut).append((Object)this.heapFactor).append((Object)this.k).append((Object)this.filter);
        return builder.toHashCode();
    }

    public String getWriteableName() {
        return NAME;
    }

    private static Map<String, Float> preprocessQueryTokens(Map<String, Float> tokens) {
        if (MapUtils.isEmpty(tokens)) {
            return Collections.emptyMap();
        }
        HashMap<Integer, Float> intTokens = new HashMap<Integer, Float>();
        try {
            for (Map.Entry<String, Float> entry : tokens.entrySet()) {
                int token = Integer.parseInt(entry.getKey());
                if (token < 0) {
                    throw new IllegalArgumentException("Query tokens should be non-negative integer!");
                }
                float value = entry.getValue().floatValue();
                int tokenHash = SparseVector.prepareTokenForShortType(token);
                if (intTokens.containsKey(tokenHash)) {
                    intTokens.put(tokenHash, Float.valueOf(Math.max(((Float)intTokens.get(tokenHash)).floatValue(), value)));
                    continue;
                }
                intTokens.put(tokenHash, Float.valueOf(value));
            }
        }
        catch (NumberFormatException ex) {
            throw new IllegalArgumentException("Query tokens should be valid integer");
        }
        return intTokens.entrySet().stream().collect(Collectors.toMap(e -> String.valueOf(e.getKey()), Map.Entry::getValue));
    }

    private static float getQuantizationCeilSearch(QueryShardContext context, String fieldName) {
        float quantizationCeilSearch = 16.0f;
        try {
            for (LeafReaderContext leafContext : context.searcher().getTopReaderContext().leaves()) {
                FieldInfos fieldInfos = leafContext.reader().getFieldInfos();
                FieldInfo fieldInfo = fieldInfos.fieldInfo(fieldName);
                if (fieldInfo == null) continue;
                quantizationCeilSearch = ByteQuantizationUtil.getCeilingValueSearch(fieldInfo);
                break;
            }
        }
        catch (Exception e) {
            log.error("Failed to get quantization ceiling search value for field [{}]", (Object)fieldName, (Object)e);
        }
        return quantizationCeilSearch;
    }

    @Generated
    public static SparseAnnQueryBuilderBuilder builder() {
        return new SparseAnnQueryBuilderBuilder();
    }

    @Generated
    public String fieldName() {
        return this.fieldName;
    }

    @Generated
    public Integer queryCut() {
        return this.queryCut;
    }

    @Generated
    public Integer k() {
        return this.k;
    }

    @Generated
    public Float heapFactor() {
        return this.heapFactor;
    }

    @Generated
    public QueryBuilder filter() {
        return this.filter;
    }

    @Generated
    public Query fallbackQuery() {
        return this.fallbackQuery;
    }

    @Generated
    public Map<String, Float> queryTokens() {
        return this.queryTokens;
    }

    @Generated
    public SparseAnnQueryBuilder fieldName(String fieldName) {
        this.fieldName = fieldName;
        return this;
    }

    @Generated
    public SparseAnnQueryBuilder queryCut(Integer queryCut) {
        this.queryCut = queryCut;
        return this;
    }

    @Generated
    public SparseAnnQueryBuilder k(Integer k) {
        this.k = k;
        return this;
    }

    @Generated
    public SparseAnnQueryBuilder heapFactor(Float heapFactor) {
        this.heapFactor = heapFactor;
        return this;
    }

    @Generated
    public SparseAnnQueryBuilder filter(QueryBuilder filter) {
        this.filter = filter;
        return this;
    }

    @Generated
    public SparseAnnQueryBuilder fallbackQuery(Query fallbackQuery) {
        this.fallbackQuery = fallbackQuery;
        return this;
    }

    @Generated
    public SparseAnnQueryBuilder() {
    }

    public static class SparseAnnQueryBuilderBuilder {
        @Generated
        private String fieldName;
        @Generated
        private Integer queryCut;
        @Generated
        private Integer k;
        @Generated
        private Float heapFactor;
        @Generated
        private QueryBuilder filter;
        @Generated
        private Query fallbackQuery;
        @Generated
        private Map<String, Float> queryTokens;

        public SparseAnnQueryBuilderBuilder queryTokens(Map<String, Float> queryTokens) {
            this.queryTokens = SparseAnnQueryBuilder.preprocessQueryTokens(queryTokens);
            return this;
        }

        @Generated
        SparseAnnQueryBuilderBuilder() {
        }

        @Generated
        public SparseAnnQueryBuilderBuilder fieldName(String fieldName) {
            this.fieldName = fieldName;
            return this;
        }

        @Generated
        public SparseAnnQueryBuilderBuilder queryCut(Integer queryCut) {
            this.queryCut = queryCut;
            return this;
        }

        @Generated
        public SparseAnnQueryBuilderBuilder k(Integer k) {
            this.k = k;
            return this;
        }

        @Generated
        public SparseAnnQueryBuilderBuilder heapFactor(Float heapFactor) {
            this.heapFactor = heapFactor;
            return this;
        }

        @Generated
        public SparseAnnQueryBuilderBuilder filter(QueryBuilder filter) {
            this.filter = filter;
            return this;
        }

        @Generated
        public SparseAnnQueryBuilderBuilder fallbackQuery(Query fallbackQuery) {
            this.fallbackQuery = fallbackQuery;
            return this;
        }

        @Generated
        public SparseAnnQueryBuilder build() {
            return new SparseAnnQueryBuilder(this.fieldName, this.queryCut, this.k, this.heapFactor, this.filter, this.fallbackQuery, this.queryTokens);
        }

        @Generated
        public String toString() {
            return "SparseAnnQueryBuilder.SparseAnnQueryBuilderBuilder(fieldName=" + this.fieldName + ", queryCut=" + this.queryCut + ", k=" + this.k + ", heapFactor=" + this.heapFactor + ", filter=" + String.valueOf(this.filter) + ", fallbackQuery=" + String.valueOf(this.fallbackQuery) + ", queryTokens=" + String.valueOf(this.queryTokens) + ")";
        }
    }
}

