/* TitanFirewall.cpp
 *
 * This must be compiled with -L/usr/lib/mysql -lmysqlclient or it won't find
 * the mysql libraries.
 *
 * It must also be compiled with -lpthread or it won't be able to spawn threads.
 */

#include <iostream>
#include <cstdio>
#include <stdio.h>
//vector definition here http://www.sgi.com/tech/stl/Vector.html
#include <vector>
#include "Connection.cpp"

#include "Attributes/Attribute.cpp"
#include "Attributes/Attr_DestinationIP.cpp"
#include "Attributes/Attr_SourceIP.cpp"
#include "Attributes/Attr_DestinationPort.cpp"
#include "Attributes/Attr_SourcePort.cpp"
#include "Attributes/Attr_LengthOfConnection.cpp"
#include "Attributes/Attr_PacketLength.cpp"
#include "Attributes/Attr_PacketsPerSecond.cpp"
#include "Attributes/Attr_Protocol.cpp"

#include "Classifier.cpp"
#include "Db_Query/Db_Query.cpp"
#include <string.h>
#include <pthread.h>
#include <stdlib.h>
//#include <unistd.h>

#define THRESHOLD .5
#define SUBNET_THRESHOLD 5

#define SOURCE_IP_WEIGHT .75
#define SOURCE_PORT_WEIGHT 0
#define DESTINATION_IP_WEIGHT .5
#define DESTINATION_PORT_WEIGHT 1
#define LENGTH_OF_CONNECTION_WEIGHT .25
#define PACKET_LENGTH_WEIGHT .25
#define PACKETS_PER_SECOND_WEIGHT .25
#define PROTOCOL_WEIGHT 1


//#ifndef
//#define

//#define ANOMALY_DISTANCE 500
//#define ANOMALY_SCORE .8

using namespace std;

void *rule_thread( void * );

int main()
{
	// And it was said "Let there be an attribute vector!"
	vector<Attribute*> attributes;
	
	// We are creating our attributes
	Attr_DestinationIP a1(DESTINATION_IP_WEIGHT);
	Attr_SourceIP a2(SOURCE_IP_WEIGHT);
	Attr_DestinationPort a3(DESTINATION_PORT_WEIGHT);
	Attr_SourcePort a4(SOURCE_PORT_WEIGHT);
	Attr_LengthOfConnection a5(LENGTH_OF_CONNECTION_WEIGHT);
	Attr_PacketLength a6(PACKET_LENGTH_WEIGHT);
	Attr_PacketsPerSecond a7(PACKETS_PER_SECOND_WEIGHT);
	Attr_Protocol a8(PROTOCOL_WEIGHT);
	
	// We are setting up our attribute vector
	attributes.push_back(&a1);
	attributes.push_back(&a2);
	attributes.push_back(&a3);
	//attributes.push_back(&a4);
	attributes.push_back(&a5);
	attributes.push_back(&a6);
	attributes.push_back(&a7);
	attributes.push_back(&a8);

	// Our vector of connections, called... "connections", believe it or not.
	vector<Connection> connections;
	
	// Database handle called dbh.  Genius.
	Db_Query *dbh = new Db_Query;
	
	int rows; // To be used for parsing the database query
	int fields; // To be used for parsing the database query
	char **row; // Stores the results of the DB query
	char query[300]; // Size of the query string; if it's bigger than this, we're in trouble

	sprintf(query,"SELECT ip_saddr, ip_daddr, sport, dport, score, open, protocol, packet_length, UNIX_TIMESTAMP(time_established), UNIX_TIMESTAMP(time_last_packet), num_packets FROM connections WHERE open='NO' AND currently_scored='YES'");
	sql_data_t query_result;

	// Run the query and check for errors all at the same time!
	// Grab all open and currently scored connections from the database.
	checkError(query_result = dbh->query(query, strlen(query), &rows, &fields), dbh);


	/* ******* CREATE OUR TITAN CLASSIFIER HERE!!!  SWEET!!!! ******* */
	// In order to do clustering correctly on the fly, we read in the connections
	// data and cluster it at startup; this allows us to keep a record of every
	// connection in the database, not just the clusters.  It also allows us to
	// keep just the clusters in the vector in memory giving us much faster searches.
	Classifier titanClassifier(attributes, connections, DESTINATION_IP_WEIGHT + DESTINATION_PORT_WEIGHT + SOURCE_IP_WEIGHT + SOURCE_PORT_WEIGHT + LENGTH_OF_CONNECTION_WEIGHT + PACKET_LENGTH_WEIGHT + PACKETS_PER_SECOND_WEIGHT + PROTOCOL_WEIGHT);
	
	// Generate connection list by looking at results from database query
	for(int rowcnt = 0; rowcnt < rows; rowcnt++)
	{
		row = query_result + rowcnt * fields;
		
		Connection con( (unsigned int) strtoul(row[0], (char**)NULL, 10), // Source IP
				(unsigned int) strtoul(row[1], (char**)NULL, 10), // Destination IP
				(int) strtol(row[2], (char**)NULL, 10), // Source Port
				(int) strtol(row[3], (char**)NULL, 10), // Destination Port
				strtod(row[4], (char**)NULL), // Score
				strcmp(row[5], "YES") == 0, // Open?
				(int) strtol(row[6], (char**)NULL, 10), // Protocol
				strtod(row[9], (char**)NULL) - strtod(row[8], (char**)NULL), // Connection Length
				strtod(row[7], (char**)NULL), // Packet Length
				strtod(row[10], (char**)NULL) // Number of Packets
				);
		// Give our newly created classifier a new connection to classify, to determine
		// if it should be added to the connection list.
		titanClassifier.classify(&con,true,false);
		if(rowcnt%100==0)
		{
			cout<<"We are at row " << rowcnt << endl;
			cout<<"We have "<<titanClassifier.getConnectionsSize()<<" connections in memory."<<endl;
		}
	}
	// Clean up database query memory.
	dbh->release(query_result);
	
	cout<<"The number of connections in the vector in memory is "<<titanClassifier.getConnectionsSize()<<endl;

	while(true)  // Forever and ever, keep looking for new connections to classify
	{
		sprintf(query, "SELECT ip_saddr, ip_daddr, sport, dport, score, open, protocol, packet_length, UNIX_TIMESTAMP(time_established), UNIX_TIMESTAMP(time_last_packet), num_packets, connection_id FROM connections WHERE currently_scored='NO'");
		checkError(query_result = dbh->query(query, strlen(query), &rows, &fields), dbh);
		
		// We are selecting all connections that are not currently scored
		// (as indicated in the database).  We pass these to our classifier
		// (telling it whether the connection is open or closed, so it can
		// decide whether to use the connection for future classification),
		// and then create a thread to create a rule to block the connection
		// if the score is less than our assigned THRESHOLD.
		for(int rowcnt = 0; rowcnt < rows; rowcnt++)
		{
			row = query_result + rowcnt * fields;
			
			Connection con( (unsigned int) strtoul(row[0], (char**)NULL, 10), // Source IP
				(unsigned int) strtoul(row[1], (char**)NULL, 10), // Destination IP
				(int) strtol(row[2], (char**)NULL, 10), // Source Port
				(int) strtol(row[3], (char**)NULL, 10), // Destination Port
				strtod(row[4], (char**)NULL), // Score
				strcmp(row[5], "YES") == 0, // Open?
				(int) strtol(row[6], (char**)NULL, 10), // Protocol
				strtod(row[9], (char**)NULL) - strtod(row[8], (char**)NULL), // Connection Length
				strtod(row[7], (char**)NULL), // Packet Length
				strtod(row[10], (char**)NULL) // Number of Packets
				);
				
			// score da connection!
			if(con.open)
			{
				titanClassifier.classify(&con,false,true);
			}
			else
			{
				titanClassifier.classify(&con,true,true);
			}

			//set the currently scored flag to true for the recently scored connection
			sprintf(query, "UPDATE connections SET currently_scored='YES', score='%f' WHERE connection_id='%s'", con.score, row[11]);
			checkError(dbh->query(query, strlen(query)), dbh);
			
			if(con.score < THRESHOLD)
			{
				// Make a separate thread to make a rule.
				pthread_t ignoreMe;
				if( pthread_create( &ignoreMe, NULL, rule_thread, (void *) &con) != 0)
				{
					perror( "Can't create rule-making thread!");
				}
			}
		}
		dbh->release(query_result);
	}	
}

// This is the code that gets run when a thread to make a rule is spawned
void * rule_thread( void * con )
{
	Connection *currentCon2;
	currentCon2 = (Connection*) con;
	Connection currentCon = *currentCon2;
	cout<< "Making a rule to block the current IP: "<< currentCon.source_ip_address << "." << endl;
	Db_Query *rule_dbh = new Db_Query;
	int rule_rows;
	int rule_fields;
	char system_string[200];
	char rule_query[300];
	sprintf(rule_query, "SELECT * FROM rules WHERE ip='%u' OR (ip='%u' AND mask='24')",
		currentCon.source_ip_address, currentCon.source_ip_address & 4294967040);
	sql_data_t rule_results;
	
	
	checkError(rule_results = rule_dbh->query(rule_query, strlen(rule_query), &rule_rows, &rule_fields), rule_dbh);
	rule_dbh->release(rule_results);
	// If there's a rule already blocking this connection, update the expiration time on that rule.
	if(rule_rows >= 1)
	{
		sprintf(rule_query, "UPDATE rules SET expiration=DATE_ADD(NOW(), INTERVAL 15 MINUTE) WHERE ip='%u' OR (ip='%u' AND mask='24')",
			currentCon.source_ip_address, currentCon.source_ip_address & 4294967040);
		//run query
		checkError(rule_dbh->query(rule_query, strlen(rule_query)), rule_dbh);
	}
	// Else, there is no rule is already blocking this connection; make a new rule to do so.
	// This rule will either block the IP if we have less rules than SUBNET_THRESHOLD blocking
	// IPs in the current connection's subnet, otherwise we'll make a rule to block the subnet.
	else
	{
		sprintf(rule_query, "SELECT * FROM rules WHERE ip>='%u' AND ip<='%u'",
			currentCon.source_ip_address & 4294967040, (currentCon.source_ip_address & 4294967040) + 256);
		checkError(rule_results = rule_dbh->query(rule_query, strlen(rule_query), &rule_rows, &rule_fields), rule_dbh);
		rule_dbh->release(rule_results);
		
		if(rule_rows >= SUBNET_THRESHOLD)
		{
			//block the subnet
			sprintf(system_string, "iptables -A FWBLOCK --source %u/24 -j DROP", currentCon.source_ip_address & 4294967040);
			system(system_string);
			sprintf(rule_query, "INSERT INTO rules (ip, mask, expiration) values ('%u', 24, DATE_ADD(NOW(), INTERVAL 1 MINUTE)) ON DUPLICATE KEY UPDATE expiration=VALUES(expiration)",
				currentCon.source_ip_address & 4294967040);
			checkError(rule_dbh->query(rule_query, strlen(rule_query)), rule_dbh);
		}
		else
		{
			//just block the ip
			sprintf(system_string, "iptables -A FWBLOCK --source %u/32 -j DROP", currentCon.source_ip_address);
			system(system_string);
			sprintf(rule_query, "INSERT INTO rules (ip, mask, expiration) values ('%u', 32, DATE_ADD(NOW(), INTERVAL 1 MINUTE)) ON DUPLICATE KEY UPDATE expiration=VALUES(expiration)",
				currentCon.source_ip_address);
			checkError(rule_dbh->query(rule_query, strlen(rule_query)), rule_dbh);
		}
	}
	return (void *) NULL;
}
