/*******************************************************************************
 * Copyright (c) 2014 IBM Corp.
 *
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the Eclipse Public License v1.0
 * and Eclipse Distribution License v1.0 which accompany this distribution.
 *
 * The Eclipse Public License is available at
 *    http://www.eclipse.org/legal/epl-v10.html
 * and the Eclipse Distribution License is available at
 *   http://www.eclipse.org/org/documents/edl-v10.php.
 *
 * Contributors:
 *    Allan Stockdill-Mander/Ian Craggs - initial API and implementation and/or initial documentation
 *******************************************************************************/

#include "MQTTClient.h"
#include "os_port.h"
#include "mbedtls/ssl.h"
//extern MQTT_LOGGER mqtt_logger;
//static int reconnect = 0;
#define FORCE_KEEP_ALIVE_INTERVAL 5 // After first interval(40s) not get response Ping, re-send KAL after this timeout(10s)
#define MAXIMUM_FORCE_KEEP_ALIVE_TIMES 4 // Maximum times force send keep alive to server(or server close connection)

char msgTypeMessages [][32] =
    {
        "CONNECT", "CONNACK", "PUBLISH", "PUBACK", "PUBREC", "PUBREL",
        "PUBCOMP", "SUBSCRIBE", "SUBACK", "UNSUBSCRIBE", "UNSUBACK",
        "PINGREQ", "PINGRESP", "DISCONNECT"
    };

void NewMessageData(MessageData* md, MQTTString* aTopicName, MQTTMessage* aMessgage)
{
    md->topicName = aTopicName;
    md->message = aMessgage;
}


int getNextPacketId(Client *c)
{
    return c->next_packetid = (c->next_packetid == MAX_PACKET_ID) ? 1 : c->next_packetid + 1;
}


int sendPacket(Client* c, int length, Timer* timer)
{
    int rc = FAILURE,
        sent = 0;

    if(expired(timer))
    {
        hlog_info("timer expired, need extend to send\n");
        countdown_ms(timer, 1000);
    }
    while (sent < length && !expired(timer))
    {
        rc = c->ipstack->mqttwrite(c->ipstack, &c->buf[sent], length, left_ms(timer));
        if (rc < 0)  // there was an error writing the data
            break;
        sent += rc;
    }
    if (sent == length)
    {
        //Only reset PING timer on sending PING successfully not on every msg sent
        //countdown(&c->ping_timer, c->keepAliveInterval); // record the fact that we have successfully sent the packet
        rc = SUCCESS;
    }
    else
        rc = FAILURE;

    hlog_info("sendPacket Done with return code = %d\n",rc);
    return rc;
}


void MQTTClient(Client* c, Network* network, unsigned int command_timeout_ms, unsigned char* buf, size_t buf_size, unsigned char* readbuf, size_t readbuf_size)
{
    int i;
    c->ipstack = network;

    for (i = 0; i < MAX_MESSAGE_HANDLERS; ++i)
        c->messageHandlers[i].topicFilter = 0;
    c->command_timeout_ms = command_timeout_ms;
    c->buf = buf;
    c->buf_size = buf_size;
    c->readbuf = readbuf;
    c->readbuf_size = readbuf_size;
    c->isconnected = 0;
    c->ping_outstanding = 0;
    c->defaultMessageHandler = NULL;
    InitTimer(&c->ping_timer);
}


int decodePacket(Client* c, int* value, int timeout)
{
    unsigned char i;
    int multiplier = 1;
    int len = 0;
    const int MAX_NO_OF_REMAINING_LENGTH_BYTES = 4;

    *value = 0;
    do
    {
        int rc = MQTTPACKET_READ_ERROR;

        if (++len > MAX_NO_OF_REMAINING_LENGTH_BYTES)
        {
            rc = MQTTPACKET_READ_ERROR; /* bad data */
            goto exit;
        }
        rc = c->ipstack->mqttread(c->ipstack, &i, 1, timeout);
        if (rc != 1)
            goto exit;
        *value += (i & 127) * multiplier;
        multiplier *= 128;
    }
    while ((i & 128) != 0);
exit:
    return len;
}


int readPacket(Client* c, Timer* timer)
{
    int rc = FAILURE;
    MQTTHeader header = {0};
    int len = 0;
    int rem_len = 0;

    /* 1. read the header byte.  This has the packet type in it */
    if (c->ipstack->mqttread(c->ipstack, c->readbuf, 1, left_ms(timer)) != 1)
        goto exit;

    len = 1;
    /* 2. read the remaining length.  This is variable in itself */
    decodePacket(c, &rem_len, left_ms(timer));
    len += MQTTPacket_encode(c->readbuf + 1, rem_len); /* put the original remaining length back into the buffer */

    /* 3. read the rest of the buffer using a callback to supply the rest of the data */
    if (rem_len > 0 && (c->ipstack->mqttread(c->ipstack, c->readbuf + len, rem_len, left_ms(timer)) != rem_len))
        goto exit;

    header.byte = c->readbuf[0];
    rc = header.bits.type;
exit:
    return rc;
}


// assume topic filter and name is in correct format
// # can only be at end
// + and # can only be next to separator
char isTopicMatched(char* topicFilter, MQTTString* topicName)
{
    char* curf = topicFilter;
    char* curn = topicName->lenstring.data;
    char* curn_end = curn + topicName->lenstring.len;

    while (*curf && curn < curn_end)
    {
        if (*curn == '/' && *curf != '/')
            break;
        if (*curf != '+' && *curf != '#' && *curf != *curn)
            break;
        if (*curf == '+')
        {
            // skip until we meet the next separator, or end of string
            char* nextpos = curn + 1;
            while (nextpos < curn_end && *nextpos != '/')
                nextpos = ++curn + 1;
        }
        else if (*curf == '#')
            curn = curn_end - 1;    // skip until end of string
        curf++;
        curn++;
    };

    return (curn == curn_end) && (*curf == '\0');
}


int deliverMessage(Client* c, MQTTString* topicName, MQTTMessage* message)
{
    int i;
    int rc = FAILURE;

    // we have to find the right message handler - indexed by topic
    for (i = 0; i < MAX_MESSAGE_HANDLERS; ++i)
    {
        if (c->messageHandlers[i].topicFilter != 0 && (MQTTPacket_equals(topicName, (char*)c->messageHandlers[i].topicFilter) ||
                    isTopicMatched((char*)c->messageHandlers[i].topicFilter, topicName)))
        {
            if (c->messageHandlers[i].fp != NULL)
            {
                MessageData md;
                NewMessageData(&md, topicName, message);
                c->messageHandlers[i].fp(&md);
                rc = SUCCESS;
            }
        }
    }

    if (rc == FAILURE && c->defaultMessageHandler != NULL)
    {
        MessageData md;
        NewMessageData(&md, topicName, message);
        c->defaultMessageHandler(&md);
        rc = SUCCESS;
    }

    return rc;
}

static int intCountExpire = 0;
static int intPingFailed = 0;
static int forceKeepAlive = 0;
int keepalive(Client* c)
{
    int rc = FAILURE;

    if (c->keepAliveInterval == 0)
    {
        rc = SUCCESS;
        goto exit;
    }
    if (expired(&c->ping_timer))
    {
        if (!c->ping_outstanding && forceKeepAlive <= MAXIMUM_FORCE_KEEP_ALIVE_TIMES)
        {
            hlog_info("Sending keepAlive packet");
            Timer timer;
            InitTimer(&timer);
            countdown_ms(&timer, 1000);
            int len = MQTTSerialize_pingreq(c->buf, c->buf_size);
            if (len > 0 && (rc = sendPacket(c, len, &timer)) == SUCCESS) // send the ping packet
            {
                // If in Force Ping state, increase Force Ping count every ping success.
                if(forceKeepAlive > 0)
                {
                    hlog_info("Force Ping, retries = %d",forceKeepAlive);
                    forceKeepAlive += 1;
                    countdown(&c->ping_timer, FORCE_KEEP_ALIVE_INTERVAL);
                    c->ping_outstanding = 0;
                }
                else
                {
                countdown(&c->ping_timer, c->keepAliveInterval); //Reset PING cooldown
                c->ping_outstanding = 1;
                }
            }
            else if (rc == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)
            {
                // If Server close the connection, Proceed to RESTART
                forceKeepAlive = MAXIMUM_FORCE_KEEP_ALIVE_TIMES + 1;
                hlog_error("Server closed connection, Proceed to restart");

            }
        }
        else
        {
            if(forceKeepAlive <= MAXIMUM_FORCE_KEEP_ALIVE_TIMES)
            {
                // If coudln't receive Ping Response in 40s, Force PING every 10s
                hlog_error("Couldn't receive Ping Response, Proceed to Force Ping");
                c->ping_outstanding = 0;
                forceKeepAlive += 1;
                rc = SUCCESS;
                return rc;
            }
            /* Reset MQTT Connection here */
            if(intCountExpire > 10)
            {
                hlog_error("Send KAL but not response, Disconnect MQTT and Taking Server's URL");
#if (SUPPORT_MQTT_IDLE == 1)
                if(check_network_connection(__FILE__,__FUNCTION__,__LINE__) == 0)
                {
                    hlog_fatal(ANALYTICS_SPECIFIER"Send KAL but can't receive response; Disconnect MQTT and Taking Server's URL",
                            MQTT_failedReceivePING,MQTT_Running_keepAlive,ANALYTICS_LOGSTATUS_ERROR,ANALYTICS_LOGSTATUS_ERROR);
                }
                else
                    wait_network_connection(__FILE__,__FUNCTION__,__LINE__);
#endif
                c->ping_outstanding = 0;
                forceKeepAlive = 0;
                intCountExpire = 0;
                rc = TIMEOUT;
            }
            else
            {
                hlog_error("Send KAL but not response, Re-Init MQTT and Re-Subscribe");
                forceKeepAlive = 0;
#if (SUPPORT_MQTT_IDLE == 1)
                if(check_network_connection(__FILE__, __FUNCTION__, __LINE__) == 0)
                {
                    hlog_fatal(ANALYTICS_SPECIFIER"Send KAL but can't receive response; Re-Init MQTT and Re-Subscribe",
                            MQTT_failedReceivePING,MQTT_Running_keepAlive,ANALYTICS_LOGSTATUS_ERROR,ANALYTICS_LOGSTATUS_ERROR);
                }
                else
                    wait_network_connection(__FILE__, __FUNCTION__, __LINE__);
#endif
                mqtts_control(MQTTS_CMD_RESTART);
                intCountExpire = intCountExpire +1;
            }
        }

    }

exit:
    return rc;
}


/**@brief       Periodcally read Client for any comming message from server
 * @details     Check for packet type from server and responds accordingly
 * @param[in]   Client variable and timeout
 * @param[out]  Initialize status
 */ 
int cycle(Client* c, Timer* timer)
{
    // read the socket, see what work is due
    int packet_type = readPacket(c, timer);

    int len = 0,
        rc = SUCCESS;

    switch (packet_type)
    {
        case CONNACK:
        case PUBACK:
        case SUBACK:
            break;
        case PUBLISH:
            {
                MQTTString topicName;
                MQTTMessage msg;
                if (MQTTDeserialize_publish((unsigned char*)&msg.dup, (int*)&msg.qos, (unsigned char*)&msg.retained, (unsigned short*)&msg.id, &topicName,
                            (unsigned char**)&msg.payload, (int*)&msg.payloadlen, c->readbuf, c->readbuf_size) != 1)
                    goto exit;
                // Set topic name and Message to Client handler
                deliverMessage(c, &topicName, &msg);
                if (msg.qos != QOS0)
                {
                    if (msg.qos == QOS1)
                    {
                        hlog_info("sendPacket PUBACK to server");
                        len = MQTTSerialize_ack(c->buf, c->buf_size, PUBACK, 0, msg.id);
                    }
                    else if (msg.qos == QOS2)
                    {
                        hlog_info("sendPacket PUBREC to server");
                        len = MQTTSerialize_ack(c->buf, c->buf_size, PUBREC, 0, msg.id);
                    }
                    if (len <= 0)
                        rc = FAILURE;
                    else
                        rc = sendPacket(c, len, timer);
                    if (rc == FAILURE)
                    {
                        goto exit; // there was a problem
                    }
                }
                break;
            }
        case PUBREC:
            {
                unsigned short mypacketid;
                unsigned char dup, type;
                if (MQTTDeserialize_ack(&type, &dup, &mypacketid, c->readbuf, c->readbuf_size) != 1)
                    rc = FAILURE;
                else if ((len = MQTTSerialize_ack(c->buf, c->buf_size, PUBREL, 0, mypacketid)) <= 0)
                    rc = FAILURE;
                else if ((rc = sendPacket(c, len, timer)) != SUCCESS) // send the PUBREL packet
                    rc = FAILURE; // there was a problem
                if (rc == FAILURE)
                    goto exit; // there was a problem
                break;
            }
        case PUBCOMP:
            break;
        case PINGRESP:
            hlog_info("Received PING RESP!!!");
            c->ping_outstanding = 0;
            intCountExpire = 0;
            forceKeepAlive = 0;
            break;
    }
    messageHandle();
    if(keepalive(c) == TIMEOUT)
    {
        rc = TIMEOUT;
    }
exit:
    if (rc == SUCCESS)
        rc = packet_type;
    return rc;
}


int MQTTYield(Client* c, int timeout_ms)
{
    // Keep handling cmd in timeout_ms time
    int rc = SUCCESS;
    Timer timer;

    InitTimer(&timer);
    countdown_ms(&timer, timeout_ms);
    while (!expired(&timer))
    {
        if (cycle(c, &timer) == FAILURE)
        {
            rc = FAILURE;
            break;
        }
        else if (cycle(c, &timer) == TIMEOUT)
        {
            rc = FAILURE;
            mqtts_stop_state();
            mqtts_control(MQTTS_CMD_STOP);
            break;
        }
    }
    return rc;
}


// only used in single-threaded mode where one command at a time is in process
int waitfor(Client* c, int packet_type, Timer* timer)
{
    int rc = FAILURE;
    hlog_info("Waitfor %s packet",msgTypeMessages[packet_type - 1]);
    do
    {
        if (expired(timer))
            break; // we timed out
    }
    while ((rc = cycle(c, timer)) != packet_type);
    return rc;
}


int MQTTConnect(Client* c, MQTTPacket_connectData* options)
{
    Timer connect_timer;
    int rc = FAILURE;
    MQTTPacket_connectData default_options = MQTTPacket_connectData_initializer;
    int len = 0;

    InitTimer(&connect_timer);
    countdown_ms(&connect_timer, c->command_timeout_ms);

    if (c->isconnected) // don't send connect packet again if we are already connected
        goto exit;

    if (options == 0)
        options = &default_options; // set default options if none were supplied

    if ((len = MQTTSerialize_connect(c->buf, c->buf_size, options)) <= 0)
        goto exit;
    hlog_info("sendPacket MQTTConnect");
    if ((rc = sendPacket(c, len, &connect_timer)) != SUCCESS)  // send the connect packet
        goto exit; // there was a problem

    c->keepAliveInterval = 40;//options->keepAliveInterval/2;
    countdown(&c->ping_timer, c->keepAliveInterval);
    // this will be a blocking call, wait for the connack
    if (waitfor(c, CONNACK, &connect_timer) == CONNACK)
    {
        unsigned char connack_rc = 255;
        char sessionPresent = 0;
        if (MQTTDeserialize_connack((unsigned char*)&sessionPresent, &connack_rc, c->readbuf, c->readbuf_size) == 1)
            rc = connack_rc;
        else
            rc = FAILURE;
    }
    else
        rc = FAILURE;
exit:
    if (rc == SUCCESS)
        c->isconnected = 1;
    return rc;
}


int MQTTSubscribe(Client* c, const char* topicFilter, enum QoS qos, messageHandler messageHandler)
{
    int rc = FAILURE;
    Timer timer;
    int len = 0;
    MQTTString topic = MQTTString_initializer;
    topic.cstring = (char *)topicFilter;

    InitTimer(&timer);
    countdown_ms(&timer, c->command_timeout_ms);

    if (!c->isconnected)
        goto exit;
    hlog_info("sendPacket MQTTSubscribe");
    len = MQTTSerialize_subscribe(c->buf, c->buf_size, 0, getNextPacketId(c), 1, &topic, (int*)&qos);
    if (len <= 0)
        goto exit;

    if ((rc = sendPacket(c, len, &timer)) != SUCCESS) // send the subscribe packet
        goto exit;             // there was a problem

    if (waitfor(c, SUBACK, &timer) == SUBACK)      // wait for suback (subscribe ack)
    {
        int count = 0, grantedQoS = -1;
        unsigned short mypacketid;
        if (MQTTDeserialize_suback(&mypacketid, 1, &count, &grantedQoS, c->readbuf, c->readbuf_size) == 1)
        {
            hlog_info("grantedQoS %d", grantedQoS);
            rc = grantedQoS; // 0, 1, 2 or 0x80
        }
        else
        {
            // VinhPQ: Log this case first. If something went wrong, will do handling, keep original source
            hlog_warn("Decode Subscribe package failure");
        }

        if (rc != 0x80)
        {
            int i;
            for (i = 0; i < MAX_MESSAGE_HANDLERS; ++i)
            {
                if (c->messageHandlers[i].topicFilter == 0)
                {
                    c->messageHandlers[i].topicFilter = topicFilter;
                    c->messageHandlers[i].fp = messageHandler;
                    rc = 0;
                    break;
                }
            }
        }
    }
    else
        rc = FAILURE;

exit:
    return rc;
}


int MQTTUnsubscribe(Client* c, const char* topicFilter)
{
    int rc = FAILURE;
    Timer timer;
    MQTTString topic = MQTTString_initializer;
    topic.cstring = (char *)topicFilter;
    int len = 0;

    InitTimer(&timer);
    countdown_ms(&timer, c->command_timeout_ms);

    if (!c->isconnected)
        goto exit;

    if ((len = MQTTSerialize_unsubscribe(c->buf, c->buf_size, 0, getNextPacketId(c), 1, &topic)) <= 0)
        goto exit;

    hlog_info("sendPacket MQTTUnSubscribe");
    if ((rc = sendPacket(c, len, &timer)) != SUCCESS) // send the subscribe packet
        goto exit; // there was a problem

    if (waitfor(c, UNSUBACK, &timer) == UNSUBACK)
    {
        unsigned short mypacketid;  // should be the same as the packetid above
        if (MQTTDeserialize_unsuback(&mypacketid, c->readbuf, c->readbuf_size) == 1)
            rc = 0;
    }
    else
        rc = FAILURE;

exit:
    return rc;
}


int MQTTPublish(Client* c, const char* topicName, MQTTMessage* message)
{
    int rc = FAILURE;
    Timer timer;
    MQTTString topic = MQTTString_initializer;
    topic.cstring = (char *)topicName;
    int len = 0;

    InitTimer(&timer);
    countdown_ms(&timer, c->command_timeout_ms);

    if (!c->isconnected)
        goto exit;

    if (message->qos == QOS1 || message->qos == QOS2)
        message->id = getNextPacketId(c);

    len = MQTTSerialize_publish(c->buf, c->buf_size, 0, message->qos, message->retained, message->id,
            topic, (unsigned char*)message->payload, message->payloadlen);
    if (len <= 0)
        goto exit;

    //    hlog_info("sendPacket MQTTPublish");
    if ((rc = sendPacket(c, len, &timer)) != SUCCESS) // send the subscribe packet
        goto exit; // there was a problem

    if (message->qos == QOS1)
    {
        if (waitfor(c, PUBACK, &timer) == PUBACK)
        {
            unsigned short mypacketid;
            unsigned char dup, type;
            if (MQTTDeserialize_ack(&type, &dup, &mypacketid, c->readbuf, c->readbuf_size) != 1)
                rc = FAILURE;
        }
        else
            rc = FAILURE;
    }
    else if (message->qos == QOS2)
    {
        if (waitfor(c, PUBCOMP, &timer) == PUBCOMP)
        {
            unsigned short mypacketid;
            unsigned char dup, type;
            if (MQTTDeserialize_ack(&type, &dup, &mypacketid, c->readbuf, c->readbuf_size) != 1)
                rc = FAILURE;
        }
        else
            rc = FAILURE;
    }

exit:
    return rc;
}


int MQTTDisconnect(Client* c)
{
    int rc = FAILURE;
    Timer timer;     // we might wait for incomplete incoming publishes to complete
    int len = MQTTSerialize_disconnect(c->buf, c->buf_size);

    InitTimer(&timer);
    countdown_ms(&timer, c->command_timeout_ms);
    hlog_info("sendPacket MQTTDisconnect");
    if (len > 0)
        rc = sendPacket(c, len, &timer);            // send the disconnect packet

    c->isconnected = 0;
    return rc;
}

