Sfoglia il codice sorgente

add first ibrary unittest

add SMBLib unit tests
add Utility unit tests
fix get smb version random choice for linux
fix utility get random nops. copy nop/pseudonop lists instead of using original lists
fix utility get bytes from file. close files before exiting if error
Stefano Acquaviti 6 anni fa
parent
commit
3d3b9f7916

+ 2 - 2
code/ID2TLib/SMBLib.py

@@ -1,6 +1,6 @@
 from os import urandom
 from binascii import b2a_hex
-from random import random
+from random import choice
 
 from ID2TLib.Utility import check_platform, get_filetime_format, get_rnd_boot_time
 
@@ -60,7 +60,7 @@ def get_smb_version(platform: str):
     """
     check_platform(platform)
     if platform is "linux":
-        return random.choice(list(smb_versions_per_samba.values()))
+        return choice(list(smb_versions_per_samba.values()))
     elif platform is "macos":
         return "2.1"
     else:

+ 6 - 3
code/ID2TLib/Utility.py

@@ -188,9 +188,9 @@ def get_rnd_x86_nop(count=1, side_effect_free=False, char_filter=set()):
     :return: Random x86 NOP bytestring
     """
     result = b''
-    nops = x86_nops
+    nops = x86_nops.copy()
     if not side_effect_free:
-        nops |= x86_pseudo_nops
+        nops |= x86_pseudo_nops.copy()
 
     if not isinstance(char_filter, set):
         char_filter = set(char_filter)
@@ -249,16 +249,19 @@ def get_bytes_from_file(filepath):
                 result_bytes = bytes.fromhex(content)
             except ValueError:
                 print("\nERROR: Content of file is not all hexadecimal.")
+                file.close()
                 exit(1)
         elif header == "str":
-            result_bytes = content.encode()
+            result_bytes = content.strip().encode()
         else:
             print("\nERROR: Invalid header found: " + header + ". Try 'hex' or 'str' followed by endline instead.")
+            file.close()
             exit(1)
 
         for forbidden_char in forbidden_chars:
             if forbidden_char in result_bytes:
                 print("\nERROR: Forbidden character found in payload: ", forbidden_char)
+                file.close()
                 exit(1)
 
         file.close()

+ 2 - 0
code/Test/resources/HexTestFile.txt

@@ -0,0 +1,2 @@
+hex
+"abcd ef \xff10\ff 'xaa' x \ ab"

+ 2 - 0
code/Test/resources/InvalidHeader.txt

@@ -0,0 +1,2 @@
+InvalidHeader
+The header above is invalid because it is not 'hex' or 'str'

+ 2 - 0
code/Test/resources/InvalidHexFile.txt

@@ -0,0 +1,2 @@
+hex
+This is not a valid hexdump

+ 2 - 0
code/Test/resources/StringTestFile.txt

@@ -0,0 +1,2 @@
+str
+This is a string-test

+ 57 - 0
code/Test/test_SMBLib.py

@@ -0,0 +1,57 @@
+from unittest import TestCase
+from ID2TLib.SMBLib import *
+from ID2TLib.Utility import platforms, get_filetime_format
+
+
+class TestSMBLib(TestCase):
+
+    def test_get_smb_version_all(self):
+
+        for platform in platforms:
+            with self.subTest(platform):
+                result = get_smb_version(platform)
+                self.assertTrue((result in smb_versions_per_win.values() or result in smb_versions_per_samba.values()))
+
+    def test_get_smb_version_invalid(self):
+
+        with self.assertRaises(SystemExit):
+            get_smb_version("abc")
+
+    def test_get_smb_version_mac(self):
+        self.assertEqual(get_smb_version("macos"), "2.1")
+
+    def test_get_smb_version_win(self):
+
+        win_platforms = {'win7', 'win10', 'winxp', 'win8.1', 'win8', 'winvista', 'winnt', "win2000"}
+
+        for platform in win_platforms:
+            with self.subTest(platform):
+                self.assertIn(get_smb_version(platform), smb_versions_per_win.values())
+
+    def test_get_smb_version_linux(self):
+        self.assertIn(get_smb_version("linux"), smb_versions_per_samba.values())
+
+    def test_get_smb_platform_data_invalid(self):
+
+        with self.assertRaises(SystemExit):
+            get_smb_platform_data("abc", 0)
+
+    def test_get_smb_platform_data_linux(self):
+        self.assertEqual((get_smb_platform_data("linux", 0)), ("ubuntu", security_blob_ubuntu, 0x5, 0x800000, 0))
+
+    def test_get_smb_platform_data_mac(self):
+        guid, blob, cap, d_size, time = get_smb_platform_data("macos", 0)
+        self.assertEqual((blob, cap, d_size, time), (security_blob_macos, 0x6, 0x400000, 0))
+        self.assertTrue(isinstance(guid, str) and len(guid) > 0)
+
+    def test_get_smb_platform_data_win(self):
+        guid, blob, cap, d_size, time = get_smb_platform_data("win7", 100)
+        self.assertEqual((blob, cap, d_size), (security_blob_windows, 0x7, 0x100000))
+        self.assertTrue(isinstance(guid, str) and len(guid) > 0)
+        self.assertTrue(time <= get_filetime_format(100))
+
+    def test_invalid_smb_version(self):
+        with self.assertRaises(SystemExit):
+            invalid_smb_version("abc")
+
+

+ 173 - 0
code/Test/test_Utility.py

@@ -0,0 +1,173 @@
+from unittest import TestCase
+from ID2TLib.Utility import *
+
+
+class TestUtility(TestCase):
+
+    def test_update_timestamp_no_delay(self):
+        self.assertTrue(100+10/5 >= update_timestamp(100, 5) >= 100+1/5)
+
+    def test_update_timestamp_with_delay(self):
+        self.assertTrue(100+1/5+10*100 >= update_timestamp(100, 5, 10) >= 100+1/5+10)
+
+    def test_update_timestamp_comparison(self):
+        self.assertTrue(update_timestamp(100, 5) <= update_timestamp(100, 5, 10))
+
+    def test_get_interval_pps_below_max(self):
+        cipps = [(5, 1), (10, 2), (15, 3)]
+        self.assertEqual(get_interval_pps(cipps, 3), 1)
+        self.assertEqual(get_interval_pps(cipps, 7), 2)
+        self.assertEqual(get_interval_pps(cipps, 12), 3)
+
+    def test_get_interval_pps_above_max(self):
+        cipps = [(5, 1), (10, 2), (15, 3)]
+        self.assertEqual(get_interval_pps(cipps, 30), 3)
+
+    # Errors if empty list and result bad if only one list
+    def test_get_nth_random_element_equal_no(self):
+        letters = ["A", "B", "C"]
+        numbers = [1, 2, 3]
+        results = [("A", 1), ("B", 2), ("C", 3)]
+        self.assertIn(get_nth_random_element(letters, numbers), results)
+
+    def test_get_nth_random_element_unequal_no(self):
+        letters = ["A", "B", "C"]
+        numbers = [1, 2]
+        results = [("A", 1), ("B", 2)]
+        self.assertIn(get_nth_random_element(letters, numbers), results)
+
+    #def test_get_nth_random_element_single_list(self):
+        #letters = ["A", "B", "C"]
+        #self.assertIn(get_nth_random_element(letters), letters)
+
+    def test_index_increment_not_max(self):
+        self.assertEqual(index_increment(5, 10), 6)
+
+    def test_index_increment_max(self):
+        self.assertEqual(index_increment(10, 10), 0)
+
+    # Correct?
+    def test_index_increment_max2(self):
+        self.assertEqual(index_increment(9, 10), 0)
+
+    def test_get_rnd_os(self):
+        self.assertIn(get_rnd_os(), platforms)
+
+    def test_check_platform_valid(self):
+        check_platform("linux")
+
+    def test_check_platform_invalid(self):
+        with self.assertRaises(SystemExit):
+            check_platform("abc")
+
+    def test_get_ip_range_forwards(self):
+        start = "192.168.178.254"
+        end = "192.168.179.1"
+        result = ["192.168.178.254", "192.168.178.255", "192.168.179.0", "192.168.179.1"]
+        self.assertEqual(get_ip_range(start, end), result)
+
+    def test_get_ip_range_backwards(self):
+        end = "192.168.178.254"
+        start = "192.168.179.1"
+        result = ["192.168.179.1", "192.168.179.0", "192.168.178.255", "192.168.178.254"]
+        self.assertEqual(get_ip_range(start, end), result)
+
+    def test_generate_source_port_from_platform_invalid(self):
+        with self.assertRaises(SystemExit):
+            generate_source_port_from_platform("abc")
+
+    def test_generate_source_port_from_platform_oldwin_firstport(self):
+        self.assertTrue(1024 <= generate_source_port_from_platform("winxp") <= 5000)
+
+    def test_generate_source_port_from_platform_oldwin_nextport(self):
+        self.assertEqual(generate_source_port_from_platform("winxp", 2000), 2001)
+
+    def test_generate_source_port_from_platform_oldwin_maxport(self):
+        self.assertTrue(1024 <= generate_source_port_from_platform("winxp", 5000) <= 5000)
+
+    def test_generate_source_port_from_platform_linux(self):
+        self.assertTrue(32768 <= generate_source_port_from_platform("linux") <= 61000)
+
+    def test_generate_source_port_from_platform_newwinmac_firstport(self):
+        self.assertTrue(49152 <= generate_source_port_from_platform("win7") <= 65535)
+
+    def test_generate_source_port_from_platform_newwinmac_nextport(self):
+        self.assertEqual(generate_source_port_from_platform("win7", 50000), 50001)
+
+    def test_generate_source_port_from_platform_newwinmac_maxport(self):
+        self.assertTrue(49152 <= generate_source_port_from_platform("win7", 65535) <= 65535)
+
+    # Test get_filetime_format????
+
+    def test_get_rnd_boot_time_invalid(self):
+        with self.assertRaises(SystemExit):
+            get_rnd_boot_time(10, "abc")
+
+    def test_get_rnd_boot_time_linux(self):
+        self.assertTrue(get_rnd_boot_time(100, "linux") < 100)
+
+    def test_get_rnd_boot_time_macos(self):
+        self.assertTrue(get_rnd_boot_time(100, "macos") < 100)
+
+    def test_get_rnd_boot_time_win(self):
+        self.assertTrue(get_rnd_boot_time(100, "win7") < 100)
+
+    def test_get_rnd_x86_nop_len(self):
+        result = get_rnd_x86_nop(1000)
+        self.assertEqual(len(result), 1000)
+
+    def test_get_rnd_x86_nop_with_sideeffects(self):
+        result = get_rnd_x86_nop(1000, False)
+        correct = True
+        for byte in result:
+            if byte.to_bytes(1, "little") not in x86_nops and byte.to_bytes(1, "little") not in x86_pseudo_nops:
+                correct = False
+        self.assertTrue(correct)
+
+    def test_get_rnd_x86_nop_without_sideeffects(self):
+        result = get_rnd_x86_nop(1000, True)
+        correct = True
+        for byte in result:
+            if byte.to_bytes(1, "little") in x86_pseudo_nops:
+                correct = False
+        self.assertTrue(correct)
+
+    def test_get_rnd_x86_nop_filter(self):
+        result = get_rnd_x86_nop(1000, False, x86_nops.copy())
+        correct = True
+        for byte in result:
+            if byte.to_bytes(1, "little") in x86_nops:
+                correct = False
+        self.assertTrue(correct)
+
+    def test_get_rnd_bytes_number(self):
+        result = get_rnd_bytes(1000)
+        self.assertEqual(len(result), 1000)
+
+    def test_get_rnd_bytes_filter(self):
+        result = get_rnd_bytes(1000, x86_pseudo_nops.copy())
+        correct = True
+        for byte in result:
+            if byte.to_bytes(1, "little") in x86_pseudo_nops:
+                correct = False
+        self.assertTrue(correct)
+
+    def test_get_bytes_from_file_invalid_path(self):
+        with self.assertRaises(SystemExit):
+            get_bytes_from_file("resources/NonExistingFile.txt")
+
+    def test_get_bytes_from_file_invalid_header(self):
+        with self.assertRaises(SystemExit):
+            get_bytes_from_file("resources/InvalidHeader.txt")
+
+    def test_get_bytes_from_file_invalid_hexfile(self):
+        with self.assertRaises(SystemExit):
+            get_bytes_from_file("resources/InvalidHexFile.txt")
+
+    def test_get_bytes_from_file_str(self):
+        result = get_bytes_from_file("resources/StringTestFile.txt")
+        self.assertEqual(result, b'This is a string-test')
+
+    def test_get_bytes_from_file_hex(self):
+        result = get_bytes_from_file("resources/HexTestFile.txt")
+        self.assertEqual(result, b'\xab\xcd\xef\xff\x10\xff\xaa\xab')