Browse Source

Changed check_platform to raise an exception instead of exiting

Stefan Schmidt 6 years ago
parent
commit
f5c9c883d5
3 changed files with 8 additions and 16 deletions
  1. 3 4
      code/ID2TLib/Utility.py
  2. 2 6
      code/Test/test_SMBLib.py
  3. 3 6
      code/Test/test_Utility.py

+ 3 - 4
code/ID2TLib/Utility.py

@@ -110,7 +110,7 @@ def get_rnd_os():
     return os_dist.random()
 
 
-def check_platform(platform: str):
+def check_platform(platform: str) -> None:
     """
     Checks if the given platform is currently supported
     if not exits with error
@@ -118,9 +118,8 @@ def check_platform(platform: str):
     :param platform: the platform, which should be validated
     """
     if platform not in platforms:
-        print("\nERROR: Invalid platform: " + platform + "." +
-              "\n Please select one of the following platforms: ", platforms)
-        exit(1)
+        raise ValueError("ERROR: Invalid platform: " + platform + "." +
+                         "\n Please select one of the following platforms: " + ",".join(platforms))
 
 
 def get_ip_range(start_ip: str, end_ip: str):

+ 2 - 6
code/Test/test_SMBLib.py

@@ -6,7 +6,6 @@ import ID2TLib.Utility as Utility
 
 class TestSMBLib(unittest.TestCase):
     def test_get_smb_version_all(self):
-
         for platform in Utility.platforms:
             with self.subTest(platform):
                 result = SMBLib.get_smb_version(platform)
@@ -14,15 +13,13 @@ class TestSMBLib(unittest.TestCase):
                                  result in SMBLib.smb_versions_per_samba.values()))
 
     def test_get_smb_version_invalid(self):
-
-        with self.assertRaises(SystemExit):
+        with self.assertRaises(ValueError):
             SMBLib.get_smb_version("abc")
 
     def test_get_smb_version_mac(self):
         self.assertEqual(SMBLib.get_smb_version("macos"), "2.1")
 
     def test_get_smb_version_win(self):
-
         win_platforms = {'win7', 'win10', 'winxp', 'win8.1', 'win8', 'winvista', 'winnt', "win2000"}
 
         for platform in win_platforms:
@@ -33,8 +30,7 @@ class TestSMBLib(unittest.TestCase):
         self.assertIn(SMBLib.get_smb_version("linux"), SMBLib.smb_versions_per_samba.values())
 
     def test_get_smb_platform_data_invalid(self):
-
-        with self.assertRaises(SystemExit):
+        with self.assertRaises(ValueError):
             SMBLib.get_smb_platform_data("abc", 0)
 
     def test_get_smb_platform_data_linux(self):

+ 3 - 6
code/Test/test_Utility.py

@@ -53,13 +53,10 @@ class TestUtility(unittest.TestCase):
         self.assertIn(Utility.get_rnd_os(), Utility.platforms)
 
     def test_check_platform_valid(self):
-        try:
             Utility.check_platform("linux")
-        except SystemExit:
-            self.fail()
 
     def test_check_platform_invalid(self):
-        with self.assertRaises(SystemExit):
+        with self.assertRaises(ValueError):
             Utility.check_platform("abc")
 
     def test_get_ip_range_forwards(self):
@@ -81,7 +78,7 @@ class TestUtility(unittest.TestCase):
         self.assertEqual(Utility.get_ip_range(start, end), result)
 
     def test_generate_source_port_from_platform_invalid(self):
-        with self.assertRaises(SystemExit):
+        with self.assertRaises(ValueError):
             Utility.generate_source_port_from_platform("abc")
 
     def test_generate_source_port_from_platform_oldwin_firstport(self):
@@ -108,7 +105,7 @@ class TestUtility(unittest.TestCase):
     # TODO: get_filetime_format Test
 
     def test_get_rnd_boot_time_invalid(self):
-        with self.assertRaises(SystemExit):
+        with self.assertRaises(ValueError):
             Utility.get_rnd_boot_time(10, "abc")
 
     def test_get_rnd_boot_time_linux(self):