/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.chat.model;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.EmptyRateLimit;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;

public class MessageAggregator {
    private static final Logger logger = LoggerFactory.getLogger(MessageAggregator.class);

    public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse, Consumer<ChatResponse> onAggregationComplete) {
        AtomicReference<StringBuilder> messageTextContentRef = new AtomicReference<StringBuilder>(new StringBuilder());
        AtomicReference messageMetadataMapRef = new AtomicReference();
        AtomicReference toolCallsRef = new AtomicReference(new ArrayList());
        AtomicReference<ChatGenerationMetadata> generationMetadataRef = new AtomicReference<ChatGenerationMetadata>(ChatGenerationMetadata.NULL);
        AtomicReference<Integer> metadataUsagePromptTokensRef = new AtomicReference<Integer>(0);
        AtomicReference<Integer> metadataUsageGenerationTokensRef = new AtomicReference<Integer>(0);
        AtomicReference<Integer> metadataUsageTotalTokensRef = new AtomicReference<Integer>(0);
        AtomicReference<PromptMetadata> metadataPromptMetadataRef = new AtomicReference<PromptMetadata>(PromptMetadata.empty());
        AtomicReference<EmptyRateLimit> metadataRateLimitRef = new AtomicReference<EmptyRateLimit>(new EmptyRateLimit());
        AtomicReference<String> metadataIdRef = new AtomicReference<String>("");
        AtomicReference<String> metadataModelRef = new AtomicReference<String>("");
        return fluxChatResponse.doOnSubscribe(subscription -> {
            messageTextContentRef.set(new StringBuilder());
            messageMetadataMapRef.set(new HashMap());
            toolCallsRef.set(new ArrayList());
            metadataIdRef.set("");
            metadataModelRef.set("");
            metadataUsagePromptTokensRef.set(0);
            metadataUsageGenerationTokensRef.set(0);
            metadataUsageTotalTokensRef.set(0);
            metadataPromptMetadataRef.set(PromptMetadata.empty());
            metadataRateLimitRef.set(new EmptyRateLimit());
        }).doOnNext(chatResponse -> {
            if (chatResponse.getResult() != null) {
                AssistantMessage outputMessage;
                if (chatResponse.getResult().getMetadata() != null && chatResponse.getResult().getMetadata() != ChatGenerationMetadata.NULL) {
                    generationMetadataRef.set(chatResponse.getResult().getMetadata());
                }
                if (chatResponse.getResult().getOutput().getText() != null) {
                    ((StringBuilder)messageTextContentRef.get()).append(chatResponse.getResult().getOutput().getText());
                }
                if (chatResponse.getResult().getOutput().getMetadata() != null) {
                    ((Map)messageMetadataMapRef.get()).putAll(chatResponse.getResult().getOutput().getMetadata());
                }
                if (!CollectionUtils.isEmpty((outputMessage = chatResponse.getResult().getOutput()).getToolCalls())) {
                    ((List)toolCallsRef.get()).addAll(outputMessage.getToolCalls());
                }
            }
            if (chatResponse.getMetadata() != null) {
                Object toolCallsFromMetadata;
                if (chatResponse.getMetadata().getUsage() != null) {
                    Usage usage = chatResponse.getMetadata().getUsage();
                    metadataUsagePromptTokensRef.set(usage.getPromptTokens() > 0 ? usage.getPromptTokens() : (Integer)metadataUsagePromptTokensRef.get());
                    metadataUsageGenerationTokensRef.set(usage.getCompletionTokens() > 0 ? usage.getCompletionTokens() : (Integer)metadataUsageGenerationTokensRef.get());
                    metadataUsageTotalTokensRef.set(usage.getTotalTokens() > 0 ? usage.getTotalTokens() : (Integer)metadataUsageTotalTokensRef.get());
                }
                if (chatResponse.getMetadata().getPromptMetadata() != null && chatResponse.getMetadata().getPromptMetadata().iterator().hasNext()) {
                    metadataPromptMetadataRef.set(chatResponse.getMetadata().getPromptMetadata());
                }
                if (chatResponse.getMetadata().getRateLimit() != null && !(metadataRateLimitRef.get() instanceof EmptyRateLimit)) {
                    metadataRateLimitRef.set((EmptyRateLimit)chatResponse.getMetadata().getRateLimit());
                }
                if (StringUtils.hasText((String)chatResponse.getMetadata().getId())) {
                    metadataIdRef.set(chatResponse.getMetadata().getId());
                }
                if (StringUtils.hasText((String)chatResponse.getMetadata().getModel())) {
                    metadataModelRef.set(chatResponse.getMetadata().getModel());
                }
                if ((toolCallsFromMetadata = chatResponse.getMetadata().get("toolCalls")) instanceof List) {
                    List toolCallsList = (List)toolCallsFromMetadata;
                    ((List)toolCallsRef.get()).addAll(toolCallsList);
                }
            }
        }).doOnComplete(() -> {
            DefaultUsage usage = new DefaultUsage((Integer)metadataUsagePromptTokensRef.get(), (Integer)metadataUsageGenerationTokensRef.get(), (Integer)metadataUsageTotalTokensRef.get());
            ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata.builder().id((String)metadataIdRef.get()).model((String)metadataModelRef.get()).rateLimit((RateLimit)metadataRateLimitRef.get()).usage(usage).promptMetadata((PromptMetadata)metadataPromptMetadataRef.get()).build();
            List collectedToolCalls = (List)toolCallsRef.get();
            AssistantMessage finalAssistantMessage = !CollectionUtils.isEmpty((Collection)collectedToolCalls) ? new AssistantMessage(((StringBuilder)messageTextContentRef.get()).toString(), (Map)messageMetadataMapRef.get(), collectedToolCalls) : new AssistantMessage(((StringBuilder)messageTextContentRef.get()).toString(), (Map)messageMetadataMapRef.get());
            onAggregationComplete.accept(new ChatResponse(List.of(new Generation(finalAssistantMessage, (ChatGenerationMetadata)generationMetadataRef.get())), chatResponseMetadata));
            messageTextContentRef.set(new StringBuilder());
            messageMetadataMapRef.set(new HashMap());
            toolCallsRef.set(new ArrayList());
            metadataIdRef.set("");
            metadataModelRef.set("");
            metadataUsagePromptTokensRef.set(0);
            metadataUsageGenerationTokensRef.set(0);
            metadataUsageTotalTokensRef.set(0);
            metadataPromptMetadataRef.set(PromptMetadata.empty());
            metadataRateLimitRef.set(new EmptyRateLimit());
        }).doOnError(e -> logger.error("Aggregation Error", e));
    }

    public record DefaultUsage(Integer promptTokens, Integer completionTokens, Integer totalTokens) implements Usage
    {
        @Override
        public Integer getPromptTokens() {
            return this.promptTokens();
        }

        @Override
        public Integer getCompletionTokens() {
            return this.completionTokens();
        }

        @Override
        public Integer getTotalTokens() {
            return this.totalTokens();
        }

        @Override
        public Map<String, Integer> getNativeUsage() {
            HashMap<String, Integer> usage = new HashMap<String, Integer>();
            usage.put("promptTokens", this.promptTokens());
            usage.put("completionTokens", this.completionTokens());
            usage.put("totalTokens", this.totalTokens());
            return usage;
        }
    }
}

