Quellcode durchsuchen

Merge branch 'various_improvements' of stefan.schmidt/ID2T-toolkit into master

Carlos Garcia vor 6 Jahren
Ursprung
Commit
3912b75273

+ 3 - 5
code/Core/Controller.py

@@ -59,7 +59,6 @@ class Controller:
         :param flag_write_file: Writes the statistics to a file.
         :param flag_recalculate_stats: Forces the recalculation of statistics.
         :param flag_print_statistics: Prints the statistics on the terminal.
-        :param flag_non_verbose: Reduces terminal clutter.
         :return: None
         """
         self.statistics.load_pcap_statistics(flag_write_file, flag_recalculate_stats, flag_print_statistics,
@@ -188,11 +187,10 @@ class Controller:
             self.statisticsDB.process_db_query(query, print_results)
 
     @staticmethod
-    def process_help(params):
+    def process_help(params) -> None:
         """
-        TODO: FILL ME
-        :param params:
-        :return:
+        Prints either general help messages, or information about specific commands.
+        :param params: A list of parameters for the help command (can be empty).
         """
         if not params:
             print("Query mode allows you to enter SQL-queries as well as named queries.")

+ 6 - 7
code/ID2TLib/Utility.py

@@ -110,7 +110,7 @@ def get_rnd_os():
     return os_dist.random()
 
 
-def check_platform(platform: str):
+def check_platform(platform: str) -> None:
     """
     Checks if the given platform is currently supported
     if not exits with error
@@ -118,9 +118,8 @@ def check_platform(platform: str):
     :param platform: the platform, which should be validated
     """
     if platform not in platforms:
-        print("\nERROR: Invalid platform: " + platform + "." +
-              "\n Please select one of the following platforms: ", platforms)
-        exit(1)
+        raise ValueError("ERROR: Invalid platform: " + platform + "." +
+                         "\n Please select one of the following platforms: " + ",".join(platforms))
 
 
 def get_ip_range(start_ip: str, end_ip: str):
@@ -248,7 +247,7 @@ def get_rnd_bytes(count=1, ignore=None):
     return result
 
 
-def check_payload_len(payload_len: int, limit: int):
+def check_payload_len(payload_len: int, limit: int) -> None:
     """
     Checks if the len of the payload exceeds a given limit
 
@@ -257,8 +256,8 @@ def check_payload_len(payload_len: int, limit: int):
     """
 
     if payload_len > limit:
-        print("\nCustom payload too long: ", payload_len, " bytes. Should be a maximum of ", limit, " bytes.")
-        exit(1)
+        raise ValueError("Custom payload too long: " + str(payload_len) +
+                         " bytes. Should be a maximum of " + str(limit) + " bytes.")
 
 
 def get_bytes_from_file(filepath):

+ 60 - 0
code/Test/test_Controller.py

@@ -0,0 +1,60 @@
+import unittest
+import unittest.mock as mock
+import Core.Controller as Ctrl
+
+
+class TestController(unittest.TestCase):
+    @mock.patch("builtins.print")
+    def test_process_help(self, mock_print):
+        Ctrl.Controller.process_help(None)
+        self.assertTrue(mock_print.called)
+
+    @mock.patch("builtins.print")
+    def test_process_help_most_used(self, mock_print):
+        Ctrl.Controller.process_help(["most_used"])
+        self.assertTrue(mock_print.called)
+
+    @mock.patch("builtins.print")
+    def test_process_help_least_used(self, mock_print):
+        Ctrl.Controller.process_help(["least_used"])
+        self.assertTrue(mock_print.called)
+
+    @mock.patch("builtins.print")
+    def test_process_help_avg(self, mock_print):
+        Ctrl.Controller.process_help(["avg"])
+        self.assertTrue(mock_print.called)
+
+    @mock.patch("builtins.print")
+    def test_process_help_all(self, mock_print):
+        Ctrl.Controller.process_help(["all"])
+        self.assertTrue(mock_print.called)
+
+    @mock.patch("builtins.print")
+    def test_process_help_random(self, mock_print):
+        Ctrl.Controller.process_help(["random"])
+        self.assertTrue(mock_print.called)
+
+    @mock.patch("builtins.print")
+    def test_process_help_first(self, mock_print):
+        Ctrl.Controller.process_help(["first"])
+        self.assertTrue(mock_print.called)
+
+    @mock.patch("builtins.print")
+    def test_process_help_last(self, mock_print):
+        Ctrl.Controller.process_help(["last"])
+        self.assertTrue(mock_print.called)
+
+    @mock.patch("builtins.print")
+    def test_process_help_ipaddress(self, mock_print):
+        Ctrl.Controller.process_help(["ipaddress"])
+        self.assertTrue(mock_print.called)
+
+    @mock.patch("builtins.print")
+    def test_process_help_macaddress(self, mock_print):
+        Ctrl.Controller.process_help(["macaddress"])
+        self.assertTrue(mock_print.called)
+
+    @mock.patch("builtins.print")
+    def test_process_help_examples(self, mock_print):
+        Ctrl.Controller.process_help(["examples"])
+        self.assertTrue(mock_print.called)

+ 1 - 1
code/Test/test_NamedQueries.py

@@ -78,7 +78,7 @@ class UnitTestNamedQueries(unittest.TestCase):
     def test_least_used_mssvalue(self):
         self.assertEqual(controller.statistics.process_db_query('least_used(mssvalue)'), 1460)
 
-    def least_used_winsize(self):
+    def test_least_used_winsize(self):
         self.assertEqual(controller.statistics.process_db_query('least_used(winsize)'), leastUsedWinASize)
 
     def test_least_used_ipclass(self):

+ 2 - 6
code/Test/test_SMBLib.py

@@ -6,7 +6,6 @@ import ID2TLib.Utility as Utility
 
 class TestSMBLib(unittest.TestCase):
     def test_get_smb_version_all(self):
-
         for platform in Utility.platforms:
             with self.subTest(platform):
                 result = SMBLib.get_smb_version(platform)
@@ -14,15 +13,13 @@ class TestSMBLib(unittest.TestCase):
                                  result in SMBLib.smb_versions_per_samba.values()))
 
     def test_get_smb_version_invalid(self):
-
-        with self.assertRaises(SystemExit):
+        with self.assertRaises(ValueError):
             SMBLib.get_smb_version("abc")
 
     def test_get_smb_version_mac(self):
         self.assertEqual(SMBLib.get_smb_version("macos"), "2.1")
 
     def test_get_smb_version_win(self):
-
         win_platforms = {'win7', 'win10', 'winxp', 'win8.1', 'win8', 'winvista', 'winnt', "win2000"}
 
         for platform in win_platforms:
@@ -33,8 +30,7 @@ class TestSMBLib(unittest.TestCase):
         self.assertIn(SMBLib.get_smb_version("linux"), SMBLib.smb_versions_per_samba.values())
 
     def test_get_smb_platform_data_invalid(self):
-
-        with self.assertRaises(SystemExit):
+        with self.assertRaises(ValueError):
             SMBLib.get_smb_platform_data("abc", 0)
 
     def test_get_smb_platform_data_linux(self):

+ 22 - 41
code/Test/test_Utility.py

@@ -53,13 +53,10 @@ class TestUtility(unittest.TestCase):
         self.assertIn(Utility.get_rnd_os(), Utility.platforms)
 
     def test_check_platform_valid(self):
-        try:
             Utility.check_platform("linux")
-        except SystemExit:
-            self.fail()
 
     def test_check_platform_invalid(self):
-        with self.assertRaises(SystemExit):
+        with self.assertRaises(ValueError):
             Utility.check_platform("abc")
 
     def test_get_ip_range_forwards(self):
@@ -81,7 +78,7 @@ class TestUtility(unittest.TestCase):
         self.assertEqual(Utility.get_ip_range(start, end), result)
 
     def test_generate_source_port_from_platform_invalid(self):
-        with self.assertRaises(SystemExit):
+        with self.assertRaises(ValueError):
             Utility.generate_source_port_from_platform("abc")
 
     def test_generate_source_port_from_platform_oldwin_firstport(self):
@@ -108,7 +105,7 @@ class TestUtility(unittest.TestCase):
     # TODO: get_filetime_format Test
 
     def test_get_rnd_boot_time_invalid(self):
-        with self.assertRaises(SystemExit):
+        with self.assertRaises(ValueError):
             Utility.get_rnd_boot_time(10, "abc")
 
     def test_get_rnd_boot_time_linux(self):
@@ -126,40 +123,29 @@ class TestUtility(unittest.TestCase):
 
     def test_get_rnd_x86_nop_with_sideeffects(self):
         result = Utility.get_rnd_x86_nop(1000, False)
-        correct = True
-        for byte in result:
-            if byte.to_bytes(1, "little") not in Utility.x86_nops \
-                    and byte.to_bytes(1, "little") not in Utility.x86_pseudo_nops:
-                correct = False
-                break
-        self.assertTrue(correct)
+        for i in range(0, len(result)):
+            with self.subTest(i=i):
+                self.assertTrue(result[i].to_bytes(1, "little") in Utility.x86_nops or
+                                result[i].to_bytes(1, "little") in Utility.x86_pseudo_nops)
 
     def test_get_rnd_x86_nop_without_sideeffects(self):
         result = Utility.get_rnd_x86_nop(1000, True)
-        correct = True
-        for byte in result:
-            if byte.to_bytes(1, "little") in Utility.x86_pseudo_nops:
-                correct = False
-                break
-        self.assertTrue(correct)
+        for i in range(0, len(result)):
+            with self.subTest(i=i):
+                self.assertIn(result[i].to_bytes(1, "little"), Utility.x86_nops)
+                self.assertNotIn(result[i].to_bytes(1, "little"), Utility.x86_pseudo_nops)
 
     def test_get_rnd_x86_nop_filter(self):
         result = Utility.get_rnd_x86_nop(1000, False, Utility.x86_nops.copy())
-        correct = True
-        for byte in result:
-            if byte.to_bytes(1, "little") in Utility.x86_nops:
-                correct = False
-                break
-        self.assertTrue(correct)
+        for i in range(0, len(result)):
+            with self.subTest(i=i):
+                self.assertNotIn(result[i].to_bytes(1, "little"), Utility.x86_nops)
 
     def test_get_rnd_x86_nop_single_filter(self):
         result = Utility.get_rnd_x86_nop(1000, False, b'\x20')
-        correct = True
-        for byte in result:
-            if byte.to_bytes(1, "little") == b'\x20':
-                correct = False
-                break
-        self.assertTrue(correct)
+        for i in range(0, len(result)):
+            with self.subTest(i=i):
+                self.assertNotEqual(result[i].to_bytes(1, "little"), b'\x20')
 
     def test_get_rnd_bytes_number(self):
         result = Utility.get_rnd_bytes(1000)
@@ -167,11 +153,9 @@ class TestUtility(unittest.TestCase):
 
     def test_get_rnd_bytes_filter(self):
         result = Utility.get_rnd_bytes(1000, Utility.x86_pseudo_nops.copy())
-        correct = True
-        for byte in result:
-            if byte.to_bytes(1, "little") in Utility.x86_pseudo_nops:
-                correct = False
-        self.assertTrue(correct)
+        for i in range(0, len(result)):
+            with self.subTest(i=i):
+                self.assertNotIn(result[i].to_bytes(1, "little"), Utility.x86_pseudo_nops)
 
     def test_get_bytes_from_file_invalid_path(self):
         with self.assertRaises(SystemExit):
@@ -217,14 +201,11 @@ class TestUtility(unittest.TestCase):
         self.assertEqual(Utility.handle_most_used_outputs(test_input), 0)
 
     def test_check_payload_len_exceeded(self):
-        with self.assertRaises(SystemExit):
+        with self.assertRaises(ValueError):
             Utility.check_payload_len(10, 5)
 
     def test_check_payload_len_valid(self):
-        try:
-            Utility.check_payload_len(5, 10)
-        except SystemExit:
-            self.fail()
+        Utility.check_payload_len(5, 10)
 
     def test_remove_generic_ending_attack(self):
         self.assertEqual(Utility.remove_generic_ending("someattack"), "some")