/* 
 * Copyright (C) 1999-2001 Peter T. Breuer <ptb@it.uc3m.es>
 */


#include <stdlib.h>
#include <sys/time.h>
#include <sys/types.h>
#include <unistd.h>
#include <setjmp.h>
#include <sys/errno.h>
#include <sys/socket.h>
#include <netdb.h>
#include <syslog.h>
#include <stdio.h>
#include <stdio.h>
#include <netinet/in.h>

#ifdef USING_SSL
#include "sslincl.h"
#endif

#include "alarm.h"
#include "stream.h"
#include "cliserv.h"
#include "socket.h"
#include "select.h"

#ifndef MY_NAME
#define MY_NAME "stream"
#endif

extern int debug_level;

#ifdef USING_SSL

inline static int
ssl_check(SSL * con, int res) {

      switch (SSL_get_error(con,res)){ 
           case SSL_ERROR_WANT_WRITE:
           case SSL_ERROR_WANT_X509_LOOKUP:
           case SSL_ERROR_SYSCALL:
           case SSL_ERROR_ZERO_RETURN:
           case SSL_ERROR_SSL:
             return -1;
      }
      return 0;
}
#endif

static char *
mkbuf(int size) {

      char *mbuf = NULL;
      static char trivbuf[0];

      if (size <= 0)
          return &trivbuf[0];

      while (!mbuf) {
         static int repeats;
         mbuf = calloc(1,size);
         if (mbuf) {
           repeats = 0;
           break;
         }
         if (repeats++ > 9) {
           repeats = 0;
           return NULL;
         }
         // PTB sleep 1s
         microsleep(1000000);
      }
      return mbuf;
}


/*
 *  Send or receive packet. If recv and buffer addr is zero, drop it.
 *  Return number of bytes got or error value.
 *
 * Not reentrant because of timeout handling with longjmp
 */
static int
sock_xmit (int op, int sock, char *buf, int size, int timeout, unsigned flags)
{
    int result = 0;

    int off;
    int len;
    static short errorcount;
    int tot;
    static char *mbuf;
    static int msize;
    char *tbuf;
    int count = 0;

    tot = 0;
    len = size;
    off = 0;

    // PTB cope with getting a null buffer for dropping requests. Make one
    if (!buf) {
        if (mbuf && msize < size) {
           free(mbuf);
           msize = 0;
           mbuf  = NULL;
        }
        // PTB the size > 0 check stops us dying with NOMEM for no reason
        if (!mbuf && size > 0) { 
            mbuf = mkbuf(size);
            if (!mbuf)
                    return -ENOMEM;
            msize = size;
        }
        tbuf = mbuf;
    } else {
        tbuf = buf;
    }

    while (len > 0) {

        // PTB only one try if we are not blocking
        if (count++ > 0 && (flags & ENBD_STREAM_NONBLOCK)) {
            //PERR("nonblock type %d fragment size %d/%d\n",op,result,result+len);
            goto exit;
        }

        switch (op) {
          case  WRITE: 
            {
               fd_set fds, xfds;
               fd_set * wfds = & fds, * efds = & xfds;

               static struct alarm my_jmp;
               static char * addr;
               addr = tbuf + off;

               FD_ZERO(efds); FD_SET(sock,efds);
               FD_ZERO(wfds); FD_SET(sock,wfds);
             
               result =
                   microselect(sock+1,NULL,wfds,efds,1000000*timeout);
               if (result <= 0)  {
                   result = -ETIMEDOUT;
                   goto exit;
               }
               if (catch_alarm(&my_jmp, timeout)) {
                 return result = -ETIMEDOUT;
                 goto exit;
               } else {
                 const unsigned sendflags
#if defined(MSG_DONTWAIT) && defined(MSG_NOSIGNAL)
                    = MSG_DONTWAIT|MSG_NOSIGNAL;
#else
                    = 0;
#endif
                 result = send (sock, addr, len, sendflags);
                 uncatch_alarm(&my_jmp);
               }
               // PTB deal with error below

            }

            break;

          default:
          case READ:
            {
               fd_set fds, xfds;
               fd_set * rfds = & fds, * efds = & xfds;

               static struct alarm my_jmp;
               static char * addr;
               addr = tbuf + off;

               FD_ZERO(efds); FD_SET(sock,efds);
               FD_ZERO(rfds); FD_SET(sock,rfds);

               result =
                   microselect(sock+1,rfds,NULL,efds,1000000*timeout);
               if (result <= 0)  {
                   result = -ETIMEDOUT;
                   goto exit;
               }
               if (catch_alarm(&my_jmp, timeout)) {
                 return result = -ETIMEDOUT;
                 goto exit;
               } else {
                 result = recv (sock, addr, len, 0);
                 uncatch_alarm(&my_jmp);
               }
               // PTB deal with error below

            };

            break;
        } // endsw

        // PTB handle success from net
        if (result > 0) {
          len  -= result;
          off  += result;
          tot  += result;
          errorcount = 0;
          continue;
        }

        // PTB handle error result from net
        if (result == 0) {
          // PTB other side closed
          goto exit;
        }
            
        // PTB unknown error
        if (errorcount++ > 3) {
          goto exit; // PTB give up on loop
        }

        // PTB sleep 1 and try again
        microsleep(1000000);
        continue;  // PTB go round loop again

        // PTB can't get here
        goto exit;
    }

  exit:
    errorcount = 0;
    if (result > 0)
        return tot;
    return result;
}



#ifdef USING_SSL
/*
 *  Send or receive packet. If recv and buffer addr is zero, drop it.
 *  Return number of bytes got or error value.
 *
 * Not reentrant because of timeout handling with longjmp
 */
static int
SSL_xmit (int op, SSL* con, char *buf, int size, int timeout, unsigned flags)
{
    int result = 0;

    // PTB these are static because of the longjmp in alarm
    int off;
    int len;
    static short errorcount;
    int tot;
    static char *mbuf;
    static short msize;
    char *tbuf;
    int count = 0;

    tot = 0;
    len = size;
    off = 0;

    // PTB cope with getting a null buffer for dropping requests. Make one
    if (!buf) {
        if (mbuf && msize < size) {
           free(mbuf);
           msize = 0;
           tbuf  = NULL;
        }
        if (!mbuf) {
            mbuf = mkbuf(size);
            if (!mbuf)
                    return -ENOMEM;
            msize = size;
        }
        tbuf = mbuf;
    } else {
        tbuf= buf;
    }


    while (len > 0) {

        // PTB only one try if we are not blocking
        if (count++ > 0 && (flags & ENBD_STREAM_NONBLOCK)) {
            //PERR("nonblock type %d fragment size %d/%d\n",op,result,result+len);
            goto exit;
        }

        switch (op) {
          case  WRITE: 
              result = SSL_write(con, tbuf+off, len);
              if (ssl_check(con, result) < 0) {
                result = -EINVAL;
                goto exit;
              }
            break;

          default:
          case READ:
              result = SSL_read(con, tbuf+off, len);
              if (ssl_check(con, result) < 0) {
                result = -EINVAL;
                goto exit;
              }

            break;
        } // endsw

        // PTB handle success from net
        if (result > 0) {
          len  -= result;
          off  += result;
          tot  += result;
          errorcount = 0;
          continue;
        }

        // PTB handle error result from net
        if (result == 0) {
          // PTB other side closed
          goto exit;
        }
            
        // PTB unknown error
        if (errorcount++ > 3)
          goto exit; // PTB give up on loop
        // PTB sleep 1s and try again
        microsleep(1000000);
        continue;  // PTB go round loop again

        // PTB can't get here
        goto exit;
    }

  exit:
    errorcount = 0;
    if (result > 0)
        return tot;
    return result;
}
#endif

#ifdef USING_SSL
static void 
closessl(struct nbd_stream * self) {
   if (self->con) SSL_free(self->con);
   self->con = NULL;
}
#endif

static int 
disconnect_sock (struct nbd_stream * self)
{
    int err = 0;

    if (self->magic != ENBD_STREAM_MAGIC) {
       PERR("stream %p has bad magic\n", self);
       return err = -1;
    }

    if (self->sock >= 0) {
      err = close (self->sock);
      if (err < 0)
        return err;
    }
    self->sock = -1;
    self->flags &= ~ENBD_STREAM_OPEN;

    return err;
}

#ifdef USING_SSL
static int 
disconnect_ssl (struct nbd_stream * self)
{
    int err = 0;

    if (self->magic != ENBD_STREAM_MAGIC) {
       PERR("stream %x has bad magic\n", (unsigned)self);
       return err = -1;
    }

    if (self->con)
       err = SSL_shutdown(self->con);

    if (self->sock >= 0) {
      err = close (self->sock);
      if (err < 0)
        return err;
    }
    self->sock = -1;
    self->flags &= ~ENBD_STREAM_OPEN;

    if (self->con)
       SSL_free(self->con);
    self->con = NULL;
    return err;
}

static int
clear_ssl(struct nbd_stream *self) {
    // PTB too much bother.
    return 0;
}
#endif

/*
 * eat incoming chars on the socket, for reopen
 */
static int
clear_sock(struct nbd_stream *self) {

    if (self->magic != ENBD_STREAM_MAGIC) {
       PERR("stream %p has bad magic\n", self);
       return  -ENOSTR;
    }
    if (self->sock < 0) {
       PERR("stream %p has closed socket\n", self);
       return  -8;
    }

    do {

        // mop up any pending stuff on the socket
        
        fd_set fds, xfds;
        fd_set * rfds = & fds, * efds = & xfds;
        char c;
        const long utimeout = 10000;
        int result;

        FD_ZERO(efds); FD_SET(self->sock,efds);
        FD_ZERO(rfds); FD_SET(self->sock,rfds);

        result = microselect(self->sock+1,rfds,NULL,efds, utimeout);

        if (result <= 0)  {
               break;
        }
        if (read(self->sock,&c,1) <= 0)
            break;

    } while (0);

    return 0;
}

/*
 * make a connection to the given address and port
 */
static int 
connect_sock(struct nbd_stream * self, int port, char *hostname)
{
    int sock;                 /* socket descriptor */
    struct sockaddr_in xaddrin;
    int xaddrinlen = sizeof (xaddrin);
    struct hostent *hostn;

    if (self->magic != ENBD_STREAM_MAGIC) {
       PERR("stream %p has bad magic\n", self);
       return  -ENOSTR;
    }

    if (port < 0 || !hostname) {
       return -EINVAL;
    }

    hostn = gethostbyname (hostname);
    if (!hostn) {
        PERR ("client gethostname %s failed %m\n", hostname);
        return -EDESTADDRREQ;
    }

    if ((sock = socket (AF_INET, SOCK_STREAM, IPPROTO_TCP)) < 0) {
        PERR ("client open socket failed %m\n");
        return -ENOTSOCK;
    }

    self->sock = sock;
    DEBUG ("client opened socket %d\n", sock);

    xaddrin.sin_family = AF_INET;
    xaddrin.sin_port = htons (port);
    xaddrin.sin_addr.s_addr = *((int *) hostn->h_addr);

    if (connect (sock, (struct sockaddr *) &xaddrin, xaddrinlen) < 0) {
        PERR ("client socket connect failed %m on port %d\n", port);
        self->close(self);
        return -ENOTCONN;
    }

    self->port = port;
    self->hostname = hostname;

    DEBUG ("client stream %x connected on socket %d to port %d\n",
       (unsigned)self,  sock, port);

    setmysockopt (sock);



    self->flags |= ENBD_STREAM_OPEN;

    return sock;
}

#ifdef USING_SSL
/*
 * make a connection to the given address and port
 */
static int 
connect_ssl(struct nbd_stream * self, int port, char *hostname)
{

    int err = 0;

    if (self->magic != ENBD_STREAM_MAGIC) {
       PERR("stream %x has bad magic\n", (unsigned)self);
       return  err = -ENOSTR;
    }

    if (port < 0 || !hostname) {
       return -EINVAL;
    }

    err = connect_sock(self,port,hostname);
    if (err < 0)
      return err;

    if (self->con)
#ifdef TRY_SSL_RECONNECT
      DEBUG("Reusing existing SSL connection...");  
#else
      closessl(self);
#endif
    if (!self->con)
      self->con = SSL_new(* self->ctx);

    if (!SSL_set_fd(self->con,self->sock)) {
      PERR("SSL connection failed\n");  
      self->close(self);
      return -ENOTSOCK;
    }

    if (SSL_connect(self->con) < 0) {
      PERR("SSL connection failed\n");  
      self->close(self);
      return -ENOTCONN;
    }

    self->flags |= ENBD_STREAM_OPEN;
    DEBUG("SSL connection established\n");
    return err;
}
#endif


static int
read_sock (struct nbd_stream *self, char *buf, int size) {

  int res;

  if (self->magic != ENBD_STREAM_MAGIC) {
       PERR("stream %p has bad magic\n", self);
       return  -7;
  }
  if (self->sock < 0) {
       PERR("stream %p has closed socket\n", self);
       return  -8;
  }
  res = sock_xmit (READ, self->sock, buf, size, self->timeout, self->flags);
  DEBUG("stream %p reads %dB (%#x..) to buf %p\n",
          self, size, buf?*(unsigned*)buf:0, buf);
  return res;
}
static int
write_sock (struct nbd_stream * self, char *buf, int size) {
  if (self->magic != ENBD_STREAM_MAGIC) {
       PERR("stream %p has bad magic\n", self);
       return  -7;
  }
  if (self->sock < 0) {
       PERR("stream %p has closed socket\n", self);
       return  -8;
  }
  DEBUG("stream %p writes %dB (%#x..) from buf %p\n",
          self, size, buf?*(unsigned*)buf:0, buf);
  return sock_xmit (WRITE, self->sock, buf, size, self->timeout, self->flags);
}

#ifdef USING_SSL
static int
read_ssl (struct nbd_stream *self, char *buf, int size) {
  if (self->magic != ENBD_STREAM_MAGIC) {
      PERR("stream %x has bad magic\n", (unsigned)self);
      return  -7;
  }
  if (!self->con) {
      PERR("stream %x has closed ssl connection\n", (unsigned)self);
      return  -9;
  }
  if (size == 0)  // PTB fight new openssl bug
      return 0;
  if (size < 0)
      return -EINVAL;
  return SSL_xmit (READ, self->con, buf, size, self->timeout, self->flags);
}
static int
write_ssl (struct nbd_stream *self, char *buf, int size) {
  if (self->magic != ENBD_STREAM_MAGIC) {
      PERR("stream %x has bad magic\n", (unsigned)self);
      return  -7;
  }
  if (!self->con) {
      PERR("stream %x has closed ssl connection\n", (unsigned)self);
      return  -9;
  }
  if (size == 0)  // PTB fight new openssl bug
      return 0;
  if (size < 0)
      return -EINVAL;
  return SSL_xmit (WRITE, self->con, buf, size, self->timeout, self->flags);
}
#endif

static int
reopen(struct nbd_stream *self) {
  int port = self->port;
  char *hostname = self->hostname;
  int err;
  self->close(self);
  err = self->open(self,port,hostname);
  if (err < 0)
      return err;
  self->clear(self);
  return 0;
}

int
#ifdef USING_SSL
initstream(struct nbd_stream *self, int timeout, SSL_CTX **ctx) {
#else
initstream(struct nbd_stream *self, int timeout, void *dummy) {
#endif

    self->timeout = timeout;
    self->sock = -1;
    self->port = -1;
    self->hostname = NULL;
    self->errs = 0;
    self->flags = 0;

#ifdef USING_SSL
    self->con = NULL;
    self->ctx = ctx;
    if (ctx) {
      self->close = disconnect_ssl;
      self->read  = read_ssl;
      self->write = write_ssl;
      self->open  = connect_ssl;
      self->clear = clear_ssl;
    } else
#endif
    {
      self->close = disconnect_sock;
      self->read  = read_sock;
      self->write = write_sock;
      self->open  = connect_sock;
      self->clear = clear_sock;
    }
    self->reopen = reopen;
    self->magic  = ENBD_STREAM_MAGIC;
    self->flags |= ENBD_STREAM_INITIALIZED;
    DEBUG("initialized stream %x\n", (unsigned)self);

    return 0;
}

