/******************************************************************************
*                                                                             *
*  TakTuk, a middleware for adaptive large scale parallel remote executions   *
*  deployment. Perl implementation, copyright(C) 2006 Guillaume Huard.        *
*                                                                             *
*  This program is free software; you can redistribute it and/or modify       *
*  it under the terms of the GNU General Public License as published by       *
*  the Free Software Foundation; either version 2 of the License, or          *
*  (at your option) any later version.                                        *
*                                                                             *
*  This program is distributed in the hope that it will be useful,            *
*  but WITHOUT ANY WARRANTY; without even the implied warranty of             *
*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
*  GNU General Public License for more details.                               *
*                                                                             *
*  You should have received a copy of the GNU General Public License          *
*  along with this program; if not, write to the Free Software                *
*  Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA *
*                                                                             *
*  Contact: Guillaume.Huard@imag.fr                                           *
*           ENSIMAG - Laboratoire ID                                          *
*           51 avenue Jean Kuntzmann                                          *
*           38330 Montbonnot Saint Martin                                     *
*                                                                             *
******************************************************************************/

#include  "taktuk.h"
#include <arpa/inet.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <stdio.h>
#include <pthread.h>

/*#define DEBUG*/
#ifdef DEBUG
#define debug(f, ...) printf("Line %d : " f, __LINE__, ##__VA_ARGS__)
#else
#define debug(f, ...)
#endif

int taktuk_error;

const char *taktuk_error_msg(int msg_code)
  {
    switch(msg_code)
      {
      case TAKTUK_ESWRIT:
        return "write failed";
      case TAKTUK_EFCLSD:
        return "TakTuk engine closed the communication channel";
      case TAKTUK_ESREAD:
        return "read error";
      case TAKTUK_ETMOUT:
        return "timeouted";
      case TAKTUK_EINVST:
        return "invalid destination set specification";
      case TAKTUK_EALLOC:
        return "memory allocation failed";
      case TAKTUK_EIBUFF:
        return "invalid buffer";
      case TAKTUK_EINTRN:
        return "internal error";
      default:
        return "Unknown error";
      }
  }

#define MAX_LOCAL_HEADER 64
#define MAX_RECV_BUFFER 10
#define MSG_LENGTH 1

static void *put_uint32(void *pos, uint32_t value)
  {
    uint32_t to_be_sent;
    char *src, *dst;
    int size = sizeof(uint32_t);

    to_be_sent = htonl(value);
    src = (char *) &to_be_sent;
    dst = (char *) pos;
    while (size--)
        *(dst++) = *(src++);
    return dst;
  }

static void *put_bytes(void *pos, const void *mem, size_t length)
  {
    memcpy(pos, mem, length);
    return ((char *) pos) + length;
  }

static void *get_uint32(const void *pos, uint32_t *value)
  {
    uint32_t received;
    char *src, *dst;
    int size = sizeof(uint32_t);

    src = (char *) pos;
    dst = (char *) &received;
    while (size--)
        *(dst++) = *(src++);
    *value = ntohl(received);
    return src;
  }

static void *get_bytes(const void *pos, void *mem, size_t length)
  {
    memcpy(mem, pos, length);
    return ((char *) pos) + length;
  }

static int insistent_write(int fd, const void *buffer, size_t size)
  {
    int result = 0;
    int total = size;

    while (size && (result > -1))
      {
        result = write(fd, buffer, size);
        if (result < 0)
            taktuk_error = TAKTUK_ESWRIT;
        if (result == 0)
          {
            taktuk_error = TAKTUK_EINTRN;
            result = -1;
          }
        if (result > -1)
          {
            buffer = ((char *) buffer) + result;
            size -= result;
          }
      }
    return (result == -1)?result:total;
  }

static int insistent_read(int fd, void *buffer, size_t size)
  {
    int result = 1;
    int total = size;

    while (size && (result > 0))
      {
        result = read(fd, buffer, size);
        if (result > 0)
          {
            buffer = ((char *) buffer) + result;
            size -= result;
            debug("Read %d bytes, remaining %ld to read\n", result, size);
          }
        else
            if (result == 0)
              {
                taktuk_error = TAKTUK_EFCLSD;
                result = -1;
              }
            else
                taktuk_error = TAKTUK_ESREAD;
      }
    return (result == -1)?result:total;
  }

static int get_taktuk_read_fd()
  {
    static int taktuk_read_fd = -1;
    char *fd_string;

    if (taktuk_read_fd == -1)
      {
        fd_string = getenv("TAKTUK_CONTROL_READ");
        if (fd_string != NULL)
            taktuk_read_fd = atoi(fd_string);
      }
    return taktuk_read_fd;
  }

static int get_taktuk_write_fd()
  {
    static int taktuk_write_fd = -1;
    char *fd_string;

    if (taktuk_write_fd == -1)
      {
        fd_string = getenv("TAKTUK_CONTROL_WRITE");
        if (fd_string != NULL)
            taktuk_write_fd = atoi(fd_string);
      }
    return taktuk_write_fd;
  }

static pthread_mutex_t *get_taktuk_mutex()
  {
    static pthread_mutex_t mutex;
    static int mutex_initialized = 0;

    if (!mutex_initialized)
      {
        pthread_mutex_init(&mutex, NULL);
        mutex_initialized = 1;
      }
    return &mutex;
  }

static int send_header(const char *dest, size_t body_length)
  {
    char buffer[MAX_LOCAL_HEADER];
    char *allocated_buffer = NULL;
    void *header = buffer;
    void *current;
    uint32_t header_size;
    uint32_t dest_size;
    int result;
    char send_to = TAKTUK_SEND_TO;
    char message = TAKTUK_MESSAGE;
    static char *taktuk_from = NULL;
    static int taktuk_from_size;
    int taktuk_write_fd;

    if (taktuk_from == NULL)
      {
        taktuk_from = getenv("TAKTUK_RANK");
        if (taktuk_from != NULL)
            taktuk_from_size = strlen(taktuk_from);
        else
            taktuk_from_size = 0;
      }
    dest_size = strlen(dest);
    /* Do not take into account the first four bytes encoding the total size
     * Unfortunately I have to compute size first to allocate my memory if
     * needed ...
     */
    header_size = sizeof(uint32_t) +
                  MSG_LENGTH +
                  sizeof(uint32_t) + dest_size +
                  MSG_LENGTH +
                  sizeof(uint32_t) + taktuk_from_size;

    if (header_size > MAX_LOCAL_HEADER)
      {
        allocated_buffer = (char *) malloc(header_size);
        if (allocated_buffer == NULL)
          {
            taktuk_error = TAKTUK_EALLOC;
            return -1;
          }
        header = allocated_buffer;
      } 

    current = header;
    /* The first four bytes of the header encode its own size not including
     * the four bytes themselves
     */
    current = put_uint32(current, header_size-sizeof(uint32_t) + body_length);
    current = put_bytes(current, &send_to, MSG_LENGTH);
    current = put_uint32(current, dest_size);
    current = put_bytes(current, dest, dest_size);
    current = put_bytes(current, &message, MSG_LENGTH);
    current = put_uint32(current, taktuk_from_size);
    current = put_bytes(current, taktuk_from, taktuk_from_size);

    taktuk_write_fd = get_taktuk_write_fd();
    debug ("Taktuk write FD : %d\n", taktuk_write_fd);
    if (taktuk_write_fd != -1)
        result = insistent_write(taktuk_write_fd, header, header_size);
    else
      {
        taktuk_error = TAKTUK_ENOCON;
        result = -1;
      }
    if (allocated_buffer != NULL)
        free(allocated_buffer);
    return result;
  }

static int recv_header(char *from, char *msg_code, size_t *length, int timeout)
  {
    char buffer[MAX_RECV_BUFFER];
    char *current;
    char wait = TAKTUK_WAIT_MESSAGE;
    int buffer_size;
    char *allocated_buffer;
    int result;
    uint32_t from_size, value;
    int taktuk_write_fd;
    int taktuk_read_fd;
    pthread_mutex_t *mutex;

    taktuk_write_fd = get_taktuk_write_fd();
    if (taktuk_write_fd == -1)
      {
        taktuk_error = TAKTUK_ENOCON;
        return -1;
      }
    taktuk_read_fd = get_taktuk_read_fd();
    if (taktuk_read_fd == -1)
      {
        taktuk_error = TAKTUK_ENOCON;
        return -1;
      }
    current = buffer;
    current = put_uint32(current, 0);
    current = put_bytes(current, &wait, MSG_LENGTH);
    buffer_size = MSG_LENGTH;
    if (timeout > 0)
      {
        /* WE ASSUME THAT MAX_RECV_BUFFER IS LARGE ENOUGH
         */
        sprintf(current, "%d", timeout);
        buffer_size += strlen(current);
      }
    put_uint32(buffer, buffer_size);
    buffer_size += sizeof(uint32_t);

    mutex = get_taktuk_mutex();
    pthread_mutex_lock(mutex);
    result = insistent_write(taktuk_write_fd, buffer, buffer_size);
    pthread_mutex_unlock(mutex);
    if (result > 0)
        result =
            insistent_read(taktuk_read_fd, buffer, sizeof(uint32_t)+MSG_LENGTH);
    if (result > 0)
      {
        current = buffer;
        current = get_uint32(current, &value);
        *length = value;
        current = get_bytes(current, msg_code, MSG_LENGTH);
        /* Minus 1 because of the message code already read
         */
        *length -= 1;
        debug("Read message %c of length %ld\n", *msg_code, *length);
        if (*msg_code == TAKTUK_MESSAGE)
          {
            result = insistent_read(taktuk_read_fd, buffer, sizeof(uint32_t));
            if (result > 0)
              {
                current = buffer;
                current = get_uint32(current, &from_size);
                *length -= sizeof(uint32_t);
                if (from != NULL)
                  {
                    result = insistent_read(taktuk_read_fd, from, from_size);
                    from[from_size] = '\0';
                  }
                else
                  {
                    if (from_size <= MAX_RECV_BUFFER)
                      {
                        result = insistent_read(taktuk_read_fd, buffer, from_size);
                      }
                    else
                      {
                        if ((allocated_buffer = (char *) malloc(from_size))
                            != NULL)
                          {
                            result = insistent_read(taktuk_read_fd,
                                         allocated_buffer, from_size);
                            free(allocated_buffer);
                          }
                        else
                          {
                            taktuk_error = TAKTUK_EALLOC;
                            result = -1;
                          }
                      }
                  }
                *length -= from_size;
              }
          }
        else if (*msg_code == TAKTUK_TIMEOUT)
          {
            /* in this case, length should be nul
             */
            if (*length != 0)
                taktuk_error = TAKTUK_EINTRN;
            else
                taktuk_error = TAKTUK_ETMOUT;
            result = -1;
          }
      }
    return result;
  }

int taktuk_send(const char *dest, const void *buffer, size_t length)
  {
    int result;
    pthread_mutex_t *mutex;
    int taktuk_write_fd;

    mutex = get_taktuk_mutex();
    taktuk_write_fd = get_taktuk_write_fd();
    pthread_mutex_lock(mutex);
    result = send_header(dest, length);
    debug("Header sent, result: %d\n", result);
    if (result > -1)
      {
        result = insistent_write(taktuk_write_fd, buffer, length);
      }
    pthread_mutex_unlock(mutex);
    return result;
  }

int taktuk_recv(char *from, void *buffer, size_t length, int timeout)
  {
    int result;
    char msg_code;
    size_t msg_length;
    int taktuk_read_fd;

    taktuk_read_fd = get_taktuk_read_fd();
    result = recv_header(from, &msg_code, &msg_length, timeout);
    if (result > -1)
      {
        if (msg_length > length)
          {
            /* We still need to read the data contained in the message
             * Don't know if this is really usefull as the user prog is
             * probably broken
             */
            while (msg_length && (result > -1))
              {
                result = insistent_read(taktuk_read_fd, buffer, length);
                msg_length -= length;
                if (msg_length < length)
                    length = msg_length;
              }
            taktuk_error = TAKTUK_EIBUFF;
            result = -1;
          }
        else
          {
            result = insistent_read(taktuk_read_fd, buffer, msg_length);
          }
      }
    return result;
  }

static struct iovec *copy_iovec(const struct iovec *iov, int iovcnt)
  {
    struct iovec *allocated_iov;
    int i;

    allocated_iov = (struct iovec *) malloc(sizeof(struct iovec)*iovcnt);
    if (allocated_iov != NULL)
      {
        for (i=0; i<iovcnt; i++)
          {
            allocated_iov[i].iov_base = iov[i].iov_base;
            allocated_iov[i].iov_len = iov[i].iov_len;
          }
      }
    return allocated_iov;
  }

static void remove_bytes(struct iovec **iov, int *iovcnt, int count)
  {
    while (count > 0)
      {
        if (count >= (*iov)[0].iov_len)
          {
            count -= (*iov)[0].iov_len;
            (*iov)++;
            (*iovcnt)--;
          }
        else
          {
            (*iov)[0].iov_base = ((char *) (*iov)[0].iov_base) + count;
            (*iov)[0].iov_len -= count;
            count = 0;
          }
      }
  }

int taktuk_sendv(const char *dest, const struct iovec *iov, int iovcnt)
  {
    int i, result, count, total=0;
    size_t length = 0;
    struct iovec *allocated_iov;
    pthread_mutex_t *mutex;
    int taktuk_write_fd;

    debug("Initiating send\n");
    for (i=0; i<iovcnt; i++)
        length += iov[i].iov_len;
    mutex = get_taktuk_mutex();
    taktuk_write_fd = get_taktuk_write_fd();
    pthread_mutex_lock(mutex);
    result = send_header(dest, length);
    if (result > -1)
      {
        result = writev(taktuk_write_fd, iov, iovcnt);
        total = result;
        debug ("Sent %d\n", result);
        if (result < 0)
            taktuk_error = TAKTUK_ESWRIT;
        if ((result > length) || (result == 0))
          {
            taktuk_error = TAKTUK_EINTRN;
            result = -1;
          }
        if ((result > 0) && (result < length))
          {
            /* In the following, available memory is not checked, so memory
             * fault might occur
             */
            allocated_iov = copy_iovec(iov, iovcnt);
            count = iovcnt;
            if (allocated_iov != NULL)
              {
                while ((result > 0) && (result < length))
                  {
                    remove_bytes(&allocated_iov, &count, result);
                    length -= result;
                    result = writev(taktuk_write_fd, allocated_iov, count);
                    total += result;
                    debug ("Sent %d\n", result);
                    if ((result > length) || (result == 0))
                      {
                        taktuk_error = TAKTUK_EINTRN;
                        result = -1;
                      }
                    else if (result < 0)
                      {
                        taktuk_error = TAKTUK_ESWRIT;
                      }
                  }
                free(allocated_iov);
              }
            else
              {
                taktuk_error = TAKTUK_EALLOC;
                result = -1;
              }
          }
      }
    pthread_mutex_unlock(mutex);
    debug("Send complete\n");
    return (result > -1)?total:result;
  }

int taktuk_recvv(char *from, const struct iovec *iov, int iovcnt, int timeout)
  {
    int result, count, total=0;
    char msg_code;
    size_t msg_length;
    struct iovec *allocated_iov;
    struct iovec *current_iov;
    int taktuk_read_fd;

    taktuk_read_fd = get_taktuk_read_fd();
    result = recv_header(from, &msg_code, &msg_length, timeout);
    if (result > -1)
      {
        result = readv(taktuk_read_fd, iov, iovcnt);
        total = result;
        if (result < 0)
            taktuk_error = TAKTUK_ESREAD;
        if (msg_length && !result)
          {
            taktuk_error = TAKTUK_EFCLSD;
            result = -1;
          }
        /* In the following case, the control channel is left unclean...
         * Not necessarily to be fixed as the protocol is probably broken
         */
        if (result > msg_length)
          {
            taktuk_error = TAKTUK_EIBUFF;
            result = -1;
          }
        if (result < msg_length)
          {
            allocated_iov = copy_iovec(iov, iovcnt);
            count = iovcnt;
            if (allocated_iov != NULL)
              {
                current_iov = allocated_iov;
                while ((result > 0) && (result < msg_length))
                  {
                    remove_bytes(&current_iov, &count, result);
                    msg_length -= result;
                    result = readv(taktuk_read_fd, current_iov, count);
                    total += result;
                    if (result > msg_length)
                      {
                        taktuk_error = TAKTUK_EIBUFF;
                        result = -1;
                      }
                    else if (result == 0)
                      {
                        taktuk_error = TAKTUK_EFCLSD;
                        result = -1;
                      }
                    else if (result < 0)
                        taktuk_error = TAKTUK_ESREAD;
                  }
                free(allocated_iov);
              }
            else
              {
                taktuk_error = TAKTUK_EALLOC;
                result = -1;
              }
          }
      }
    return (result < 0)?result:total;
  }
