Przeglądaj źródła

Fixed seed- and label-bug

Denis Waßmann 7 lat temu
rodzic
commit
dcc9a35c92
4 zmienionych plików z 33 dodań i 21 usunięć
  1. 2 1
      code/CLI.py
  2. 1 1
      code/ID2TLib/Controller.py
  3. 3 2
      code/ID2TLib/LabelManager.py
  4. 27 17
      code/ID2TLib/Ports.py

+ 2 - 1
code/CLI.py

@@ -92,7 +92,8 @@ class CLI(object):
     def seed_rng(self, seed):
         try: # try to convert the seed to int
             seed = int(seed)
-        except: pass # otherwise use the strings hash, random.seed does this automatically
+        except:
+            seed = hash(seed) # otherwise use the strings hash
         
         random.seed(seed)
         numpy.random.seed(seed)

+ 1 - 1
code/ID2TLib/Controller.py

@@ -127,4 +127,4 @@ class Controller:
             params_dict = dict([z.split("=") for z in params])
             self.statistics.plot_statistics(format=params_dict['format'])
         else:
-            self.statistics.plot_statistics()
+            self.statistics.plot_statistics()

+ 3 - 2
code/ID2TLib/LabelManager.py

@@ -28,7 +28,8 @@ class LabelManager:
         self.labels = list()
 
         if filepath_pcap is not None:
-            self.label_file_path = filepath_pcap.strip('.pcap') + '_labels.xml'
+            # splitext gives us the filename without extension
+            self.label_file_path = os.path.splitext(filepath_pcap)[0] + '_labels.xml'
             # only load labels if label file is existing
             if os.path.exists(self.label_file_path):
                 self.load_labels()
@@ -83,7 +84,7 @@ class LabelManager:
             return timestamp_root
 
         if filepath is not None:
-            self.label_file_path = filepath.strip('.pcap') + '_labels.xml'
+            self.label_file_path = os.path.splitext(filepath)[0] + '_labels.xml' # splitext removes the file extension
 
         # Generate XML
         doc = Document()

+ 27 - 17
code/ID2TLib/Ports.py

@@ -1,4 +1,4 @@
-import random
+import random, copy
 
 # information taken from https://www.cymru.com/jtk/misc/ephemeralports.html
 class PortRanges:
@@ -16,31 +16,28 @@ class PortRanges:
 	WINDOWS_VISTA = DYNAMIC_PORTS
 	WINDOWS_XP = range(1024, 5001)
 
+# This class uses classes instead of functions so deepcloning works
 class PortSelectionStrategy:
-	@staticmethod
-	def sequential():
-		counter = -1
+	class sequential:
+		def __init__(self):
+			self.counter = -1
 		
 		# that function will always return a one higher counter than before,
 		# restarting from the start once it reached the highest value
-		def select_port(port_range):
-			global counter
-			if counter == -1:
-				counter = port_range.start
+		def __call__(port_range):
+			if self.counter == -1:
+				self.counter = port_range.start
 			
 			port = counter
 			
-			counter += 1
-			if counter == port_range.stop:
-				counter = port_range.start
+			self.counter += 1
+			if self.counter == port_range.stop:
+				self.counter = port_range.start
 			
 			return port
-		
-		return select_port
-	
-	@staticmethod
-	def random(port_range):
-		return random.randrange(port_range.start, port_range.stop)
+	class random:
+		def __call__(port_range):
+			return random.randrange(port_range.start, port_range.stop)
 
 class PortSelector:
 	def __init__(self, port_range, select_function):
@@ -75,6 +72,9 @@ class PortSelector:
 	
 	def clear(self):
 		self.generated = []
+	
+	def clone(self):
+		return copy.deepcopy(self)
 
 class ProtocolPortSelector:
 	def __init__(self, port_range, select_tcp, select_udp = None):
@@ -99,6 +99,16 @@ class ProtocolPortSelector:
 	def is_port_in_use_udp(self, port):
 		return self.udp.is_port_in_use(port)
 	
+	def clone(self):
+		class Tmp: pass
+		clone = Tmp()
+		clone.__class__ = type(self)
+		
+		clone.udp = self.udp.clone()
+		clone.tcp = self.tcp.clone()
+		
+		return clone
+	
 	def __getattr__(self, attr):
 		val = getattr(self.tcp, attr)