/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.pulsar.websocket;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.ObjectReader;
import com.google.common.base.Enums;
import org.apache.commons.lang3.StringUtils;
import org.apache.pulsar.broker.authentication.AuthenticationDataSource;
import org.apache.pulsar.client.api.*;
import org.apache.pulsar.client.impl.TypedMessageBuilderImpl;
import org.apache.pulsar.common.api.proto.KeyValue;
import org.apache.pulsar.common.naming.TopicName;
import org.apache.pulsar.common.util.ObjectMapperFactory;
import org.apache.pulsar.websocket.data.ProducerAck;
import org.apache.pulsar.websocket.data.ProducerMessage;
import org.apache.pulsar.websocket.service.WSSDummyMessageCryptoImpl;
import org.apache.pulsar.websocket.stats.StatsBuckets;
import org.eclipse.jetty.websocket.api.WebSocketException;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Collectors;

import static com.google.common.base.Preconditions.checkArgument;
import static org.apache.pulsar.common.api.EncryptionContext.EncryptionKey;
import static org.apache.pulsar.websocket.WebSocketError.PayloadEncodingError;
import static org.apache.pulsar.websocket.WebSocketError.UnknownError;

/**
 * Websocket end-point url handler to handle incoming message coming from client. Websocket end-point url handler to
 * handle incoming message coming from client.
 * <p>
 * On every produced message from client it calls broker to persists it.
 * </p>
 */

public class RelNotifProducerHandler extends AbstractWebSocketHandler implements ProducerHandlerInterface {

    private WebSocketService service;
    private Producer<byte[]> producer;
    private final LongAdder numMsgsSent;
    private final LongAdder numMsgsFailed;
    private final LongAdder numBytesSent;
    private final StatsBuckets publishLatencyStatsUSec;
    private volatile long msgPublishedCounter = 0;
    private boolean clientSideEncrypt;
    private static final AtomicLongFieldUpdater<RelNotifProducerHandler> MSG_PUBLISHED_COUNTER_UPDATER =
            AtomicLongFieldUpdater.newUpdater(RelNotifProducerHandler.class, "msgPublishedCounter");

    public static final List<Long> ENTRY_LATENCY_BUCKETS_USEC = Collections.unmodifiableList(Arrays.asList(
            500L, 1_000L, 5_000L, 10_000L, 20_000L, 50_000L, 100_000L, 200_000L, 1000_000L));
    private final ObjectReader producerMessageReader =
            ObjectMapperFactory.getMapper().reader().forType(ProducerMessage.class);

    public RelNotifProducerHandler(WebSocketService service, HttpServletRequest request, ServletUpgradeResponse response) {
        super(service, request, response);
        this.numMsgsSent = new LongAdder();
        this.numBytesSent = new LongAdder();
        this.numMsgsFailed = new LongAdder();
        this.publishLatencyStatsUSec = new StatsBuckets(ENTRY_LATENCY_BUCKETS_USEC);
        this.service = service;

        this.keyFile = service.getConfig().getWebSocketPublicKeyFile();

        verifyToken(request);
        this.topic = extractNotificationTopicName(request);

//        if (!checkAuth(response)) {
//            return;
//        }

        try {
            this.producer = getNotificationProducerBuilder(service.getPulsarClient()).topic(topic.toString()).create();
            if (clientSideEncrypt) {
                log.info("[{}] [{}] The producer session is created with param encryptionKeyValues, which means that"
                                + " message encryption will be done on the client side, then the server will skip "
                                + "batch message processing, message compression processing, and message encryption"
                                + " processing", producer.getTopic(), producer.getProducerName());
            }
            if (!this.service.addProducer(this)) {
                log.warn("[{}:{}] Failed to add producer handler for topic {}", request.getRemoteAddr(),
                        request.getRemotePort(), topic);
            }
        } catch (Exception e) {
            int errorCode = getErrorCode(e);
            boolean isKnownError = errorCode != HttpServletResponse.SC_INTERNAL_SERVER_ERROR;
            if (isKnownError) {
                log.warn("[{}:{}] Failed in creating producer on topic {}: {}", request.getRemoteAddr(),
                        request.getRemotePort(), topic, e.getMessage());
            } else {
                log.error("[{}:{}] Failed in creating producer on topic {}: {}", request.getRemoteAddr(),
                        request.getRemotePort(), topic, e.getMessage(), e);
            }

            try {
                response.sendError(errorCode, getErrorMessage(e));
            } catch (IOException e1) {
                log.warn("[{}:{}] Failed to send error: {}", request.getRemoteAddr(), request.getRemotePort(),
                        e1.getMessage(), e1);
            }
        }
    }

    @Override
    public void close() throws IOException {
        if (producer != null) {
            if (!this.service.removeProducer(this)) {
                log.warn("[{}] Failed to remove producer handler", producer.getTopic());
            }
            producer.closeAsync().thenAccept(x -> {
                if (log.isDebugEnabled()) {
                    log.debug("[{}] Closed producer asynchronously", producer.getTopic());
                }
            }).exceptionally(exception -> {
                log.warn("[{}] Failed to close producer", producer.getTopic(), exception);
                return null;
            });
        }
    }

    @Override
    public void onWebSocketText(String message) {
        if (log.isDebugEnabled()) {
            log.debug("[{}] Received new message from producer {} ", producer.getTopic(),
                    getRemote().getInetSocketAddress().toString());
        }
        ParseResult parseResult = parseHeaders(message);

        String requestContext = parseResult.contextHeader;
        if (requestContext == null) {
            throw new WebSocketException("Expected context id in message");
        }

        String msg = parseResult.message;
        boolean binaryPayload = RelNotifToken.isBinaryNamespace(token.getTopicName().getNamespacePortion());

        byte[] rawPayload;
        if (binaryPayload) {
            try {
                rawPayload = Base64.getDecoder().decode(msg);
            } catch (IllegalArgumentException e) {
                sendAckResponse(new ProducerAck(PayloadEncodingError, e.getMessage(), null, requestContext));
                return;
            }
        } else {
            rawPayload = msg.getBytes(StandardCharsets.UTF_8);
        }

        final long msgSize = rawPayload.length;
        TypedMessageBuilderImpl<byte[]> builder = (TypedMessageBuilderImpl<byte[]>) producer.newMessage();

        String sourceObjectID = null;
        if (!parseResult.notificationHeaders.isEmpty()) {
            // First (optional) header is message key.
            sourceObjectID = parseResult.notificationHeaders.get(0);
        } else if (!binaryPayload) {
            // Try to extract Source Object to use as key.
            try {
                JsonNode json = mapper.readTree(msg);
                JsonNode source = json.findValue("source");
                if (source != null) {
                    JsonNode id = source.get("id");
                    if (id != null) {
                        String idStr = id.asText();
                        if (idStr != null && !idStr.isEmpty()) {
                            sourceObjectID = idStr;
                        }
                    }
                }
            } catch (Exception exn) {
                log.warn("Failed to extract source object id from message." + exn.getMessage());
            }
        }
        if (sourceObjectID != null && !sourceObjectID.isEmpty()) {
            builder.key(sourceObjectID);
        }

        try {
            builder.value(rawPayload);
        } catch (SchemaSerializationException e) {
            sendAckResponse(new ProducerAck(PayloadEncodingError, e.getMessage(), null, requestContext));
            return;
        }

        final long now = System.nanoTime();

        builder.sendAsync().thenAccept(msgId -> {
            if (log.isDebugEnabled()) {
                log.debug("[{}] Success fully write the message to broker with returned message ID {} from producer {}",
                        producer.getTopic(), msgId, getRemote().getInetSocketAddress().toString());
            }
            updateSentMsgStats(msgSize, TimeUnit.NANOSECONDS.toMicros(System.nanoTime() - now));
            if (isConnected()) {
                String messageId = Base64.getEncoder().encodeToString(msgId.toByteArray());
                sendAckResponse(new ProducerAck(messageId, requestContext));
            }
        }).exceptionally(exception -> {
            log.warn("[{}] Error occurred while producer handler was sending msg from {}", producer.getTopic(),
                    getRemote().getInetSocketAddress().toString(), exception);
            numMsgsFailed.increment();
            sendAckResponse(
                    new ProducerAck(UnknownError, exception.getMessage(), null, requestContext));
            return null;
        });
    }

    public Producer<byte[]> getProducer() {
        return this.producer;
    }

    public long getAndResetNumMsgsSent() {
        return numMsgsSent.sumThenReset();
    }

    public long getAndResetNumBytesSent() {
        return numBytesSent.sumThenReset();
    }

    public long getAndResetNumMsgsFailed() {
        return numMsgsFailed.sumThenReset();
    }

    public long[] getAndResetPublishLatencyStatsUSec() {
        publishLatencyStatsUSec.refresh();
        return publishLatencyStatsUSec.getBuckets();
    }

    public StatsBuckets getPublishLatencyStatsUSec() {
        return this.publishLatencyStatsUSec;
    }

    public long getMsgPublishedCounter() {
        return msgPublishedCounter;
    }

    @Override
    protected Boolean isAuthorized(String authRole, AuthenticationDataSource authenticationData) throws Exception {
        return true;
    }

    private void sendAckResponse(ProducerAck response) {
        try {
            String msg = response.context;
            if (!response.result.equals("ok")) {
                msg += "\n" + response.errorMsg;
            }
            getSession().getRemote().sendString(msg, new WriteCallback() {
                @Override
                public void writeFailed(Throwable th) {
                    log.warn("[{}] Failed to send ack: {}", producer.getTopic(), th.getMessage());
                }

                @Override
                public void writeSuccess() {
                    if (log.isDebugEnabled()) {
                        log.debug("[{}] Ack was sent successfully to {}", producer.getTopic(),
                                getRemote().getInetSocketAddress().toString());
                    }
                }
            });
        } catch (Exception e) {
            log.warn("[{}] Failed to send ack: {}", producer.getTopic(), e.getMessage());
        }
    }

    private void updateSentMsgStats(long msgSize, long latencyUsec) {
        this.publishLatencyStatsUSec.addValue(latencyUsec);
        this.numBytesSent.add(msgSize);
        this.numMsgsSent.increment();
        MSG_PUBLISHED_COUNTER_UPDATER.getAndIncrement(this);
    }

    protected ProducerBuilder<byte[]> getProducerBuilder(PulsarClient client) {
        ProducerBuilder<byte[]> builder = client.newProducer()
            .enableBatching(false)
            .messageRoutingMode(MessageRoutingMode.SinglePartition);

        // Set to false to prevent the server thread from being blocked if a lot of messages are pending.
        builder.blockIfQueueFull(false);

        if (queryParams.containsKey("producerName")) {
            builder.producerName(queryParams.get("producerName"));
        }

        if (queryParams.containsKey("initialSequenceId")) {
            builder.initialSequenceId(Long.parseLong(queryParams.get("initialSequenceId")));
        }

        if (queryParams.containsKey("hashingScheme")) {
            builder.hashingScheme(HashingScheme.valueOf(queryParams.get("hashingScheme")));
        }

        if (queryParams.containsKey("sendTimeoutMillis")) {
            builder.sendTimeout(Integer.parseInt(queryParams.get("sendTimeoutMillis")), TimeUnit.MILLISECONDS);
        }

        if (queryParams.containsKey("messageRoutingMode")) {
            checkArgument(
                    Enums.getIfPresent(MessageRoutingMode.class, queryParams.get("messageRoutingMode")).isPresent(),
                    "Invalid messageRoutingMode %s", queryParams.get("messageRoutingMode"));
            MessageRoutingMode routingMode = MessageRoutingMode.valueOf(queryParams.get("messageRoutingMode"));
            if (!MessageRoutingMode.CustomPartition.equals(routingMode)) {
                builder.messageRoutingMode(routingMode);
            }
        }

        Map<String, EncryptionKey> encryptionKeyMap = tryToExtractJsonEncryptionKeys();
        if (encryptionKeyMap != null) {
            popularProducerBuilderForClientSideEncrypt(builder, encryptionKeyMap);
        } else {
            popularProducerBuilderForServerSideEncrypt(builder);
        }
        return builder;
    }

    private Map<String, EncryptionKey> tryToExtractJsonEncryptionKeys() {
        if (!queryParams.containsKey("encryptionKeys")) {
            return null;
        }
        // Base64 decode.
        byte[] param = null;
        try {
            param = Base64.getDecoder().decode(StringUtils.trim(queryParams.get("encryptionKeys")));
        } catch (Exception base64DecodeEx) {
            return null;
        }
        try {
            Map<String, EncryptionKey> keys = ObjectMapperFactory.getMapper().getObjectMapper()
                    .readValue(param, new TypeReference<Map<String, EncryptionKey>>() {});
            if (keys.isEmpty()) {
                return null;
            }
            if (keys.values().iterator().next().getKeyValue() == null) {
                return null;
            }
            return keys;
        } catch (IOException ex) {
            return null;
        }
    }

    private void popularProducerBuilderForClientSideEncrypt(ProducerBuilder<byte[]> builder,
                                                            Map<String, EncryptionKey> encryptionKeyMap) {
        this.clientSideEncrypt = true;
        int keysLen = encryptionKeyMap.size();
        final String[] keyNameArray = new String[keysLen];
        final byte[][] keyValueArray = new byte[keysLen][];
        final List<KeyValue>[] keyMetadataArray = new List[keysLen];
        // Format keys.
        int index = 0;
        for (Map.Entry<String, EncryptionKey> entry : encryptionKeyMap.entrySet()) {
            checkArgument(StringUtils.isNotBlank(entry.getKey()), "Empty param encryptionKeys.key");
            checkArgument(entry.getValue() != null, "Empty param encryptionKeys.value");
            checkArgument(entry.getValue().getKeyValue() != null, "Empty param encryptionKeys.value.keyValue");
            keyNameArray[index] = StringUtils.trim(entry.getKey());
            keyValueArray[index] = entry.getValue().getKeyValue();
            if (entry.getValue().getMetadata() == null) {
                keyMetadataArray[index] = Collections.emptyList();
            } else {
                keyMetadataArray[index] = entry.getValue().getMetadata().entrySet().stream()
                        .map(e -> new KeyValue().setKey(e.getKey()).setValue(e.getValue()))
                        .collect(Collectors.toList());
            }
            builder.addEncryptionKey(keyNameArray[index]);
        }
        // Background: The order of message payload process during message sending:
        //  1. The Producer will composite several message payloads into a batched message payload if the producer is
        //    enabled batch;
        //  2. The Producer will compress the batched message payload to a compressed payload if enabled compression;
        //  3. After the previous two steps, the Producer encrypts the compressed payload to an encrypted payload.
        //
        // Since the order of producer operation for message payloads is "compression --> encryption", users need to
        // handle Compression themselves if needed. We just disable server-side batch process, server-side compression,
        // and server-side encryption, and only set the message metadata that.
        builder.enableBatching(false);
        // Disable server-side compression, and just set compression attributes into the message metadata when sending
        // messages(see the method "onWebSocketText").
        builder.compressionType(CompressionType.NONE);
        // Disable server-side encryption, and just set encryption attributes into the message metadata when sending
        // messages(see the method "onWebSocketText").
        builder.cryptoKeyReader(DummyCryptoKeyReaderImpl.INSTANCE);
        // Set the param `enableChunking` to `false`(the default value is `false`) to prevent unexpected problems if
        // the default setting is changed in the future.
        builder.enableChunking(false);
        // Inject encryption metadata decorator.
        builder.messageCrypto(new WSSDummyMessageCryptoImpl(msgMetadata -> {
            for (int i = 0; i < keyNameArray.length; i++) {
                msgMetadata.addEncryptionKey().setKey(keyNameArray[i]).setValue(keyValueArray[i])
                        .addAllMetadatas(keyMetadataArray[i]);
            }
        }));
        // Do warning param check and print warning log.
        printLogIfSettingDiscardedBatchedParams();
        printLogIfSettingDiscardedCompressionParams();
    }

    private void popularProducerBuilderForServerSideEncrypt(ProducerBuilder<byte[]> builder) {
        this.clientSideEncrypt = false;
        if (queryParams.containsKey("batchingEnabled")) {
            boolean batchingEnabled = Boolean.parseBoolean(queryParams.get("batchingEnabled"));
            if (batchingEnabled) {
                builder.enableBatching(true);
                if (queryParams.containsKey("batchingMaxMessages")) {
                    builder.batchingMaxMessages(Integer.parseInt(queryParams.get("batchingMaxMessages")));
                }

                if (queryParams.containsKey("maxPendingMessages")) {
                    builder.maxPendingMessages(Integer.parseInt(queryParams.get("maxPendingMessages")));
                }

                if (queryParams.containsKey("batchingMaxPublishDelay")) {
                    builder.batchingMaxPublishDelay(Integer.parseInt(queryParams.get("batchingMaxPublishDelay")),
                            TimeUnit.MILLISECONDS);
                }
            } else {
                builder.enableBatching(false);
                printLogIfSettingDiscardedBatchedParams();
            }
        }

        if (queryParams.containsKey("compressionType")) {
            checkArgument(Enums.getIfPresent(CompressionType.class, queryParams.get("compressionType")).isPresent(),
                    "Invalid compressionType %s", queryParams.get("compressionType"));
            builder.compressionType(CompressionType.valueOf(queryParams.get("compressionType")));
        }

        if (queryParams.containsKey("encryptionKeys")) {
            builder.cryptoKeyReader(service.getCryptoKeyReader().orElseThrow(() -> new IllegalStateException(
                    "Can't add encryption key without configuring cryptoKeyReaderFactoryClassName")));
            String[] keys = queryParams.get("encryptionKeys").split(",");
            for (String key : keys) {
                builder.addEncryptionKey(key);
            }
        }
    }

    private void printLogIfSettingDiscardedBatchedParams() {
        if (clientSideEncrypt && queryParams.containsKey("batchingEnabled")) {
            log.info("Since clientSideEncrypt is true, the param batchingEnabled of producer will be ignored");
        }
        if (queryParams.containsKey("batchingMaxMessages")) {
            log.info("Since batchingEnabled is false, the param batchingMaxMessages of producer will be ignored");
        }
        if (queryParams.containsKey("maxPendingMessages")) {
            log.info("Since batchingEnabled is false, the param maxPendingMessages of producer will be ignored");
        }
        if (queryParams.containsKey("batchingMaxPublishDelay")) {
            log.info("Since batchingEnabled is false, the param batchingMaxPublishDelay of producer will be ignored");
        }
    }

    private void printLogIfSettingDiscardedCompressionParams() {
        if (clientSideEncrypt && queryParams.containsKey("compressionType")) {
            log.info("Since clientSideEncrypt is true, the param compressionType of producer will be ignored");
        }
    }

    private static final Logger log = LoggerFactory.getLogger(ProducerHandler.class);


    private RelNotifToken token = null;
    private final String keyFile;
    private final ObjectMapper mapper = new ObjectMapper();

    @Override
    protected void extractTopicName(final HttpServletRequest request) {
        //not needed, we have explicit topic set by calling <code>extractNotificationTopicName</code>
    }

    private TopicName extractNotificationTopicName(HttpServletRequest request) {
        if (token == null)
            return null;

        return token.getTopicName();
    }

    private void verifyToken(HttpServletRequest request) {
        String[] tokens = request.getParameterMap().get("token");
        if (tokens == null || tokens.length != 1) {
            throw new IllegalArgumentException("query string must contain one and only one token");
        }
        token = new RelNotifToken(tokens[0], keyFile);
        token.verifyWebSocketAllowed();
        token.verifyCanWrite();
    }

    private ProducerBuilder<byte[]> getNotificationProducerBuilder(PulsarClient client) {
        ProducerBuilder<byte[]> builder = client.newProducer()
                .enableBatching(false)
                .messageRoutingMode(MessageRoutingMode.SinglePartition);

        // Set to false to prevent the server thread from being blocked if a lot of messages are pending.
        builder.blockIfQueueFull(false);

        Optional<Integer> batchingMaxMessages = token.getBatchingMaxMessages(queryParams);
        if (batchingMaxMessages.isPresent()) {
            builder.batchingMaxMessages(batchingMaxMessages.get());
        }

        Optional<Integer> maxPendingMessages = token.getMaxPendingMessages(queryParams);
        if (maxPendingMessages.isPresent()) {
            builder.batchingMaxMessages(maxPendingMessages.get());
        }

        return builder;
    }

    private static class ParseResult {
        String contextHeader;
        List<String> notificationHeaders;
        String message;
    }

    private ParseResult parseHeaders(String message) {
        ParseResult results = new ParseResult();
        ArrayList<String> headers = new ArrayList<>(8);
        while (true) {
            int i = message.indexOf('\n');
            if (i == -1) {
                break;
            }
            String header = message.substring(0, i);
            message = message.substring(i + 1);
            if (header.length() == 0) {
                break;
            }
            headers.add(header);
        }
        results.message = message;
        if (headers.isEmpty()) {
            return results;
        }
        results.contextHeader = headers.get(0);
        results.notificationHeaders = headers.subList(1, headers.size());

        return results;
    }

}
