Browse Source

small changes

aidmar.wainakh 6 years ago
parent
commit
11147310af
3 changed files with 59 additions and 35 deletions
  1. 51 28
      code/Attack/SQLiAttack.py
  2. 1 2
      code_boost/src/cxx/statistics.cpp
  3. 7 5
      code_boost/src/cxx/statistics_db.cpp

+ 51 - 28
code/Attack/SQLiAttack.py

@@ -131,33 +131,49 @@ class SQLiAttack(BaseAttack.BaseAttack):
             sys.exit(0)
 
         path_attack_pcap = None
-        replyDelay = self.get_reply_delay(ip_destination)
+        minDelay, maxDelay = self.get_reply_delay(ip_destination)
 
         # Inject SQLi Attack
         # Read SQLi Attack pcap file
         orig_ip_dst = None
-        exploit_raw_packets = RawPcapReader("ATutorSQLi.pcap")
+        exploit_raw_packets = RawPcapReader("resources/ATutorSQLi.pcap")
 
         port_source = randint(self.minDefaultPort,self.maxDefaultPort) # experiments show this range of ports
 
+        # Random TCP sequence numbers
+        global attacker_seq
+        attacker_seq = randint(1000, 50000)
+        global victim_seq
+        victim_seq = randint(1000, 50000)
+
         for pkt_num, pkt in enumerate(exploit_raw_packets):
             eth_frame = Ether(pkt[0])
             ip_pkt = eth_frame.payload
             tcp_pkt = ip_pkt.payload
-            str_http_pkt = str(tcp_pkt.payload)
+            str_tcp_seg = str(tcp_pkt.payload)
+
+            # Clean payloads
+            eth_frame.payload = b''
+            ip_pkt.payload = b''
+            tcp_pkt.payload = b''
 
             if pkt_num == 0:
                 prev_orig_port_source = tcp_pkt.getfieldval("sport")
                 if tcp_pkt.getfieldval("dport") == self.http_port:
                     orig_ip_dst = ip_pkt.getfieldval("dst") # victim IP
 
-            # Request
+            # Request: Attacker --> vicitm
             if ip_pkt.getfieldval("dst") == orig_ip_dst: # victim IP
 
                 # There are 363 TCP connections with different source ports, for each of them we generate random port
                 if tcp_pkt.getfieldval("sport") != prev_orig_port_source:
                     port_source = randint(self.minDefaultPort, self.maxDefaultPort)
                     prev_orig_port_source = tcp_pkt.getfieldval("sport")
+                    # New connection, new random TCP sequence numbers
+                    attacker_seq = randint(1000, 50000)
+                    victim_seq = randint(1000, 50000)
+                    # First packet in a connection has ACK = 0
+                    tcp_pkt.setfieldval("ack", 0)
 
                 # Ether
                 eth_frame.setfieldval("src", mac_source)
@@ -168,28 +184,32 @@ class SQLiAttack(BaseAttack.BaseAttack):
                 # TCP
                 tcp_pkt.setfieldval("sport",port_source)
 
-                eth_frame.payload = b''
-                ip_pkt.payload = b''
-                tcp_pkt.payload = b''
-
-                if len(str_http_pkt) > 0:
+                if len(str_tcp_seg) > 0:
                     # convert payload bytes to str => str = "b'..\\r\\n..'"
-                    str_http_pkt = str_http_pkt[2:-1]
-                    str_http_pkt = str_http_pkt.replace('/ATutor', target_uri)
-                    str_http_pkt = str_http_pkt.replace(orig_ip_dst, target_host)
-                    str_http_pkt = str_http_pkt.replace("\\n", "\n")
-                    str_http_pkt = str_http_pkt.replace("\\r", "\r")
+                    str_tcp_seg = str_tcp_seg[2:-1]
+                    str_tcp_seg = str_tcp_seg.replace('/ATutor', target_uri)
+                    str_tcp_seg = str_tcp_seg.replace(orig_ip_dst, target_host)
+                    str_tcp_seg = str_tcp_seg.replace("\\n", "\n")
+                    str_tcp_seg = str_tcp_seg.replace("\\r", "\r")
                     str_tcp_seg = str_tcp_seg.replace("\\t", "\t")
                     str_tcp_seg = str_tcp_seg.replace("\\\'", "\'")
 
-                new_pkt = (eth_frame / ip_pkt/ tcp_pkt / str_http_pkt)
+                # TCP Seq, Ack
+                if tcp_pkt.getfieldval("ack") != 0:
+                    tcp_pkt.setfieldval("ack", victim_seq)
+                tcp_pkt.setfieldval("seq", attacker_seq)
+                if not (tcp_pkt.getfieldval("flags") == 16 and len(str_tcp_seg) == 0):  # flags=A:
+                    attacker_seq += max(len(str_tcp_seg), 1)
+
+                new_pkt = (eth_frame / ip_pkt/ tcp_pkt / str_tcp_seg)
                 new_pkt.time = timestamp_next_pkt
 
                 maxdelay = randomdelay.random()
                 pps = self.minDefaultPPS if getIntervalPPS(complement_interval_pps, timestamp_next_pkt) is None else max(
                     getIntervalPPS(complement_interval_pps, timestamp_next_pkt), self.minDefaultPPS)
                 timestamp_next_pkt = update_timestamp(timestamp_next_pkt, pps, maxdelay)
-            # Reply
+
+            # Reply: Victim --> attacker
             else:
                 # Ether
                 eth_frame.setfieldval("src", mac_destination)
@@ -200,22 +220,25 @@ class SQLiAttack(BaseAttack.BaseAttack):
                 # TCP
                 tcp_pkt.setfieldval("dport", port_source)
 
-                eth_frame.payload = b''
-                ip_pkt.payload = b''
-                tcp_pkt.payload = b''
-
-                if len(str_http_pkt) > 0:
+                if len(str_tcp_seg) > 0:
                     # convert payload bytes to str => str = "b'..\\r\\n..'"
-                    str_http_pkt = str_http_pkt[2:-1]
-                    str_http_pkt = str_http_pkt.replace('/ATutor', target_uri)
-                    str_http_pkt = str_http_pkt.replace(orig_ip_dst, target_host)
-                    str_http_pkt = str_http_pkt.replace("\\n", "\n")
-                    str_http_pkt = str_http_pkt.replace("\\r", "\r")
+                    str_tcp_seg = str_tcp_seg[2:-1]
+                    str_tcp_seg = str_tcp_seg.replace('/ATutor', target_uri)
+                    str_tcp_seg = str_tcp_seg.replace(orig_ip_dst, target_host)
+                    str_tcp_seg = str_tcp_seg.replace("\\n", "\n")
+                    str_tcp_seg = str_tcp_seg.replace("\\r", "\r")
                     str_tcp_seg = str_tcp_seg.replace("\\t", "\t")
                     str_tcp_seg = str_tcp_seg.replace("\\\'", "\'")
 
-                new_pkt = (eth_frame / ip_pkt / tcp_pkt / str_http_pkt)
-                timestamp_next_pkt = timestamp_next_pkt + uniform(replyDelay, 2 * replyDelay)
+                # TCP Seq, ACK
+                tcp_pkt.setfieldval("ack", attacker_seq)
+                tcp_pkt.setfieldval("seq", victim_seq)
+                strLen = len(str_tcp_seg)
+                if not (tcp_pkt.getfieldval("flags") == 16 and strLen == 0):  # flags=A:
+                    victim_seq += max(strLen, 1)
+
+                new_pkt = (eth_frame / ip_pkt / tcp_pkt / str_tcp_seg)
+                timestamp_next_pkt = timestamp_next_pkt + uniform(minDelay, 2 * maxDelay)
                 new_pkt.time = timestamp_next_pkt
 
             packets.append(new_pkt)

+ 1 - 2
code_boost/src/cxx/statistics.cpp

@@ -25,8 +25,7 @@ std::vector<float> statistics::calculateLastIntervalIPsEntropy(std::chrono::micr
         int pktsSent = 0, pktsReceived = 0;
         
         for (auto i = ip_statistics.begin(); i != ip_statistics.end(); i++) {
-            // TO-DO: should add this condition to avoid Segmentation Fault    if(i->second.pktsSentTimestamp.size()>0) realy?
-            int indexStartSent = getClosestIndex(i->second.pktsSentTimestamp, intervalStartTimestamp);                         
+            int indexStartSent = getClosestIndex(i->second.pktsSentTimestamp, intervalStartTimestamp);
             int IPsSrcPktsCount = i->second.pktsSentTimestamp.size() - indexStartSent;
             IPsSrcPktsCounts.push_back(IPsSrcPktsCount);
             pktsSent += IPsSrcPktsCount;                        

+ 7 - 5
code_boost/src/cxx/statistics_db.cpp

@@ -386,23 +386,25 @@ void statistics_db::writeStatisticsInterval(std::unordered_map<std::string, entr
         const char *createTable = "CREATE TABLE interval_statistics ("
                 "lastPktTimestamp TEXT,"
                 "pktsCount INTEGER,"
+                "kBytes REAL,"
                 "ipSrcEntropy REAL,"      
                 "ipDstEntropy REAL,"  
                 "ipSrcCumEntropy REAL,"      
                 "ipDstCumEntropy REAL," 
                 "PRIMARY KEY(lastPktTimestamp));";
         db->exec(createTable);
-        SQLite::Statement query(*db, "INSERT INTO interval_statistics VALUES (?, ?, ?, ?, ?, ?)");
+        SQLite::Statement query(*db, "INSERT INTO interval_statistics VALUES (?, ?, ?, ?, ?, ?. ?)");
         for (auto it = intervalStatistics.begin(); it != intervalStatistics.end(); ++it) {
             std::string t = it->first;
             entry_intervalStat e = it->second;        
             
             query.bind(1, t);
             query.bind(2, (int)e.pkts_count);
-            query.bind(3, e.ip_src_entropy);
-            query.bind(4, e.ip_dst_entropy);
-            query.bind(5, e.ip_src_cum_entropy);
-            query.bind(6, e.ip_dst_cum_entropy);
+            query.bind(3, e.kbytes);
+            query.bind(4, e.ip_src_entropy);
+            query.bind(5, e.ip_dst_entropy);
+            query.bind(6, e.ip_src_cum_entropy);
+            query.bind(7, e.ip_dst_cum_entropy);
 
             query.exec();
             query.reset();