/* utils.c General utility functions 
   Copyright Beau Kuiper 1999.

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2, or (at your option)
   any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software
   Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.  */

#include "ftpd.h"

//#define SHOWMALLOC

/* Wapper for malloc that syslogs and then dies if malloc fails before
   damage is done. */ 
   
int nummalloc = 0;
int doesdint = -1;
/* 			memory routines 			       *
 ***********************************************************************/
   
void *mallocwrapper(int size)
{
	void *outmem;

	outmem = (void *)malloc(size);

#ifdef SHOWMALLOC
	nummalloc++;
	printf("malloc: %d\n", nummalloc);
#endif
	
	if (outmem == NULL)
		ERRORMSGFATAL("malloc error, out of memory");
		
	return(outmem);
}

/* Wapper for remalloc that syslogs and then dies if malloc fails before
   damage is done */

void reallocwrapper(int size, void **inarea)
{
	void *outmem;

	assert(size > 0);
	
	if (*inarea == NULL)
	 	outmem = mallocwrapper(size);
	else
		outmem = (void *)realloc(*inarea, size);
	
	if (outmem == NULL)
	{
		ERRORMSG("realloc error, out of memory");
		exit(1);
	}

#if DEBUG
	if (*inarea == outmem)
	{
		outmem = mallocwrapper(size);  /* make sure memory is moved */
		memmove(outmem, *inarea, size);
		freewrapper(*inarea);     /* free the original area */
	}
#endif
	*inarea = outmem;
}
   
char *strdupwrapper(char *s)
{
	char *outstr;
	
	assert(s != NULL);
	
	outstr = malloc(strlen(s) + 2);
	
	if (outstr == NULL)
		ERRORMSGFATAL("strdup error, out of memory");

#ifdef SHOWMALLOC
	nummalloc++;
	printf("malloc: %d\n", nummalloc);
#endif
	
	strcpy(outstr, s);
	
	return(outstr);
}

void freewrapper(void *tofree)
{
	if (tofree == NULL)
	{
		ERRORMSG("Trying to free null pointer!");
	}

	free(tofree);
	
#ifdef SHOWMALLOC
	printf("free: %d\n", nummalloc);
	nummalloc--;
#endif
}

void freeifnotnull(void *tofree)
{
	if (tofree != NULL)
	{
		free(tofree);

#ifdef SHOWMALLOC
	printf("free: %d\n", nummalloc);
	nummalloc--;
#endif
	}
}

#ifndef HAVE_MEMMOVE

void *memmove(void *dest, void *src, int n)
{
	int p;
	
	assert(dest != NULL);
	assert(src != NULL);
	assert(n >= 0);
	
	if (src == dest)
		return(dest);
	
	if (src < dest)
		for(p = 0; p < n; p++)
			*(char *)(src + n) = *(char *)(dest + n);
	else
		for(p = n; p > 0; p--)
			*(char *)(src + n) = *(char *)(dest + n);	

	return(dest);
}

#endif

/* 			string routines 			       *
 ***********************************************************************/

void strtrimspace(char *string)
{
	int pos, pos2;
	
	pos2 = 0;
	pos = 0;
	while(string[pos] != 0)
	{
		if (string[pos] > 32)
		{
			string[pos2] = string[pos];
			pos2++;
		}
		pos++;
	}
	string[pos2] = 0;
}

int strchrcount(char *string, char tok)
{
	char *lastoccur = string;
	int ret = 0;
	
	while((lastoccur = strchr(lastoccur, tok)) != NULL)
	{
		lastoccur++;
		ret++;
	}
	return(ret);
}

char *safe_vsnprintf(int size, char *format, va_list ap)
{
	char *buffer = mallocwrapper(size+1);
	int result;
	
	result = vsnprintf(buffer, size+1, format, ap);
	return(buffer);
}

char *safe_snprintf(char *format, ...)
{
	va_list printfargs;
	int size = BUFFSMALL;
	char *buf = mallocwrapper(size);
	int result;

	va_start(printfargs, format);
	result = vsnprintf(buf, size, format, printfargs);
	va_end(printfargs);
	
	if (result >= size)
	{
		reallocwrapper(result+1, (void *)&buf);
		va_start(printfargs, format);
		result = vsnprintf(buf, result+1, format, printfargs);
		va_end(printfargs);
	}
	return(buf);
}

char *getcwd2(void)
{
	int size = BUFFSMALL;
	char *buffer = mallocwrapper(size);
	char *result;
	
	do
	{
		result = getcwd(buffer, size-1);
	
		if (result == NULL)
		{
			if (errno == ERANGE)
			{
				if (size > (64 * 1024))
					size += (64 * 1024);
				else
					size += size;
				reallocwrapper(size, (void *)&buffer);
			}
			else
			{
				freewrapper(buffer);
				return(NULL);
			}
		}
	} while (result == NULL);
	
	return(buffer);
}

void pathname_simplify(char *pathname)
{
	int pos, pos2;
	
	pos = 0, pos2 = 0;
	
	while(pathname[pos] != 0)
	{
		switch(pathname[pos])
		{
			case '/':	
				if (pathname[pos+1] == '/');
				else if ((strncmp(pathname + pos, "/./", 3) == 0) ||
					 (strcmp(pathname + pos, "/.") == 0))
					pos++;
				else if ((strncmp(pathname + pos, "/../", 4) == 0) ||
					 (strcmp(pathname + pos, "/..") == 0))
				{
					if (pos2 > 1)
						pos2--;
					while((pos2 > 1) && (pathname[pos2] != '/'))
						pos2--;
					pos += 2;
				}
				else
				{
					pathname[pos2] = '/';
					pos2++;			
				}
				break;		
			default:
				pathname[pos2] = pathname[pos];
			pos2++;
		}
		pos++;
	}
	
	if (pos2 == 0)
	{
		pathname[pos2] = '/';
		pos2++;
	}
	
	pathname[pos2] = 0;
}	

/* This is meant to be quick */

char *offt_tostr(off_t size)
{
	/* the divide should be optimized away automagicly */
	static char sizestr[sizeof(off_t) * 4];
	char *l = sizestr + (sizeof(off_t) * 4 - 1);
	off_t esize = size;
	
	*l = 0;
	l--;
	
	do
	{
		*l = '0' + (esize % 10);
		l--;
		esize /= 10;
	}
	while(esize > 0);

	if (size < 0)
	{
		*l = '-';
		l--;
	}
	return(l + 1);
}

int strto_offt(char *str, off_t *ret)
{
	off_t r;
	char *l = str;
	int neg = (*l == '-');
	
	r = 0;
	if (*l == '-')
		l++;
	
	while((*l >= '0') && (*l <= '9'))
	{
		r = r * 10 + (*l - '0');
		l++;
	}
	// error if not end of string
	if (*l != 0)
		return(-1);
	
	*ret = (neg ? -r : r);
	return(0);
}
					
void test_libc(int verbose)
{
	char instr[15];
	int result;
	long long a;
	
	if (verbose)
		printf("Testing libraries!!\n");
	
	/* test snprintf schematics */
	
	if (verbose)
		printf("Now, does snprintf work  . . . . . . . . ");
	
	memset(instr, 0, 15);
	/* this will overflow the string */
	result = snprintf(instr, 10, "test %s", "testing");
	if (instr[10] != 0)
	{
		if (verbose) printf("No\n");
		ERRORMSGFATAL("snprintf is broken! Cannot continue!");
	}
	if (verbose)
		printf("Yes\n");
		
	/* testing double int support */
	
	if (verbose)
		printf("Now, can I use double ints . . . . . . . ");

	a = 1000000000000LL;
	sprintf(instr, "%lld", a);
	sscanf(instr, "%lld", &a);
	
	if ((strcmp(instr, "1000000000000") == 0) && (a == 1000000000000LL))
	{
		if (verbose) printf("Yes\n");
		doesdint = TRUE;
	}
	else
	{
		if (verbose) printf("No, ratio support disabled\n");
		doesdint = FALSE;
	}
	
	if ((!doesdint) && (sizeof(long long) == sizeof(off_t)))
		ERRORMSGFATAL("snprintf doesn't support double int but needed to do file lengths");
}
                                                
#ifndef HAVE_USLEEP

int usleep(int usecs)
{
	struct timeval tv;
	
	tv.tv_sec = usecs / 1000000;
	tv.tv_usec = usecs % 1000000;
	
	return(select(0, NULL, NULL, NULL, &tv));
}
 
#endif

/*			hash functions				       *
 ***********************************************************************/
 


/* 			string cache routines 			       *
 ***********************************************************************/

STRCACHE *strcache_new(void)
{
	STRCACHE *cache = mallocwrapper(sizeof(STRCACHE));
	
	memset(cache, 0, sizeof(STRCACHE));
	return(cache);
}	

char *strcache_check(STRCACHE *cache, int num)
{
	int counter = 0;
	
	while (counter < cache->size)
	{
		if (cache->data[counter].num == num)
			return(cache->data[counter].str);
		counter++;
	}
	
	return(NULL);
}

void strcache_add(STRCACHE *cache, int num, char *str)
{
	if (cache->size < STRCACHESIZE)
	{
		cache->data[cache->size].num = num;
		cache->data[cache->size].str = strdupwrapper(str);
		(cache->size)++;
	}
}

void strcache_free(STRCACHE *cache)
{
	int counter = 0;
	
	while (counter < cache->size)
	{
		freewrapper(cache->data[counter].str);
		counter++;
	}
	freewrapper(cache);
}


/* 			token set routines 			       *
 ***********************************************************************/

TOKENSET *tokenset_new(void)
{
	TOKENSET *newset = mallocwrapper(sizeof (TOKENSET));

	memset(newset, 0, sizeof(TOKENSET));
	return(newset);
}

void tokenset_settoken(TOKENSET *tset, unsigned char tok, char *data)
{
	int token = (int)tok;
	
	assert(token < 128);
	assert(tset != NULL);

/*	printf("adding a token %d %s\n", token, data); */
		
	if(tset->params[token] != NULL)
		freewrapper(tset->params[token]);
	
	tset->params[token] = data;
}

void tokenset_deltoken(TOKENSET *tset, unsigned char tok)
{
	int token = (int)tok;

	assert(token < 128);
	assert(tset != NULL);

	if (tset->params[token] != NULL)
		freewrapper(tset->params[token]);
	
	tset->params[token] = NULL;
}

char *tokenset_apply(TOKENSET *tok, char *inputstr, int escapecookies)
{
	int instrlen;
	char *outstring;
	int inpos = 0;
	int outpos = 0;
	int outstrlen;
	
	if (inputstr == NULL)
		return(NULL);
		
	instrlen = strlen(inputstr);
	outstring = mallocwrapper(instrlen+1);	
	outstrlen = instrlen + 1;
	
	while(inpos < instrlen)
	{
		/* WARNING: This code contains strange things. Look carefully
			before frobing */
			
		if (inputstr[inpos] != '%')
			outstring[outpos++] = inputstr[inpos++];
		else if (inputstr[inpos + 1] == '%')
		{
			outstring[outpos++] = '%';
			inpos += 2;
		}
		else if (inputstr[inpos + 1] == '(')
		{
			int a, b, strl, pc, allcookie = FALSE;
			
			/* try X to n format */
			
			if (sscanf(inputstr + inpos, "%%(%d,*)", &a) == 1)
				allcookie = TRUE;
			else if (sscanf(inputstr + inpos, "%%(%d,%d)", &a, &b) < 2)
				goto fail;
			
			while(inputstr[inpos++] != ')')
				if (inputstr[inpos] == 0)
					goto fail;
			pc = (int)inputstr[inpos];
			if (pc >= 128)
				goto fail;
			if (tok->params[pc] == NULL)
				goto fail;
			if (allcookie)
				b = strlen(tok->params[pc]);
			if (a < 0)
				a = 0;
			if (a > b)
				goto fail;
			strl = b - a + 1;
			if (strl > strlen(tok->params[pc]) - a)
				strl = strlen(tok->params[pc]) - a;
			inpos++;
			if (escapecookies)
				outstrlen += strl;
			outstrlen += strl;
			reallocwrapper(outstrlen+1, (void *)&outstring);
			if (escapecookies)
			{
				b = a + strl;
				while(a < b)
				{
					outstring[outpos] = '\\';
					outstring[outpos+1] = tok->params[pc][a];
					a++; outpos += 2;
				}	
			}
			else
			{
				strncpy(outstring + outpos, tok->params[pc] + a, strl);
				outpos += strl; 
			}
		}
		else
		{
			int strl, pc, a;
			
			pc = (int)inputstr[++inpos];
			if (pc >= 128)
				goto fail;
			if (tok->params[pc] == NULL)
				goto fail;
			strl = strlen(tok->params[pc]);

			outstrlen += strl;
			if (escapecookies)
				outstrlen += strl;
			
			reallocwrapper(outstrlen+1, (void *)&outstring);
			
			if (escapecookies)
			{
				for (a = 0; a < strl; a++)
				{
					outstring[outpos] = '\\';
					outstring[outpos+1] = tok->params[pc][a];
					outpos += 2;
				}	
			}
			else
			{
				strncpy(outstring + outpos, tok->params[pc], strl);
				outpos += strl;
			}
			inpos++;
		}
	}
	freewrapper(inputstr);
	outstring[outpos] = 0;
	return(outstring);

fail:
	freewrapper(outstring);

	/* log the error so it can get fixed, because it may cause security
	   problems when cookies don't parse correctly! */
	/* These logs annoy too much, I will just forget about it */
/*
	outstring = safe_snprintf("Failed to parse cookie in line: '%s'", inputstr);
	log_addentry(MYLOG_INFO, NULL, outstring);
	freewrapper(outstring);
*/
	return(inputstr);
}

void tokenset_finish(TOKENSET *tset)
{
	int count;
	
	for (count = 0; count < 128; count++)
		freeifnotnull(tset->params[count]);

	freewrapper(tset);
}

/* 			uid and gid routines 			       *
 ***********************************************************************/

/* this loads a username given a uid or a gid given a groupname */

/* this holds fd's to groupname, username files */
/* FIXME: This stuff sucks at NIS wise */

FILE *uidfile = NULL;
FILE *gidfile = NULL;

char *uid_retstr = NULL;
char *gid_retstr = NULL;

#define GROUPFILE "/etc/group"
#define USERFILE "/etc/passwd"

void init_pwgrfiles(void)
{
	if (uidfile)
		fclose(uidfile);
	if (gidfile)
		fclose(gidfile);
	uidfile = fopen(USERFILE, "r");
	gidfile = fopen(GROUPFILE, "r"); 
}

char *get_passwdname(uid_t inuid, int usefile)
{
	struct passwd *pw;

	freeifnotnull(uid_retstr);
	uid_retstr = NULL;
	
	if (!usefile)
	{
		pw = getpwuid(inuid);
		if (pw)
			uid_retstr = strdupwrapper(pw->pw_name);
		return(uid_retstr);
	}
	
	if (!uidfile)
		return(NULL);
	rewind(uidfile);
			
	while((pw = fgetpwent(uidfile)) != NULL)
		if (pw->pw_uid == inuid)
		{
			uid_retstr = strdupwrapper(pw->pw_name);
			return(uid_retstr);
		}
	return(NULL);
}

char *get_groupname(gid_t ingid, int usefile)
{
	struct group *gr;

	freeifnotnull(gid_retstr);
	gid_retstr = NULL;
	
	if (!usefile)
	{
		gr = getgrgid(ingid);
		if (gr)
			gid_retstr = strdupwrapper(gr->gr_name);
		return(gid_retstr);
	}
	
	if (!gidfile)
		return(NULL);
	rewind(gidfile);
		
	while((gr = fgetgrent(gidfile)) != NULL)
		if (gr->gr_gid == ingid)
		{
			gid_retstr = strdupwrapper(gr->gr_name);
			return(gid_retstr);
		}
	return(NULL);
}

/* this must not be called in chroot jail. I shall use plain old getgrent
   so nis works */

gid_t *getusergrouplist(char *username)
{
	gid_t *retlist = mallocwrapper(sizeof(gid_t));
	struct group *gr;
	int count = 0;
	
	/* make sure info isn't stale */
	endgrent();	
	setgrent();
	
	retlist[0] = 0;
	if (!gidfile)
		return(retlist);
	rewind(gidfile);
	while((gr = getgrent()) != NULL)
	{
		int count2 = 0;
		while(gr->gr_mem[count2] != NULL)
		{
			if (strcmp(gr->gr_mem[count2], username) == 0)
				if ((retlist[count] != gr->gr_gid) || (count == 0))
				{
					count++;
					reallocwrapper(sizeof(gid_t) * (count + 1), (void *)&retlist);
					retlist[count] = gr->gr_gid;
					retlist[0]++;
				}
			count2++;
		}
	}
	return(retlist);
}

gid_t *newgidlist(void)
{
	gid_t *retlist = mallocwrapper(sizeof(gid_t));
	retlist[0] = 0;
	return(retlist);
}

gid_t *addgidlist(gid_t *list, gid_t new)
{
	list[0]++;
	reallocwrapper(sizeof(gid_t) * (list[0] + 1), (void *)&list);
	list[list[0]] = new;
	return(list);
}

void delgidlist(gid_t *list, gid_t old)
{
	int pos, pos2;
	int count;
	
	pos = 1;
	pos2 = 1;
	count = list[0];
	
	while(pos <= count)
	{
		if (list[pos] != old)
		{
			list[pos2] = list[pos];
			pos2++;
		}
		else
			list[0]--;
		pos++;
	}
}

gid_t *parsegidlist(char *str)
{
	gid_t *list = newgidlist();
	char *pos = str;
	int newval;
	gid_t new;
	int result;
	int do_remove;

	strtrimspace(str);
	
	if (strcmp(str, "*") == 0)
		return(list);
	
	do
	{
		if (strcmp(pos, "") == 0)
			break;
			
		if (*pos == '!')
		{
			do_remove = TRUE;
			pos++;
		}
		else
			do_remove = FALSE;
		
		if (sscanf(pos, "%d", &newval) == 1)
		{
			new = (gid_t)newval;
			if (do_remove)
				delgidlist(list, new);
			else
				list = addgidlist(list, new);
		}
		else
		{
			list[0] = 0;
			return(list);
		}
		
		pos = strchr(pos, ',');
		result = (pos == NULL);
		pos++;
	}
	while(!result);
	
	return(list);
}
	
char *makegidliststr(gid_t *list)
{
	int count;
	char *ret = mallocwrapper(12 * list[0] + 1);
	char *pos = ret;
	
	strcpy(ret, "");
	for (count = 1; count <= list[0]; count++)
	{
		sprintf(pos, "%u,", list[count]);
		
		while(*pos != 0)
			pos++;
	}
	if (*(pos-1) == ',')
		*(pos-1) = 0;
	return(ret);	
}

void kill_uidgidfiles(void)
{
	if (uidfile)
		fclose(uidfile);
	if (gidfile)
		fclose(gidfile);
	freeifnotnull(gid_retstr);
	freeifnotnull(uid_retstr);
}

int isfilesafe(int fd)
{
	struct stat buf;
	
	fstat(fd, &buf);
	
	if ((buf.st_uid == geteuid()) && (buf.st_gid == getegid()) &&
	        (!(buf.st_mode & 0022)))
		return TRUE;

	return(FALSE);
}

/*			limiter functions				*
 *		This is for the bandwidth limiter			*
 ************************************************************************/
 
LIMITER *limiter_new(int maxspeed)
{
	LIMITER *new;
	struct timezone tz;
	
	new = mallocwrapper(sizeof(LIMITER));
	new->maxspeed = maxspeed;
	new->bytes_transfered = 0;
	gettimeofday(&(new->current_time), &tz); 
	
	return(new);
}

void limiter_add(LIMITER *l, int byte_count, int force)
{
	int dif;
	struct timeval tv;
	struct timezone tz;
	
	l->bytes_transfered += byte_count;

	/* if at least 1 second of data is downloaded, assess the situation
	   and determine how much time to wait */
	if ((l->bytes_transfered >= l->maxspeed) || force)
	{
		gettimeofday(&tv, &tz); 
		dif = (tv.tv_sec - l->current_time.tv_sec) * 1000 
			+ (tv.tv_usec - l->current_time.tv_usec) / 1000;
		dif = (((1000 * l->bytes_transfered) / l->maxspeed) - dif) * 1000;

		/* if usleep takes too long, this will compensate by
		   putting the expected time after usleep into l->current_time
		   instead of reading the real time after an inaccurate
		   usleep, allowing the transfer to catch up */
		memcpy(&(l->current_time), &tv, sizeof(struct timeval));
		l->current_time.tv_usec += (dif % 1000000);
		l->current_time.tv_sec += (dif / 1000000);
		if (dif > 0)
			usleep(dif);
		l->bytes_transfered = 0; 
	}
}

/*	giving up root code. this will use whatever capibilites         *
 *	possible to give up root access but still be able to bind to    *
 *	a low port (if possible)					*
 ************************************************************************/
 
int giveuproot(uid_t uid, gid_t gid)
{
	int error = FALSE;
#ifdef HAVE_CAP_INIT
	cap_t currentset;

#ifdef IRIX
	cap_value_t flags[] = { CAP_PRIV_PORT };
#else
	cap_value_t flags[] = { CAP_NET_BIND_SERVICE };
#endif

#endif
	setregid(gid, gid);
	setreuid(uid, 0);

#ifdef HAVE_CAP_INIT
	/* permit us to set the CAP_NET_BIND_SERVICE capibility
	   after setuid */
	currentset = cap_init();
	cap_clear(currentset);
	cap_set_flag(currentset, CAP_PERMITTED, 1, flags, CAP_SET);
 	if (cap_set_proc(currentset) == -1)
		error = TRUE;
#endif
	/* now finish the switch */
	setuid(uid);

#ifdef HAVE_CAP_INIT
	/* set the CAP_NET_BIND_SERVICE (bind port < 1024) to the 
	   effective set */
	cap_set_flag(currentset, CAP_EFFECTIVE, 1, flags, CAP_SET);
	if (cap_set_proc(currentset) == -1)
		error = TRUE;
#endif
	return(error);
}

/*			blocking signals				*
 ************************************************************************/
 
void blockallsignals()
{
	sigset_t sig_data;
	
	sigfillset(&sig_data);
	sigprocmask(SIG_BLOCK, &sig_data, NULL);
}

void unblockallsignals()
{
	sigset_t sig_data;
	
	sigfillset(&sig_data);
	sigprocmask(SIG_UNBLOCK, &sig_data, NULL);
}

/* 			error routines					*
 ************************************************************************/

/* Sends a message to the screen and then returns to the program */

void errormsg( char *errmessage, char *file, int line )
{
	if (inetd)
		syslog(LOG_ERR, PROGNAME" error in file %s line %d: %s", file, line, errmessage);
	else
		fprintf(stderr, PROGNAME" error in file %s line %d: %s\n", file, line, errmessage);
}

/* Sends a message to the screen and exits gracefully */

void errormsgfatal( char *errmessage, char *file, int line )
{
	errormsg( errmessage, file, line );
	if (inetd)
		syslog(LOG_ERR, "muddleftpd is exiting\n");
	else
		fprintf( stderr, "CANNOT RESUME. Goodbye\n");

	/* Shutdown */
	exit(1);
}

/* 			close file descriptors 				*
 ************************************************************************/

// try to close all non-terminal file descriptors.

void fd_closeall_nonterminal(void)
{
	int count, maxfilefd = 1024;
#ifdef RLIMIT_NPROC
	struct rlimit lim;
	
	getrlimit(RLIMIT_NOFILE, &lim);
	maxfilefd = lim.rlim_max;
#endif
	for (count = 3; count < maxfilefd; count++)
		close(count);
}


syntax highlighted by Code2HTML, v. 0.9.1