/*
    BFilter - a smart ad-filtering web proxy
    Copyright (C) 2002-2006  Joseph Artsimovich <joseph_a@mail.ru>

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*/

#include "pch.h"

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include "Socks5Authenticator.h"
#include "Reactor.h"
#include <algorithm>
#include <cstdlib>
#include <cassert>
#include <stddef.h>

using namespace std;

static unsigned char const SOCKS_VERSION_5 = 0x05;
static unsigned char const SOCKS_AUTH_NONE = 0x00;
static unsigned char const SOCKS_AUTH_UNAME_PASS = 0x02;
static unsigned char const SOCKS_AUTH_NO_ACCEPTABLE_METHODS = 0xff;
static unsigned char const SOCKS_AUTH_VERSION_1 = 0x01;
static unsigned char const SOCKS_AUTH_SUCCESS = 0x00;

unsigned char const Socks5Authenticator::m_sMsgMethods1[] = {
	SOCKS_VERSION_5, 1, SOCKS_AUTH_NONE
};

unsigned char const Socks5Authenticator::m_sMsgMethods2[] = {
	SOCKS_VERSION_5, 2, SOCKS_AUTH_NONE, SOCKS_AUTH_UNAME_PASS
};

Socks5Authenticator::Socks5Authenticator()
:	m_state(ST_INACTIVE)
{
}

Socks5Authenticator::~Socks5Authenticator()
{
}

void
Socks5Authenticator::startAuthentication(
	Listener& listener, Reactor& reactor, ACE_HANDLE handle,
	std::string const& username, std::string const& password)
{
	abort();
	createAuthMsg(username, password).swap(m_msgAuth);
	m_observerLink.setObserver(&listener);
	m_state = ST_SENDING_METHODS; // must be before startWriting()
	m_readerWriter.activate(*this, reactor, handle);
	if (username.empty() && password.empty()) {
		m_isAuthProvided = 0;
		m_readerWriter.startWriting(m_sMsgMethods1, sizeof(m_sMsgMethods1));
	} else {
		m_isAuthProvided = 1;
		m_readerWriter.startWriting(m_sMsgMethods2, sizeof(m_sMsgMethods2));
	}
}

void
Socks5Authenticator::abort()
{
	m_readerWriter.deactivate();
	m_observerLink.setObserver(0);
	std::vector<unsigned char>().swap(m_msgAuth);
	m_state = ST_INACTIVE;
}

void
Socks5Authenticator::onReadDone()
{
	if (m_state == ST_RECEIVING_METHOD) {
		onMethodReceived();
	} else if (m_state == ST_RECEIVING_STATUS) {
		onStatusReceived();
	} else {
		assert(0 && "should not happen");
	}
}

void
Socks5Authenticator::onReadError()
{
	handleAuthFailure(SocksError::CONNECTION_CLOSED);
}

void
Socks5Authenticator::onWriteDone()
{
	if (m_state == ST_SENDING_METHODS) {
		m_state = ST_RECEIVING_METHOD; // must be before startReading()
		m_readerWriter.startReading(m_recvBuf, 2);
	} else if (m_state == ST_SENDING_AUTH) {
		m_state = ST_RECEIVING_STATUS; // must be before startReading()
		m_readerWriter.startReading(m_recvBuf, 2);
	} else {
		assert(0 && "should not happen");
	}
}

void
Socks5Authenticator::onWriteError()
{
	handleAuthFailure(SocksError::CONNECTION_CLOSED);
}

void
Socks5Authenticator::onGenericError()
{
	handleAuthFailure(SocksError::GENERIC_ERROR);
}

void
Socks5Authenticator::onMethodReceived()
{
	if (m_recvBuf[0] != SOCKS_VERSION_5) {
		handleAuthFailure(SocksError::PROTOCOL_VIOLATION);
		return;
	}
	
	if (m_recvBuf[1] == SOCKS_AUTH_NONE) {
		handleAuthSuccess();
		return;
	}
	
	if (m_recvBuf[1] == SOCKS_AUTH_UNAME_PASS) {
		m_state = ST_SENDING_AUTH; // must be before startWriting()
		m_readerWriter.startWriting(
			&m_msgAuth[0], m_msgAuth.size()
		);
		return;
	}
	
	if (m_recvBuf[1] == SOCKS_AUTH_NO_ACCEPTABLE_METHODS) {
		if (m_isAuthProvided) {
			handleAuthFailure(SocksError::UNSUPPORTED_AUTH_METHOD);
		} else {
			handleAuthFailure(SocksError::AUTH_REQUIRED);
		}
		return;
	}
	
	handleAuthFailure(SocksError::PROTOCOL_VIOLATION);
}

void
Socks5Authenticator::onStatusReceived()
{
	if (m_recvBuf[0] != SOCKS_AUTH_VERSION_1) {
		handleAuthFailure(SocksError::PROTOCOL_VIOLATION);
		return;
	}
	if (m_recvBuf[1] != SOCKS_AUTH_SUCCESS) {
		handleAuthFailure(SocksError::AUTH_FAILURE);
		return;
	}
	handleAuthSuccess();
}

void
Socks5Authenticator::handleAuthFailure(SocksError::Code code)
{
	Listener* listener = m_observerLink.getObserver();
	abort(); // this will detach the listener
	if (listener) {
		listener->onAuthFailure(SocksError(code));
	}
}

void
Socks5Authenticator::handleAuthSuccess()
{
	Listener* listener = m_observerLink.getObserver();
	abort(); // this will detach the listener
	if (listener) {
		listener->onAuthSuccess();
	}
}

std::vector<unsigned char>
Socks5Authenticator::createAuthMsg(
	std::string const& username, std::string const& password)
{
	size_t const username_size = std::min<size_t>(username.size(), 255);
	size_t const password_size = std::min<size_t>(password.size(), 255);
	vector<unsigned char> vec;
	vec.resize(3 + username_size + password_size);
	unsigned char* ptr = &vec[0];
	*ptr++ = SOCKS_AUTH_VERSION_1;
	*ptr++ = username_size;
	memcpy(ptr, username.c_str(), username_size);
	ptr += username_size;
	*ptr++ = password_size;
	memcpy(ptr, password.c_str(), password_size);
	assert(ptr+password_size == &vec[0]+vec.size());
	return vec;
}
