Browse Source

add first ibrary unittest

add SMBLib unit tests
add Utility unit tests
fix get smb version random choice for linux
fix utility get random nops. copy nop/pseudonop lists instead of using original lists
fix utility get bytes from file. close files before exiting if error
Stefano Acquaviti 6 years ago
parent
commit
3d3b9f7916

+ 2 - 2
code/ID2TLib/SMBLib.py

@@ -1,6 +1,6 @@
 from os import urandom
 from binascii import b2a_hex
-from random import random
+from random import choice
 
 from ID2TLib.Utility import check_platform, get_filetime_format, get_rnd_boot_time
 
@@ -60,7 +60,7 @@ def get_smb_version(platform: str):
     """
     check_platform(platform)
     if platform is "linux":
-        return random.choice(list(smb_versions_per_samba.values()))
+        return choice(list(smb_versions_per_samba.values()))
     elif platform is "macos":
         return "2.1"
     else:

+ 6 - 3
code/ID2TLib/Utility.py

@@ -188,9 +188,9 @@ def get_rnd_x86_nop(count=1, side_effect_free=False, char_filter=set()):
     :return: Random x86 NOP bytestring
     """
     result = b''
-    nops = x86_nops
+    nops = x86_nops.copy()
     if not side_effect_free:
-        nops |= x86_pseudo_nops
+        nops |= x86_pseudo_nops.copy()
 
     if not isinstance(char_filter, set):
         char_filter = set(char_filter)
@@ -249,16 +249,19 @@ def get_bytes_from_file(filepath):
                 result_bytes = bytes.fromhex(content)
             except ValueError:
                 print("\nERROR: Content of file is not all hexadecimal.")
+                file.close()
                 exit(1)
         elif header == "str":
-            result_bytes = content.encode()
+            result_bytes = content.strip().encode()
         else:
             print("\nERROR: Invalid header found: " + header + ". Try 'hex' or 'str' followed by endline instead.")
+            file.close()
             exit(1)
 
         for forbidden_char in forbidden_chars:
             if forbidden_char in result_bytes:
                 print("\nERROR: Forbidden character found in payload: ", forbidden_char)
+                file.close()
                 exit(1)
 
         file.close()

+ 2 - 0
code/Test/resources/HexTestFile.txt

@@ -0,0 +1,2 @@
+hex
+"abcd ef \xff10\ff 'xaa' x \ ab"

+ 2 - 0
code/Test/resources/InvalidHeader.txt

@@ -0,0 +1,2 @@
+InvalidHeader
+The header above is invalid because it is not 'hex' or 'str'

+ 2 - 0
code/Test/resources/InvalidHexFile.txt

@@ -0,0 +1,2 @@
+hex
+This is not a valid hexdump

+ 2 - 0
code/Test/resources/StringTestFile.txt

@@ -0,0 +1,2 @@
+str
+This is a string-test

+ 57 - 0
code/Test/test_SMBLib.py

@@ -0,0 +1,57 @@
+from unittest import TestCase
+from ID2TLib.SMBLib import *
+from ID2TLib.Utility import platforms, get_filetime_format
+
+
+class TestSMBLib(TestCase):
+
+    def test_get_smb_version_all(self):
+
+        for platform in platforms:
+            with self.subTest(platform):
+                result = get_smb_version(platform)
+                self.assertTrue((result in smb_versions_per_win.values() or result in smb_versions_per_samba.values()))
+
+    def test_get_smb_version_invalid(self):
+
+        with self.assertRaises(SystemExit):
+            get_smb_version("abc")
+
+    def test_get_smb_version_mac(self):
+        self.assertEqual(get_smb_version("macos"), "2.1")
+
+    def test_get_smb_version_win(self):
+
+        win_platforms = {'win7', 'win10', 'winxp', 'win8.1', 'win8', 'winvista', 'winnt', "win2000"}
+
+        for platform in win_platforms:
+            with self.subTest(platform):
+                self.assertIn(get_smb_version(platform), smb_versions_per_win.values())
+
+    def test_get_smb_version_linux(self):
+        self.assertIn(get_smb_version("linux"), smb_versions_per_samba.values())
+
+    def test_get_smb_platform_data_invalid(self):
+
+        with self.assertRaises(SystemExit):
+            get_smb_platform_data("abc", 0)
+
+    def test_get_smb_platform_data_linux(self):
+        self.assertEqual((get_smb_platform_data("linux", 0)), ("ubuntu", security_blob_ubuntu, 0x5, 0x800000, 0))
+
+    def test_get_smb_platform_data_mac(self):
+        guid, blob, cap, d_size, time = get_smb_platform_data("macos", 0)
+        self.assertEqual((blob, cap, d_size, time), (security_blob_macos, 0x6, 0x400000, 0))
+        self.assertTrue(isinstance(guid, str) and len(guid) > 0)
+
+    def test_get_smb_platform_data_win(self):
+        guid, blob, cap, d_size, time = get_smb_platform_data("win7", 100)
+        self.assertEqual((blob, cap, d_size), (security_blob_windows, 0x7, 0x100000))
+        self.assertTrue(isinstance(guid, str) and len(guid) > 0)
+        self.assertTrue(time <= get_filetime_format(100))
+
+    def test_invalid_smb_version(self):
+        with self.assertRaises(SystemExit):
+            invalid_smb_version("abc")
+
+

+ 173 - 0
code/Test/test_Utility.py

@@ -0,0 +1,173 @@
+from unittest import TestCase
+from ID2TLib.Utility import *
+
+
+class TestUtility(TestCase):
+
+    def test_update_timestamp_no_delay(self):
+        self.assertTrue(100+10/5 >= update_timestamp(100, 5) >= 100+1/5)
+
+    def test_update_timestamp_with_delay(self):
+        self.assertTrue(100+1/5+10*100 >= update_timestamp(100, 5, 10) >= 100+1/5+10)
+
+    def test_update_timestamp_comparison(self):
+        self.assertTrue(update_timestamp(100, 5) <= update_timestamp(100, 5, 10))
+
+    def test_get_interval_pps_below_max(self):
+        cipps = [(5, 1), (10, 2), (15, 3)]
+        self.assertEqual(get_interval_pps(cipps, 3), 1)
+        self.assertEqual(get_interval_pps(cipps, 7), 2)
+        self.assertEqual(get_interval_pps(cipps, 12), 3)
+
+    def test_get_interval_pps_above_max(self):
+        cipps = [(5, 1), (10, 2), (15, 3)]
+        self.assertEqual(get_interval_pps(cipps, 30), 3)
+
+    # Errors if empty list and result bad if only one list
+    def test_get_nth_random_element_equal_no(self):
+        letters = ["A", "B", "C"]
+        numbers = [1, 2, 3]
+        results = [("A", 1), ("B", 2), ("C", 3)]
+        self.assertIn(get_nth_random_element(letters, numbers), results)
+
+    def test_get_nth_random_element_unequal_no(self):
+        letters = ["A", "B", "C"]
+        numbers = [1, 2]
+        results = [("A", 1), ("B", 2)]
+        self.assertIn(get_nth_random_element(letters, numbers), results)
+
+    #def test_get_nth_random_element_single_list(self):
+        #letters = ["A", "B", "C"]
+        #self.assertIn(get_nth_random_element(letters), letters)
+
+    def test_index_increment_not_max(self):
+        self.assertEqual(index_increment(5, 10), 6)
+
+    def test_index_increment_max(self):
+        self.assertEqual(index_increment(10, 10), 0)
+
+    # Correct?
+    def test_index_increment_max2(self):
+        self.assertEqual(index_increment(9, 10), 0)
+
+    def test_get_rnd_os(self):
+        self.assertIn(get_rnd_os(), platforms)
+
+    def test_check_platform_valid(self):
+        check_platform("linux")
+
+    def test_check_platform_invalid(self):
+        with self.assertRaises(SystemExit):
+            check_platform("abc")
+
+    def test_get_ip_range_forwards(self):
+        start = "192.168.178.254"
+        end = "192.168.179.1"
+        result = ["192.168.178.254", "192.168.178.255", "192.168.179.0", "192.168.179.1"]
+        self.assertEqual(get_ip_range(start, end), result)
+
+    def test_get_ip_range_backwards(self):
+        end = "192.168.178.254"
+        start = "192.168.179.1"
+        result = ["192.168.179.1", "192.168.179.0", "192.168.178.255", "192.168.178.254"]
+        self.assertEqual(get_ip_range(start, end), result)
+
+    def test_generate_source_port_from_platform_invalid(self):
+        with self.assertRaises(SystemExit):
+            generate_source_port_from_platform("abc")
+
+    def test_generate_source_port_from_platform_oldwin_firstport(self):
+        self.assertTrue(1024 <= generate_source_port_from_platform("winxp") <= 5000)
+
+    def test_generate_source_port_from_platform_oldwin_nextport(self):
+        self.assertEqual(generate_source_port_from_platform("winxp", 2000), 2001)
+
+    def test_generate_source_port_from_platform_oldwin_maxport(self):
+        self.assertTrue(1024 <= generate_source_port_from_platform("winxp", 5000) <= 5000)
+
+    def test_generate_source_port_from_platform_linux(self):
+        self.assertTrue(32768 <= generate_source_port_from_platform("linux") <= 61000)
+
+    def test_generate_source_port_from_platform_newwinmac_firstport(self):
+        self.assertTrue(49152 <= generate_source_port_from_platform("win7") <= 65535)
+
+    def test_generate_source_port_from_platform_newwinmac_nextport(self):
+        self.assertEqual(generate_source_port_from_platform("win7", 50000), 50001)
+
+    def test_generate_source_port_from_platform_newwinmac_maxport(self):
+        self.assertTrue(49152 <= generate_source_port_from_platform("win7", 65535) <= 65535)
+
+    # Test get_filetime_format????
+
+    def test_get_rnd_boot_time_invalid(self):
+        with self.assertRaises(SystemExit):
+            get_rnd_boot_time(10, "abc")
+
+    def test_get_rnd_boot_time_linux(self):
+        self.assertTrue(get_rnd_boot_time(100, "linux") < 100)
+
+    def test_get_rnd_boot_time_macos(self):
+        self.assertTrue(get_rnd_boot_time(100, "macos") < 100)
+
+    def test_get_rnd_boot_time_win(self):
+        self.assertTrue(get_rnd_boot_time(100, "win7") < 100)
+
+    def test_get_rnd_x86_nop_len(self):
+        result = get_rnd_x86_nop(1000)
+        self.assertEqual(len(result), 1000)
+
+    def test_get_rnd_x86_nop_with_sideeffects(self):
+        result = get_rnd_x86_nop(1000, False)
+        correct = True
+        for byte in result:
+            if byte.to_bytes(1, "little") not in x86_nops and byte.to_bytes(1, "little") not in x86_pseudo_nops:
+                correct = False
+        self.assertTrue(correct)
+
+    def test_get_rnd_x86_nop_without_sideeffects(self):
+        result = get_rnd_x86_nop(1000, True)
+        correct = True
+        for byte in result:
+            if byte.to_bytes(1, "little") in x86_pseudo_nops:
+                correct = False
+        self.assertTrue(correct)
+
+    def test_get_rnd_x86_nop_filter(self):
+        result = get_rnd_x86_nop(1000, False, x86_nops.copy())
+        correct = True
+        for byte in result:
+            if byte.to_bytes(1, "little") in x86_nops:
+                correct = False
+        self.assertTrue(correct)
+
+    def test_get_rnd_bytes_number(self):
+        result = get_rnd_bytes(1000)
+        self.assertEqual(len(result), 1000)
+
+    def test_get_rnd_bytes_filter(self):
+        result = get_rnd_bytes(1000, x86_pseudo_nops.copy())
+        correct = True
+        for byte in result:
+            if byte.to_bytes(1, "little") in x86_pseudo_nops:
+                correct = False
+        self.assertTrue(correct)
+
+    def test_get_bytes_from_file_invalid_path(self):
+        with self.assertRaises(SystemExit):
+            get_bytes_from_file("resources/NonExistingFile.txt")
+
+    def test_get_bytes_from_file_invalid_header(self):
+        with self.assertRaises(SystemExit):
+            get_bytes_from_file("resources/InvalidHeader.txt")
+
+    def test_get_bytes_from_file_invalid_hexfile(self):
+        with self.assertRaises(SystemExit):
+            get_bytes_from_file("resources/InvalidHexFile.txt")
+
+    def test_get_bytes_from_file_str(self):
+        result = get_bytes_from_file("resources/StringTestFile.txt")
+        self.assertEqual(result, b'This is a string-test')
+
+    def test_get_bytes_from_file_hex(self):
+        result = get_bytes_from_file("resources/HexTestFile.txt")
+        self.assertEqual(result, b'\xab\xcd\xef\xff\x10\xff\xaa\xab')