/***************************************************************************
 *
 * Copyright (c) 2000, 2001, 2002, 2003, 2004 BalaBit IT Ltd, Budapest, Hungary
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 as published
 * by the Free Software Foundation.
 *
 * Note that this permission is granted for only version 2 of the GPL.
 *
 * As an additional exemption you are allowed to compile & link against the
 * OpenSSL libraries as published by the OpenSSL project. See the file
 * COPYING for details.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 *
 * $Id: packstream.c,v 1.23 2004/07/22 08:09:40 sasa Exp $
 *
 * Author  : yeti
 * Auditor : bazsi
 * Last audited version:
 * Notes:
 *
 ***************************************************************************/

#include <zorp/zorp.h>
#include <zorp/packet.h>
#include <zorp/stream.h>
#include <zorp/log.h>

#include <stdio.h>
#include <string.h>
#include <errno.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/poll.h>
#include <assert.h>

typedef struct _ZStreamPacket
{
  ZStream super;

  ZPacket *partial_packet;
  guint partial_read;
  
  GAsyncQueue *queue;
  gboolean can_read, can_write;
  struct _ZStreamPacket *other;
} ZStreamPacket;

extern ZClass ZStreamPacket__class;

ZStream *
z_stream_packet_new(gchar *name);

/**
 * z_stream_packet_readable:
 * @self: #ZStreamPacket instance
 *
 * Returns: whether a packet could be read by
 * z_stream_packet_get_packet() semantics or not.
 */

static inline gboolean
z_stream_packet_readable(ZStreamPacket *self)
{
  return (!self->can_read ||
          self->partial_packet != NULL ||
	  g_async_queue_length(self->queue) > 0);
}


/*
 * The following three functions are source functions and I do not
 * feel like at all documenting them. 
 */

static gboolean
z_stream_packet_watch_prepare(ZStream *s,
			      GSource *src G_GNUC_UNUSED,
			      gint *timeout)
{
  ZStreamPacket *mystream = (ZStreamPacket *) s;

  if (mystream->super.want_write)
    {
      return TRUE;
    }

  if (mystream->super.want_read && z_stream_packet_readable(mystream))
    {
      return TRUE;
    } 
  else 
    {
      /* we are waiting to be awakaned by g_main_context_wakeup() */
      *timeout = -1;
      return FALSE;
    }
}

static gboolean
z_stream_packet_watch_check(ZStream *s, GSource *src G_GNUC_UNUSED)
{
  ZStreamPacket *mystream = (ZStreamPacket *) s;

  z_enter();
  if (mystream->super.want_write)
    {
      z_leave();
      return TRUE;
    }
  if (mystream->super.want_read && z_stream_packet_readable(mystream))
    {
      z_leave();
      return TRUE;
    }
  z_leave();
  return FALSE;
}

static gboolean
z_stream_packet_watch_dispatch(ZStream *s,
			       GSource *src G_GNUC_UNUSED)
{
  ZStreamPacket *mystream = (ZStreamPacket *) s;
  gboolean rc = TRUE;
  GIOCondition poll_cond = 0;
  
  z_enter();

  if (mystream->super.want_write)
    poll_cond |= G_IO_OUT;
  if (mystream->super.want_read && z_stream_packet_readable(mystream))
    {
      poll_cond |= G_IO_IN;
    }
  
  if (mystream->super.want_read && (poll_cond & G_IO_IN))
    {
      rc &= (*mystream->super.read_cb)(s,
				       poll_cond,
				       mystream->super.user_data_read);
    }
  
  if (mystream->super.want_write && (poll_cond & G_IO_OUT))
    {
      rc &= (*mystream->super.write_cb)(s,
				        poll_cond,
				        mystream->super.user_data_write);
    }
  z_leave();
  return rc;
}

static void
z_stream_packet_attach_source_method(ZStream *stream, GMainContext *context)
{
  z_enter();
  if (!stream->source)
    {
      stream->source = z_stream_source_new(stream);

      g_source_attach(stream->source, context);
    }
  
  z_leave();
  return;
}

static void
z_stream_packet_detach_source_method(ZStream *stream)
{
  GSource *source;
  
  z_enter();
  if (stream->source)
    {
      source = stream->source;
      stream->source = NULL;
      g_source_destroy(source);
      g_source_unref(source);
    }
    
  z_leave();
}


/**
 * z_stream_packet_feed:
 * @s: the #ZStreamPacket instance
 * @pack: packet to be sent
 *
 * Send @pack. This function assumes that the argument passed is a
 * ZStreamPacket, so you'd better test the stream using
 * z_stream_get_type() before calling this.
 *
 * Returns: A #GIOStatus. Can only fail if the other end of the stream
 * is already closed for reading or the stream if closed for writing.
 */
 
static GIOStatus
z_stream_packet_feed(ZStreamPacket *self, ZPacket *pack, GError **error)
{
  gboolean wasempty;
  
  z_enter();
  
  if (!self->can_write)
    {
      /*LOG
	This message indicates an internal error, please contact your Zorp support for assistance.
       */
      z_log(self->super.name, CORE_ERROR, 3, "Write attempted on ZStreamPacket and write side had been shut down;");
      g_set_error(error, 
                  G_IO_CHANNEL_ERROR, G_IO_CHANNEL_ERROR_PIPE, 
                  "Write attempted and channel write side had been shut down");
      errno = EPIPE;
      z_leave();
      return G_IO_STATUS_ERROR;
    }
  
  wasempty = (g_async_queue_length(self->queue) == 0);
  g_async_queue_push(self->queue, pack);
  
  if (wasempty &&
      self->super.want_read &&
      self->super.source != NULL)
    {
      g_main_context_wakeup(g_source_get_context(self->super.source));
    }

  z_leave();
  return G_IO_STATUS_NORMAL;
}

GIOStatus
z_stream_packet_send(ZStream *s, ZPacket *pack, GError **error)
{
  ZStreamPacket *self = Z_CAST(s, ZStreamPacket);
  gint rc;
  gsize length;
 
  /*LOG
    This message reports that data is sent out.
   */
  z_log(self->super.name, CORE_DEBUG, 7, "Sending to channel; count='%d'", pack->length);
  if (z_log_enabled(CORE_DEBUG, 9))
    z_data_dump(self->super.name, pack->data, pack->length);

  length = pack->length;
  if ((rc = z_stream_packet_feed(self->other, pack, error)) == G_IO_STATUS_NORMAL)
    {
      self->super.bytes_sent += length;
    }
  else
    {
      /*LOG
	This message indicates that an error occurred during sending data on the channel.
       */
      z_log(self->super.name, CORE_DEBUG, 7, "Sending to channel failed; error='%s'", (error != NULL) ? ((*error)->message) : ("Unknown error"));
    }
  return rc;  
}

/**
 * z_stream_packet_write_method:
 *
 * @stream: stream instance
 * @buf: bytes to be written
 * @count: length of buffer
 * @bytes_written: store number of bytes written here
 * @error: store error here
 *
 * write method implementation. Write into the packet stream. One
 * write always generates one packet on the other side.
 *
 * Returns: A #GIOStatus.
 */
 
static GIOStatus
z_stream_packet_write_method(ZStream *stream,
			     const gchar *buf,
			     gsize count,
			     gsize *bytes_written,
			     GError **error)
{
  ZStreamPacket *self = (ZStreamPacket *) stream;
  ZPacket *pack;
  GIOStatus rc;

  z_enter();

  if (!self->can_write)
    {
      g_set_error(error, G_IO_CHANNEL_ERROR,
		  G_IO_CHANNEL_ERROR_FAILED,
		  "Channel already closed");
      errno = EPIPE;
      z_leave();
      return G_IO_STATUS_ERROR;
    }
  
  pack = z_packet_new();
  *bytes_written = count;
  z_packet_set_data(pack, buf, count);
 
  /*LOG
   */
  z_log(self->super.name, CORE_DEBUG, 7, "Writing to packet channel; count='%d'", pack->length);
  if (z_log_enabled(CORE_DEBUG, 9))
    z_data_dump(self->super.name, pack->data, pack->length);

  rc = z_stream_packet_feed(self->other, pack, error);
  
  if (rc != G_IO_STATUS_NORMAL && rc != G_IO_STATUS_AGAIN)
    {
      /*LOG
       */
      z_log(self->super.name, CORE_DEBUG, 7, "Writing to packet channel failed; error='%s'", (error != NULL) ? ((*error)->message) : ("Unknown error"));
    }
  if (rc != G_IO_STATUS_NORMAL)
    {
      z_packet_free(pack);
    }

  z_leave();
  return rc;
}


static GIOStatus
z_stream_packet_fetch(ZStreamPacket *self, ZPacket **pack, GError **error)
{

  *pack = NULL;
  if (!self->can_read)
    {
      z_leave();
      return G_IO_STATUS_EOF;
    }
  
  if (self->partial_packet != NULL)
    {
      /*LOG
	This message indicates an internal error, please contact your Zorp support for assistance.
       */
      z_log(self->super.name, CORE_ERROR, 3, "Mixed read/recv calls on ZStreamPacket and pending partial_packet present;");
      g_set_error(error, G_IO_CHANNEL_ERROR, G_IO_CHANNEL_ERROR_FAILED,
                  "Read attempted and pending partial packet present");
      errno = EINVAL;
      z_leave();
      return G_IO_STATUS_ERROR;
    }

  if (self->super.timeout == -2)
    {
      /* no timeout, no blocking */
      *pack = g_async_queue_try_pop(self->queue);
    } 
  else if (self->super.timeout == -1)
    {
      /* infinite timeout */
      *pack = g_async_queue_pop(self->queue);
    }
  else
    {
      GTimeVal timeout;
	  
      g_get_current_time(&timeout);
      g_time_val_add(&timeout, ((guint64) self->super.timeout) * 1000);
      *pack = g_async_queue_timed_pop(self->queue, &timeout);
  
      if (*pack == NULL)
        {
          g_set_error(error, G_IO_CHANNEL_ERROR,
		      G_IO_CHANNEL_ERROR_FAILED,
		      "Channel read timed out");
          errno = ETIMEDOUT;
          return G_IO_STATUS_ERROR;
        }
    }
  return *pack == NULL ? G_IO_STATUS_AGAIN : G_IO_STATUS_NORMAL;
}

/**
 * z_stream_packet_recv:
 * @s: the #ZStreamPacket instance.
 *
 * Read a complete packet from the stream. If z_stream_read()'s cause
 * that that the next packet is partially read, return NULL.
 *
 * This function assumes that the stream is a
 * #ZPacketStream. Catastrophic things fill happen if this is not the
 * case. Beware!
 *
 * Returns: The #ZPacket read from the stream, or NULL if there was an
 * error (e.g. the stream is already closed)
 */
 
GIOStatus
z_stream_packet_recv(ZStream *s, ZPacket **pack, GError **error)
{
  ZStreamPacket *self = (ZStreamPacket *) s;  
  gint rc;
  
  z_enter();
  
  rc = z_stream_packet_fetch(self, pack, error);
  if ((rc == G_IO_STATUS_NORMAL || rc == G_IO_STATUS_EOF))
    {
      self->super.bytes_recvd += *pack ? (*pack)->length : 0;
      /*LOG
       */
      z_log(self->super.name, CORE_DEBUG, 7, "Receiving on channel; count='%d'", *pack ? (*pack)->length : 0);
      if ((*pack) && (*pack)->length > 0)
        if (z_log_enabled(CORE_DEBUG, 9))
          z_data_dump(self->super.name, (*pack)->data, (*pack)->length);
    }
  else if (rc != G_IO_STATUS_AGAIN)
    {
      /*LOG
       */
      z_log(self->super.name, CORE_DEBUG, 7, "Receiving on channel failed; error='%s'", (error != NULL) ? ((*error)->message) : ("Unknown error"));
    }
  z_leave();
  return rc;
}

/**
 * z_stream_packet_read_method:
 * @stream: stream instance
 * @count: number of bytes to read
 * @bytes_read: store bytes actually read here
 * @error: store error here
 *
 * read method implementation. Is able to handle packet fragmentation
 * that is caused by too small read's. Does not merge packets, one
 * call returns at most one packet.
 *
 * Returns: A #GIOStatus.
 */
 
static GIOStatus
z_stream_packet_read_method(ZStream *stream,
			    gchar *buf,
			    gsize count,
			    gsize *bytes_read,
			    GError **error)
{
  guint shift;
  ZPacket *pack;
  ZStreamPacket *self = (ZStreamPacket *) stream;
  gint rc;
  
  z_enter();
  g_return_val_if_fail((error == NULL) || (*error == NULL),
		       G_IO_STATUS_ERROR);

  if (self->partial_packet != NULL)
    {
      pack = self->partial_packet;
      shift = self->partial_read;

      g_assert(pack->length > shift);
    } 
  else 
    {
      shift = 0;
      
      rc = z_stream_packet_fetch(self, &pack, error);
      if (rc != G_IO_STATUS_NORMAL)
        {
          if (rc != G_IO_STATUS_EOF && rc != G_IO_STATUS_AGAIN)
	    /*LOG
	     */
            z_log(self->super.name, CORE_DEBUG, 7, "Reading from packet channel failed; code='%d', error='%s'", rc, (error != NULL) ? ((*error)->message) : ("Unknown error"));
          return rc;
        }
    }
  
  if (pack->length - shift <= count)
    {
      
      /* Partial packet completely read */
      
      *bytes_read = pack->length - shift;
      memcpy(buf, pack->data + shift, pack->length - shift);
      z_packet_free(pack);
      self->partial_packet = NULL;
      self->partial_read = 0;
    } 
  else 
    {
      
      /* Partial packet stays there */
      
      *bytes_read = count;
      memcpy(buf, pack->data + shift, count);
      self->partial_packet = pack;
      self->partial_read = count + shift;
    }

  self->super.bytes_recvd += *bytes_read;
  /*LOG
   */
  z_log(self->super.name, CORE_DEBUG, 7, "Reading from packet channel; count='%zd'", *bytes_read);
  if (z_log_enabled(CORE_DEBUG, 9))
    z_data_dump(self->super.name, buf, *bytes_read);

  z_leave();
  return G_IO_STATUS_NORMAL;
}


/**
 * z_stream_packet_shutdown_method:
 * @stream: stream instance
 * @i: shutdown direction
 * @error: unused
 *
 * shutdown method implementation. Other end is notified of the
 * shutdown and is shutdown'd according to @i paramter.  This
 * operation is idempotent. 
 *
 * Returns: %G_IO_STATUS_NORMAL.
 */

static GIOStatus
z_stream_packet_shutdown_method(ZStream *stream, int i, GError **error)
{
  ZStreamPacket *self = (ZStreamPacket *) stream;
  gboolean old_read;

  z_enter();
  old_read = self->can_read;
  switch (i)
    {
    case SHUT_RD:
      if (self->can_read)
        {
          self->can_read = FALSE;
          z_stream_shutdown(&self->other->super, SHUT_WR, error);
        }
      break;
    case SHUT_WR:
      if (self->can_write)
        {
          self->can_write = FALSE;
          z_stream_shutdown(&self->other->super, SHUT_RD, error);
        }
      break;
    case SHUT_RDWR:
      if (self->can_write || self->can_read)
        {
          self->can_write = self->can_read = FALSE;
          z_stream_shutdown(&self->other->super, SHUT_RDWR, error);
        }
      break;
    }
  if ((self->can_read != old_read) && 
      self->super.want_read && self->super.source != NULL)
    {
      /* if we shut down one of our directions, wake up poll loop */
      
      /* NOTE: we do not need to wakeup when the write side is shut, as
       * our stream is implicitly writable */
      g_main_context_wakeup(g_source_get_context(self->super.source));
    }
  z_leave();
  return G_IO_STATUS_NORMAL;
}

/**
 * z_stream_packet_close_method:
 * @stream: stream instance
 * @error: unused
 *
 * close method implementation. Other end is closed, too, and notified
 * of the close operation. This operation is idempotent.
 *
 * Returns: G_IO_STATUS_NORMAL, cannot fail.
 */

static GIOStatus
z_stream_packet_close_method(ZStream *stream,
			     GError **error)
{
  ZStreamPacket *self = (ZStreamPacket *) stream;
  gint rc;

  z_enter();
  rc = z_stream_shutdown(stream, SHUT_RDWR, error);
  z_stream_unref(&self->other->super);
  self->other = NULL;
  z_leave();
  return rc;
}

/**
 * z_stream_packet_free_method:
 * @self: #ZStreamPacket instance
 *
 * Free the stream. Frees every unread packet in the input queue.
 */
 
static void
z_stream_packet_free_method(ZObject *s)
{
  ZPacket *pack;
  ZStreamPacket *self = Z_CAST(s, ZStreamPacket);

  guint64 unread = 0;
  time_t time_close;

  z_enter();
  
  assert(!self->can_write && !self->can_read);
  
  time_close = time(NULL);

  while ((pack = (ZPacket *) g_async_queue_try_pop(self->queue)) != NULL) 
    {
      unread += pack->length;
      z_packet_free(pack);
    }

  if (self->partial_packet != NULL)
    {
      unread += self->partial_packet->length;
      z_packet_free(self->partial_packet);
    }
  if (self->other != NULL)
    {
      z_stream_unref(&self->other->super);
      self->other = NULL;
    }
  g_async_queue_unref(self->queue);
  /*LOG
   */
  z_log(self->super.name, CORE_ACCOUNTING, 5,
        "accounting info; duration='%d', sent='%" G_GUINT64_FORMAT "', received='%" G_GUINT64_FORMAT "', unread='%" G_GUINT64_FORMAT "'", 
        (int) difftime(time_close, self->super.time_open),
        self->super.bytes_sent,
        self->super.bytes_recvd,
	unread);

  z_stream_free_method(s);
  z_leave();
}

/**
 * z_stream_packet_pair_new:
 * @n: a human-readable name to identify the created streams
 * @stream1: first stream is stored here
 * @stream2: second stream is stored here
 *
 * Create a packet stream pair, much like socketpair(2). Data written
 * to one stream comes out of the other. Packet streams are called as
 * such because it is possible to read and write whole packets from/to
 * them, and so they are suitable for passing e.g. UDP packets inside
 * Zorp without consuming file descriptors.
 *
 * There are conventional read and write methods, and two functions
 * for reading and writing whole packets.
 *
 * If one end is shut down, the other end is shut down also, so as not
 * to be able to write packets into a stream that cannot be read back.
 */
 
void
z_stream_packet_pair_new(gchar *session_id, ZStream **stream1, ZStream **stream2)
{
  char buf[MAX_SESSION_ID];

  z_enter();

  g_snprintf(buf, sizeof(buf), "%s/L", session_id);
  *stream1 = z_stream_packet_new(buf);
  g_snprintf(buf, sizeof(buf), "%s/R", session_id);
  *stream2 = z_stream_packet_new(buf);

  z_stream_ref(*stream1);
  z_stream_ref(*stream2);
  
  ((ZStreamPacket *)(*stream1))->other = (ZStreamPacket *) *stream2;
  ((ZStreamPacket *)(*stream2))->other = (ZStreamPacket *) *stream1;
  
  z_leave();
}

static gboolean
z_stream_packet_ctrl_method(ZStream *s, guint function, gpointer value, guint vlen)
{

  z_enter();
  Z_CAST(s, ZStreamPacket);

  switch (ZST_CTRL_MSG(function))
    {
    case ZST_CTRL_SET_NONBLOCK:
      g_assert(vlen == sizeof(gboolean));
      break;
    case ZST_CTRL_GET_FD:
      g_assert(vlen == sizeof(gint));
      *((gint *)value) = -1;
      return TRUE;
      break;
    default:
      if (z_stream_ctrl_method(s, function, value, vlen))
        {
          return TRUE;
        }
      break;
    }
  z_leave();
  return FALSE;
}


ZStreamFuncs 
z_stream_packet_funcs =
{
  {
    Z_FUNCS_COUNT(ZStream),
    z_stream_packet_free_method,
  },
  z_stream_packet_read_method,
  z_stream_packet_write_method,
  NULL,		/* read_pri */
  NULL,		/* write_pri */
  z_stream_packet_shutdown_method,
  z_stream_packet_close_method,
  z_stream_packet_ctrl_method,
  z_stream_packet_attach_source_method,
  z_stream_packet_detach_source_method,
  z_stream_packet_watch_prepare,
  z_stream_packet_watch_check,
  z_stream_packet_watch_dispatch,
  NULL,
  NULL,
  NULL,
  NULL
};

ZClass ZStreamPacket__class =
{
  Z_CLASS_HEADER,
  &ZStream__class,
  "ZStreamPacket",
  sizeof(ZStreamPacket),
  &z_stream_packet_funcs.super,
};

ZStream *
z_stream_packet_new(gchar *name)
{
  ZStreamPacket *self;

  z_enter();
  self = Z_CAST(z_stream_new(Z_CLASS(ZStreamPacket), name, NULL, Z_STREAM_FLAG_READ|Z_STREAM_FLAG_WRITE), ZStreamPacket);

  self->queue = g_async_queue_new();
  self->can_read = self->can_write = TRUE;
  self->partial_packet = NULL;
  self->partial_read = 0;
  
  z_leave();
  return (ZStream *) self;
}

