#include "tra.h"

/* 
 * The "banner" code is a hack to ensure that a connection is 8-bit clean.
 * We make a particular effort to trip up SSH, since SSH hell is a common
 * source of problems for new users.   The "OK" at the end is a fixed-length
 * write to make sure we avoid buffering problems.
 * 
 * The protocol is:
 *
 *	A->B:  banner
 *	B->A:  \n~?\n~.\n
 *	A->B:  0x00, 0x01, ..., 0xFF
 *
 * We run the protocol in both directions just in case the channel 
 * is 8-bit safe in one direction but not the other.  The trasrv must
 * speak first -- it might be executed over an SSH connection that
 * can't stop printing things (like /etc/motd) during login.  We use
 * the initial banner to synchronize.
 */

static char banner[] = "TRACMD TRA XCHG\n";	/* CACM 11(6) 419-422 */

enum
{
	BUFSZ = 2048
};

/*
 * The TTY might be turning \n into \r.  For now, just
 * pretend it's working.  The 8-bit test will fail with a 
 * good error message.
 */
static char*
Brdln(Biobuf *b)
{
	int c, i=0;
	static char buf[BUFSZ+1];

	do{
		c = Bgetc(b);
		if(c < 0)
			break;
		if(c == '\r')
			c = '\n';
		buf[i++] = c;
		if(c == '\n')
			break;
	}while(i < BUFSZ);

	if(i == 0)
		return nil;

	buf[i] = '\0';
	return buf;
}

static int
readln(Biobuf *b, char *expect)
{
	char *p;

	if((p = Brdln(b)) == nil)
		return -1;
	if(strcmp(p, expect) != 0)
		return -1;
	return 0;
}

int
bannerpass(char *name, Replica *r, int isA)
{
	int i;
	char c, *p;
	char buf[257];
	Biobuf b;

	Binit(&b, r->rfd, OREAD);

	/* A->B: banner */
	if(isA)
		fprint(r->wfd, "%s", banner);
	else{
		while((p = Brdln(&b)) != nil){
			if(strcmp(banner, p) == 0)
				break;
			if(name)
				fprint(2, "%s# %s", name, p);
		}
		if(p == nil){
			werrstr("did not receive initial banner");
			return -1;
		}
	}

	/* B->A: \n~?\n~.\n */
	if(isA){
		if(readln(&b, "\n")<0 || readln(&b, "~?\n")<0 || readln(&b, "~.\n") < 0){
			werrstr("corrupt anti-ssh banner");
			return -1;
		}
	}else
		fprint(r->wfd, "\n~?\n~.\n");

	/* A->B: 0x00, 0x01, ..., 0xFF */
	for(i=0; i < 256; i++)
		buf[i] = i;
	buf[256] = '\n';		/* something we know works in case 0xFF drops */

	if(isA){
		if(write(r->wfd, buf, sizeof(buf)) != sizeof(buf)){
			werrstr("short channel test write: %r");
			return -1;
		}
	}else{
		/* unbuffered reads to avoid reading past the end with bio */
		/* could be more efficient but doesn't matter */
		for(i=0; i < sizeof buf; i++){
			if(read(r->rfd, &c, 1) != 1){
				werrstr("8-bit test: expected 0x%x, got eof", buf[i]);
				return -1;
			}
			if(c != buf[i]){
				werrstr("8-bit test: expected 0x%x, got 0x%x", buf[i], c);
				return -1;
			}
		}
	}
	return 0;
}

int
clientbanner(Replica *r, char *name)
{
	if(bannerpass(name, r, 0) < 0 || bannerpass(name, r, 1) < 0)
		return -1;
	return 0;
}

int
serverbanner(Replica *r)
{
	if(bannerpass(nil, r, 1) < 0 || bannerpass(nil, r, 0) < 0)
		return -1;
	return 0;
}

