/*
    MiddleMan filtering proxy server
    Copyright (C) 2002-2004  Jason McLaughlin
    Copyright (C) 2003  Riadh Elloumi

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*/

#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <stdarg.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/time.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include "../libntlm/ntlm.h"
#include "proto.h"

extern TemplateSection *template_section;
extern RewriteSection *rewrite_section;
extern ExternalSection *external_section;
extern KeywordSection *keyword_section;
extern GeneralSection *general_section;
extern CacheSection *cache_section;
extern AntivirusSection *antivirus_section;
extern GLOBAL *global;

int protocol_start(CONNECTION * connection)
{
	int proxyfd, x;
	Socket *sock;

	connection->keepalive_server = FALSE;

	if (connection->header->type != HTTP_CONNECT) {
		sock = pool_find(global->pool, connection->header->proto, (connection->proxy_type == PROXY_NORMAL && connection->proxy_host != NULL) ? connection->proxy_host : connection->header->host, (connection->proxy_type == PROXY_NORMAL) ? connection->proxy_port : connection->header->port, connection->header->username, connection->header->password);
		if (sock != NULL) {
			connection->server = sock;
			connection->keepalive_server = TRUE;
			connection->flags |= CONNECTION_LOGGEDIN;

			return sock->fd;
		}
	}

	switch (connection->proxy_type) {
	case PROXY_NORMAL:
		proxyfd = net_connect(connection->proxy_host, connection->proxy_port, general_section->ctimeout_get());
		if (proxyfd < 0)
			putlog(MMLOG_NETWORK, "connection to proxy server failed");
		else
			connection->server = xnew Socket(proxyfd);

		break;
	case PROXY_SOCKS4:
		proxyfd = net_connect(connection->proxy_host, connection->proxy_port, general_section->ctimeout_get());

		if (proxyfd >= 0) {
			connection->server = xnew Socket(proxyfd);

			x = proxy_socks4(connection);
			if (x < 0) {
				putlog(MMLOG_NETWORK, "connection to socks4 server failed");

				xdelete connection->server;
				connection->server = NULL;

				return -1;
			}
		}

		break;
	case PROXY_CONNECT:
		proxyfd = net_connect(connection->proxy_host, connection->proxy_port, general_section->ctimeout_get());
		if (proxyfd >= 0) {
			connection->server = xnew Socket(proxyfd);

			x = proxy_connect(connection);
			if (x < 0) {
				putlog(MMLOG_NETWORK, "connection to proxy failed");

				xdelete connection->server;
				connection->server = NULL;

				return -1;
			}

		}

		break;
	default:
		proxyfd = net_connect(connection->header->host, connection->header->port, general_section->ctimeout_get());
		if (proxyfd >= 0)
			connection->server = xnew Socket(proxyfd);

		break;
	}

	if (connection->flags & CONNECTION_SSLCLIENT) {
		if (!connection->server->Encrypt(Socket::SOCK_SSLCLIENT)) {
			xdelete connection->server;

			proxyfd = -1;
		}
	}

	if (proxyfd < 0) {
		connection->server = NULL;

		return proxyfd;
	}


	return proxyfd;
}

int protocol_reconnect(CONNECTION * connection)
{
	xdelete connection->server;
	if (connection->client != NULL)
		connection->client->Flush();

	return protocol_start(connection);
}


int proxy_authenticate(CONNECTION * connection)
{
	if (connection->rheader->proxy_authenticate == NULL)
		return FALSE;

	if (!strcasecmp(connection->rheader->proxy_authenticate, "NTLM")) {
		/* NTLM authentication */
#ifdef ENABLE_NTLM
		return send_ntlm_response(connection);
#endif
	} else if (!strncasecmp(connection->rheader->proxy_authenticate, "Basic", 5)) {
		/* Basic authentication */
		return send_basic_response(connection);
	}

	return FALSE;
}

#ifdef ENABLE_NTLM
/* perform an NTLM handshake, the sequence of events looks like this:
    1: C -->  S   GET ...

    2: C <--  S   401 Unauthorized
                  Proxy-Authenticate: NTLM
    
    3: C  --> S   GET ...
                  Proxy-Authorization: NTLM <base64-encoded type-1-message>
    
    4: C <--  S   401 Unauthorized
                  Proxy-Authenticate: NTLM <base64-encoded type-2-message>
    
    5: C  --> S   GET ...
                  Proxy-Authorization: NTLM <base64-encoded type-3-message>
    
    6: C <--  S   200 Ok
*/
int send_ntlm_response(CONNECTION * connection)
{
	int ret, oldkeep, oldcontentlength, oldflags;
	unsigned char buf[4096], buf2[4096], *headbuf, *oldmethod;
	static pthread_mutex_t ntlmlock = PTHREAD_MUTEX_INITIALIZER;
	tSmbNtlmAuthRequest request;
	tSmbNtlmAuthChallenge challenge;
	tSmbNtlmAuthResponse response;
	HEADER *header;

	/* client must do keep-alive during authentication */
	oldkeep = connection->keepalive_server;
	connection->keepalive_server = TRUE;

	/* must use GET method for handshake */
	oldmethod = (unsigned char *) connection->header->method;
	connection->header->method = xstrdup("GET");
	oldcontentlength = connection->header->content_length;
	oldflags = connection->header->flags;
	connection->header->flags &= ~HEADER_CL;

	/* don't need the body of the 407 message */
	http_transfer_discard(connection, SERVER);

	/* send type-1 message */
	/* note: buildSmbNtml* knows how to deal with NULL arguments */
	pthread_mutex_lock(&ntlmlock);
	buildSmbNtlmAuthRequest(&request, connection->proxy_username, connection->proxy_domain);
	pthread_mutex_unlock(&ntlmlock);

	to64frombits(buf, (unsigned char *) &request, SmbLength(&request));

	FREE_AND_NULL(connection->proxy_auth);
	snprintf((char *) buf2, sizeof(buf2), "NTLM %s", (char *) buf);
	connection->proxy_auth = xstrdup((char *) buf2);

	header_send(connection->header, connection, SERVER, HEADER_FORWARD);

	headbuf = (unsigned char *) header_get(connection, SERVER, general_section->timeout_get());
	if (headbuf == NULL) {
		/* squid and maybe some other proxies will disconnect after sending
		   the first 407 message regardless of Connection header */
		do {
			ret = protocol_reconnect(connection);
			if (ret == -1)
				goto error;

			header_send(connection->header, connection, SERVER, HEADER_FORWARD);

			headbuf = (unsigned char *) header_get(connection, SERVER, general_section->timeout_get());
		} while (headbuf == NULL && connection->keepalive_server == TRUE);

		if (headbuf == NULL)
			goto error;
	}

	header = http_header_parse_response((char *) headbuf);
	xfree(headbuf);
	if (header == NULL)
		goto error;

	http_header_free(connection->rheader);
	connection->rheader = header;

	if (header->proxy_authenticate == NULL || strncasecmp(header->proxy_authenticate, "NTLM ", 5))
		goto error;

	http_transfer_discard(connection, SERVER);

	/* parse type-2 message */
	from64tobits((char *) &challenge, &header->proxy_authenticate[5]);

	pthread_mutex_lock(&ntlmlock);
	buildSmbNtlmAuthResponse(&challenge, &response, connection->proxy_username, connection->proxy_password);
	pthread_mutex_unlock(&ntlmlock);

	/* send type-3 message */
	to64frombits(buf, (unsigned char *) &response, SmbLength(&response));
	snprintf((char *) buf2, sizeof(buf2), "NTLM %s", buf);

	FREE_AND_NULL(connection->proxy_auth);
	connection->proxy_auth = xstrdup((char *) buf2);

	connection->keepalive_server = oldkeep;
	xfree(connection->header->method);
	connection->header->method = (char *) oldmethod;
	connection->header->flags = oldflags;
	connection->header->content_length = oldcontentlength;


	header_send(connection->header, connection, SERVER, HEADER_FORWARD);

	FREE_AND_NULL(connection->proxy_auth);

	return TRUE;

      error:
	connection->keepalive_server = oldkeep;
	xfree(connection->header->method);
	connection->header->method = (char *) oldmethod;
	connection->header->flags = oldflags;
	connection->header->content_length = oldcontentlength;

	return FALSE;
}

#endif				/* ENABLE_NTLM */

/* basic proxy authentication */
int send_basic_response(CONNECTION * connection)
{
	char buf[4096], buf2[4096];

	/* discard body of 407 authentication required message */
	http_transfer_discard(connection, SERVER);

	snprintf(buf, sizeof(buf), "%s:%s", (connection->proxy_username != NULL) ? connection->proxy_username : "", (connection->proxy_password != NULL) ? connection->proxy_password : "");
	to64frombits((unsigned char *) buf2, (unsigned char *) buf, strlen(buf));

	snprintf(buf, sizeof(buf), "Basic %s", buf2);

	FREE_AND_NULL(connection->proxy_auth);
	connection->proxy_auth = xstrdup(buf);

	header_send(connection->header, connection, SERVER, HEADER_FORWARD);

	FREE_AND_NULL(connection->proxy_auth);

	return TRUE;
}

void keepalive_check(CONNECTION * connection)
{
	if (!(connection->rheader->flags & HEADER_CL) && !connection->rheader->chunked) {
		/* connection can't be kept alive if we can't determine when the remote site's data ends */
		connection->keepalive_client = FALSE;
		connection->keepalive_server = FALSE;
	} else if (connection->rheader->keepalive == TRUE || connection->rheader->proxy_keepalive == TRUE)
		connection->keepalive_server = TRUE;
	else if (connection->rheader->version == HTTP_HTTP11 && (connection->rheader->keepalive != FALSE && connection->rheader->proxy_keepalive != FALSE))
		/* keep-alive is assumed for HTTP/1.1 unless any of the above conditions are met */
		connection->keepalive_server = TRUE;
	else
		connection->keepalive_server = FALSE;

	putlog(MMLOG_DEBUG, "keepalive_server = %d", connection->keepalive_server);
}

int buffer_check(CONNECTION * connection)
{
	int encode = TRUE;

	if (connection->flags & CONNECTION_PREFETCH) {
		if (connection->htmlstream == NULL || connection->rheader->content_encoding == NULL)
			return FALSE;
		else
			return TRUE;
	}

	if (connection->rheader->content_encoding != NULL && accepts_encoded_content(connection) == FALSE)
		return TRUE;

	/* buffer chunked responses for HTTP1.0 clients */
	if (connection->rheader->chunked == TRUE && connection->header->version == HTTP_HTTP10)
		return TRUE;

	/* check whether or not this transfer should be buffered first */
	if (rewrite_section->rewrite_do(connection, NULL, REWRITE_BODY, FALSE))
		return TRUE;

	if (external_section->check_match(connection))
		return TRUE;

	if (antivirus_section->check_match(connection))
		return TRUE;

	if (keyword_section->check(connection, NULL, FALSE))
		return TRUE;

	if (url_command_find(connection->url_command, "score"))
		return TRUE;

	/* this will avoid the overhead of checking the lists again, since we've already determined none will match. */
	connection->flags &= ~CONNECTION_PROCESS;

	if (url_command_find(connection->url_command, "raw"))
		return TRUE;

	if (url_command_find(connection->url_command, "htmltree"))
		return TRUE;

	/* stream parser can't understand any encodings for now */
	if (connection->rheader->content_encoding != NULL && connection->htmlstream != NULL)
		return TRUE;

	if (general_section->compressout_get() == TRUE && connection->header->accept_encoding != NULL && connection->rheader->content_type != NULL && accepts_encoded_content(connection) && connection->rheader->content_encoding == NULL) {
		general_section->read_lock();
		if (general_section->emp != NULL)
			encode = reg_exec(general_section->emp, connection->rheader->content_type);
		general_section->unlock();

		if (!encode)
			return TRUE;
	}

	if (connection->cachemap == NULL) {
		/* ICAP: check if any respmod entries match, and return TRUE if they do.
		   we don't need to if cachemap != NULL, see respmod postcache processing point. */
	}

	return FALSE;
}

int transfer_limit_check(CONNECTION * connection)
{
	putlog(MMLOG_DEBUG, "transfer limit: %u cl: %u", connection->transferlimit, connection->rheader->content_length);
	if ((connection->rheader->flags & HEADER_CL) && (connection->cachemap == NULL || (connection->flags & CONNECTION_LIMITCACHE)) && (connection->rheader->content_length > connection->transferlimit)) {
		putlog(MMLOG_LIMITS, "transfer would exceed transfer limit");

		template_section->send("maxtransfer", connection, 503);

		return FALSE;
	}

	return TRUE;
}

void proxy_test(CONNECTION * connection)
{
	int fd, listenfd, x;
	char *headbuf, ip[16];
	struct sockaddr_in saddr;
	struct pollfd pfd;
	socklen_t sl = sizeof(struct sockaddr_in);
	Socket *server;

	/* bind on same interface we're connecting to the forwarding proxy on */
	getsockname(connection->server->fd, (struct sockaddr *) &saddr, &sl);
	saddr.sin_port = 0;

	listenfd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
	if (listenfd == -1)
		return;

	x = bind(listenfd, (struct sockaddr *) &saddr, sl);
	if (x == -1) {
		close(listenfd);
		return;
	}

	x = listen(listenfd, 1);
	if (x == -1) {
		close(listenfd);
		return;
	}

	/* get the port number the kernel chose for us */
	getsockname(listenfd, (struct sockaddr *) &saddr, &sl);

	inet_ntop(AF_INET, &saddr.sin_addr, ip, sizeof(ip));
	FREE_AND_STRDUP(connection->header->host, ip);
	FREE_AND_STRDUP(connection->header->file, "/");
	connection->header->port = ntohs(saddr.sin_port);

	header_send(connection->header, connection, SERVER, HEADER_FORWARD);

	/* ok.. if all went well, the proxy should be connecting back to us */
	pfd.fd = listenfd;
	pfd.events = POLLIN;

	x = p_poll(&pfd, 1, general_section->timeout_get() * 1000);
	if (x <= 0) {
		close(listenfd);

		return;
	} else {
		fd = accept(listenfd, (struct sockaddr *) &saddr, &sl);
		close(listenfd);

		if (fd == -1)
			return;

		server = connection->server;
		connection->server = xnew Socket(fd);

		headbuf = header_get(connection, SERVER, general_section->timeout_get());
		if (headbuf == NULL)
			goto out;

		connection->rheader = http_header_parse_request(headbuf);
		xfree(headbuf);
	}

      out:
	xdelete connection->server;
	connection->server = server;

	connection->keepalive_server = FALSE;
}

int accepts_encoded_content(CONNECTION * connection)
{
	if (!(connection->flags & CONNECTION_ENCODED) && connection->header->accept_encoding != NULL) {
		if (s_strcasestr(connection->header->accept_encoding, "gzip") != NULL || s_strcasestr(connection->header->accept_encoding, "deflate") != NULL)
			return TRUE;
	}

	putlog(MMLOG_DEBUG, "client doesn't support encoded content");

	return FALSE;
}


char *header_get_reconnect(CONNECTION * connection, int timeout)
{
	int ret;
	char *headbuf;

	headbuf = header_get(connection, SERVER, timeout);
	if (headbuf == NULL) {
		if (!connection->keepalive_server && !(connection->flags & CONNECTION_AUTHENTICATED)) {
			template_section->send("noconnect", connection, 503);

			return NULL;
		} else {
			do {
				ret = protocol_reconnect(connection);

				if (ret < 0) {
					template_section->send(error_to_template(ret), connection, 503);

					return NULL;
				}

				header_send(connection->header, connection, SERVER, (connection->proxy_type == PROXY_NORMAL) ? HEADER_FORWARD : HEADER_DIRECT);
				if (connection->postbody != NULL)
					net_filebuf_send(connection->postbody, connection, SERVER);

				headbuf = header_get(connection, SERVER, timeout);
			} while (headbuf == NULL && connection->keepalive_server == TRUE);

			if (headbuf == NULL) {
				template_section->send("noconnect", connection, 503);

				return NULL;
			}
		}
	}

	return headbuf;
}
