Browse Source

Merge branch 'label_improvements' of stefan.schmidt/ID2T-toolkit into master

Carlos Garcia 6 years ago
parent
commit
ee04907762
4 changed files with 136 additions and 25 deletions
  1. 30 16
      code/Attack/BaseAttack.py
  2. 17 3
      code/Core/AttackController.py
  3. 85 5
      code/Core/LabelManager.py
  4. 4 1
      code/ID2TLib/Label.py

+ 30 - 16
code/Attack/BaseAttack.py

@@ -9,6 +9,7 @@ import socket
 import sys
 import tempfile
 import time
+import collections
 
 # TODO: double check this import
 # does it complain because libpcapreader is not a .py?
@@ -27,6 +28,8 @@ class BaseAttack(metaclass=abc.ABCMeta):
     Abstract base class for all attack classes. Provides basic functionalities, like parameter validation.
     """
 
+    ValuePair = collections.namedtuple('ValuePair', ['value', 'user_specified'])
+
     def __init__(self, name, description, attack_type):
         """
         To be called within the individual attack class to initialize the required parameters.
@@ -250,7 +253,7 @@ class BaseAttack(metaclass=abc.ABCMeta):
         try:
             import distutils.core
             import distutils.util
-            value = distutils.util.strtobool(value.lower())
+            value = bool(distutils.util.strtobool(value.lower()))
             is_bool = True
         except ValueError:
             is_bool = False
@@ -311,19 +314,16 @@ class BaseAttack(metaclass=abc.ABCMeta):
         """
         return self.finish_time - self.start_time
 
-    def add_param_value(self, param, value):
+    def add_param_value(self, param, value, user_specified: bool = True):
         """
         Adds the pair param : value to the dictionary of attack parameters. Prints and error message and skips the
         parameter if the validation fails.
 
         :param param: Name of the parameter that we wish to modify.
         :param value: The value we wish to assign to the specified parameter.
+        :param user_specified: Whether the value was specified by the user (or left default)
         :return: None.
         """
-        # This function call is valid only if there is a statistics object available.
-        if self.statistics is None:
-            print('Error: Attack parameter added without setting a statistics object first.')
-            exit(1)
 
         # by default no param is valid
         is_valid = False
@@ -344,14 +344,6 @@ class BaseAttack(metaclass=abc.ABCMeta):
         if param_type is None:
             print('Parameter ' + str(param_name) + ' not available for chosen attack. Skipping parameter.')
 
-        # If value is query -> get value from database
-        elif self.statistics.is_query(value):
-            value = self.statistics.process_db_query(value, False)
-            if value is not None and value is not "":
-                is_valid = True
-            else:
-                print('Error in given parameter value: ' + str(value) + '. Data could not be retrieved.')
-
         # Validate parameter depending on parameter's type
         elif param_type == atkParam.ParameterTypes.TYPE_IP_ADDRESS:
             is_valid, value = self._is_ip_address(value)
@@ -381,6 +373,11 @@ class BaseAttack(metaclass=abc.ABCMeta):
         elif param_type == atkParam.ParameterTypes.TYPE_BOOLEAN:
             is_valid, value = self._is_boolean(value)
         elif param_type == atkParam.ParameterTypes.TYPE_PACKET_POSITION:
+            # This function call is valid only if there is a statistics object available.
+            if self.statistics is None:
+                print('Error: Statistics-dependent attack parameter added without setting a statistics object first.')
+                exit(1)
+
             ts = pr.pcap_processor(self.statistics.pcap_filepath, "False").get_timestamp_mu_sec(int(value))
             if 0 <= int(value) <= self.statistics.get_packet_count() and ts >= 0:
                 is_valid = True
@@ -391,7 +388,7 @@ class BaseAttack(metaclass=abc.ABCMeta):
 
         # add value iff validation was successful
         if is_valid:
-            self.params[param_name] = value
+            self.params[param_name] = self.ValuePair(value, user_specified)
         else:
             print("ERROR: Parameter " + str(param) + " or parameter value " + str(value) +
                   " not valid. Skipping parameter.")
@@ -403,7 +400,24 @@ class BaseAttack(metaclass=abc.ABCMeta):
         :param param: The parameter whose value is wanted.
         :return: The parameter's value.
         """
-        return self.params.get(param)
+        parameter = self.params.get(param)
+        if parameter is not None:
+            return parameter.value
+        else:
+            return None
+
+    def get_param_user_specified(self, param: atkParam.Parameter) -> bool:
+        """
+        Returns whether the parameter value was specified by the user for a given parameter.
+
+        :param param: The parameter whose user-specified flag is wanted.
+        :return: The parameter's user-specified flag.
+        """
+        parameter = self.params.get(param)
+        if parameter is not None:
+            return parameter.user_specified
+        else:
+            return False
 
     def check_parameters(self):
         """

+ 17 - 3
code/Core/AttackController.py

@@ -2,6 +2,7 @@ import importlib
 import sys
 import difflib
 import pkgutil
+import typing
 
 import Attack.AttackParameters as atkParam
 import Core.LabelManager as LabelManager
@@ -29,7 +30,7 @@ class AttackController:
         self.seed = None
         self.total_packets = 0
 
-    def set_seed(self, seed: int):
+    def set_seed(self, seed: int) -> None:
         """
         Sets rng seed.
 
@@ -37,6 +38,13 @@ class AttackController:
         """
         self.seed = seed
 
+    def get_seed(self) -> typing.Union[int, None]:
+        """
+        Gets rng seed.
+        :return: The current rng seed
+        """
+        return self.seed
+
     @staticmethod
     def choose_attack(input_name):
         """"
@@ -108,7 +116,13 @@ class AttackController:
         self.current_attack.set_statistics(self.statistics)
         if seed is not None:
             self.current_attack.set_seed(seed=seed)
+
         self.current_attack.init_params()
+
+        # Unset the user-specified-flag for all parameters set in init_params
+        for k, v in self.current_attack.params.items():
+            self.current_attack.params[k] = self.current_attack.ValuePair(v.value, False)
+
         # Record the attack
         self.added_attacks.append(self.current_attack)
 
@@ -169,8 +183,8 @@ class AttackController:
         print(".)")
 
         # Store label into LabelManager
-        label = Label.Label(attack, self.get_attack_start_utime(),
-                            self.get_attack_end_utime(), attack_note)
+        label = Label.Label(attack, self.get_attack_start_utime(), self.get_attack_end_utime(),
+                            self.seed, self.current_attack.params, attack_note)
         self.label_mgr.add_labels(label)
 
         return temp_attack_pcap_path, duration

+ 85 - 5
code/Core/LabelManager.py

@@ -1,23 +1,32 @@
+import importlib
 import datetime as dt
 import os.path
 import xml.dom.minidom as minidom
 
 import ID2TLib.Label as Label
+import ID2TLib.TestLibrary as Lib
 
 
 class LabelManager:
-    TAG_ROOT = 'LABELS'
+    TAG_ROOT = 'labels'
+    TAG_INPUT = 'input'
+    TAG_OUTPUT = 'output'
+    TAG_FILE_NAME = 'filename'
+    TAG_FILE_HASH = 'sha256'
     TAG_ATTACK = 'attack'
-    TAG_ATTACK_NAME = 'attack_name'
-    TAG_ATTACK_NOTE = 'attack_note'
+    TAG_ATTACK_NAME = 'name'
+    TAG_ATTACK_NOTE = 'note'
+    TAG_ATTACK_SEED = 'seed'
     TAG_TIMESTAMP_START = 'timestamp_start'
     TAG_TIMESTAMP_END = 'timestamp_end'
     TAG_TIMESTAMP = 'timestamp'
     TAG_TIMESTAMP_HR = 'timestamp_hr'
+    TAG_PARAMETERS = 'parameters'
     ATTR_VERSION = 'version_parser'
+    ATTR_PARAM_USERSPECIFIED = 'user_specified'
 
     # update this attribute if XML scheme was modified
-    ATTR_VERSION_VALUE = '0.2'
+    ATTR_VERSION_VALUE = '0.3'
 
     def __init__(self, filepath_pcap=None):
         """
@@ -26,6 +35,7 @@ class LabelManager:
         :param filepath_pcap: The path to the PCAP file associated to the labels.
         """
         self.labels = list()
+        self.filepath_input_pcap = filepath_pcap
 
         if filepath_pcap is not None:
             self.label_file_path = os.path.splitext(filepath_pcap)[0] + '_labels.xml'
@@ -58,6 +68,25 @@ class LabelManager:
         :param filepath: The path where the label file should be written to.
         """
 
+        def get_subtree_fileinfo(xml_tag_root, filename) -> minidom.Element:
+            """
+            Creates the subtree for pcap file information (filename and hash).
+
+            :return: The root node of the XML subtree
+            """
+
+            input_root = doc.createElement(xml_tag_root)
+
+            file = doc.createElement(self.TAG_FILE_NAME)
+            file.appendChild(doc.createTextNode(os.path.split(filename)[-1]))
+            input_root.appendChild(file)
+
+            hash_node = doc.createElement(self.TAG_FILE_HASH)
+            hash_node.appendChild(doc.createTextNode(Lib.get_sha256(filename)))
+            input_root.appendChild(hash_node)
+
+            return input_root
+
         def get_subtree_timestamp(xml_tag_root, timestamp_entry):
             """
             Creates the subtree for a given timestamp, consisting of the unix time format (seconds) and a human-readable
@@ -82,6 +111,23 @@ class LabelManager:
 
             return timestamp_root
 
+        def get_subtree_parameters(parameters):
+            """
+            Creates a subtree containing all parameters used to construct the attack
+
+            :param parameters: The list of parameters used to run the attack
+            :return: The root node of the XML subtree
+            """
+            parameters_root = doc.createElement(self.TAG_PARAMETERS)
+
+            for param_key, param_value in parameters.items():
+                param = doc.createElement(param_key.value)
+                param.appendChild(doc.createTextNode(str(param_value.value)))
+                param.setAttribute(self.ATTR_PARAM_USERSPECIFIED, str(param_value.user_specified))
+                parameters_root.appendChild(param)
+
+            return parameters_root
+
         if filepath is not None:
             self.label_file_path = os.path.splitext(filepath)[0] + '_labels.xml'
 
@@ -89,6 +135,9 @@ class LabelManager:
         doc = minidom.Document()
         node = doc.createElement(self.TAG_ROOT)
         node.setAttribute(self.ATTR_VERSION, self.ATTR_VERSION_VALUE)
+        node.appendChild(get_subtree_fileinfo(self.TAG_INPUT, self.filepath_input_pcap))
+        node.appendChild(get_subtree_fileinfo(self.TAG_OUTPUT, filepath))
+
         for label in self.labels:
             xml_tree = doc.createElement(self.TAG_ATTACK)
 
@@ -99,6 +148,9 @@ class LabelManager:
             attack_note = doc.createElement(self.TAG_ATTACK_NOTE)
             attack_note.appendChild(doc.createTextNode(str(label.attack_note)))
             xml_tree.appendChild(attack_note)
+            attack_seed = doc.createElement(self.TAG_ATTACK_SEED)
+            attack_seed.appendChild(doc.createTextNode(str(label.seed)))
+            xml_tree.appendChild(attack_seed)
 
             # add timestamp_start to XML tree
             xml_tree.appendChild(get_subtree_timestamp(self.TAG_TIMESTAMP_START, label.timestamp_start))
@@ -106,6 +158,9 @@ class LabelManager:
             # add timestamp_end to XML tree
             xml_tree.appendChild(get_subtree_timestamp(self.TAG_TIMESTAMP_END, label.timestamp_end))
 
+            # add parameters to XML tree
+            xml_tree.appendChild(get_subtree_parameters(label.parameters))
+
             node.appendChild(xml_tree)
 
         doc.appendChild(node)
@@ -155,6 +210,11 @@ class LabelManager:
                     "The file " + self.label_file_path + " was created by another version of ID2TLib.LabelManager. "
                                                          "Ignoring label file.")
 
+        self.input_filename = get_value_from_node(dom, self.TAG_INPUT, 1, 0)
+        self.input_hash = get_value_from_node(dom, self.TAG_INPUT, 3, 0)
+        self.output_filename = get_value_from_node(dom, self.TAG_OUTPUT, 1, 0)
+        self.output_hash = get_value_from_node(dom, self.TAG_OUTPUT, 3, 0)
+
         # Parse attacks from XML file
         attacks = dom.getElementsByTagName(self.TAG_ATTACK)
         count_labels = 0
@@ -163,7 +223,27 @@ class LabelManager:
             attack_note = get_value_from_node(a, self.TAG_ATTACK_NOTE, 0)
             timestamp_start = get_value_from_node(a, self.TAG_TIMESTAMP_START, 1, 0)
             timestamp_end = get_value_from_node(a, self.TAG_TIMESTAMP_END, 1, 0)
-            label = Label.Label(attack_name, float(timestamp_start), float(timestamp_end), attack_note)
+            attack_seed = get_value_from_node(a, self.TAG_ATTACK_SEED, 0)
+
+            # Instantiate this attack to create a parameter list with the correct types
+            attack_module = importlib.import_module("Attack." + attack_name)
+            attack_class = getattr(attack_module, attack_name)
+            attack = attack_class()
+
+            # Loop through all parameters listed in the XML file
+            param = a.getElementsByTagName(self.TAG_PARAMETERS)[0]
+            for param in param.childNodes:
+                # Skip empty text nodes returned by minidom
+                if not isinstance(param, minidom.Text):
+                    import distutils.util
+                    param_name = param.tagName
+                    param_value = param.childNodes[0].nodeValue
+                    param_userspecified = bool(distutils.util.strtobool(param.getAttribute(self.ATTR_PARAM_USERSPECIFIED)))
+                    attack.add_param_value(param_name, param_value, param_userspecified)
+
+            # Create the label from the data read
+            label = Label.Label(attack_name, float(timestamp_start), float(timestamp_end), attack_seed, attack.params,
+                                attack_note)
             self.labels.append(label)
             count_labels += 1
 

+ 4 - 1
code/ID2TLib/Label.py

@@ -3,19 +3,22 @@ import functools
 
 @functools.total_ordering
 class Label:
-    def __init__(self, attack_name, timestamp_start, timestamp_end, attack_note=""):
+    def __init__(self, attack_name, timestamp_start, timestamp_end, seed, parameters, attack_note=""):
         """
         Creates a new attack label
 
         :param attack_name: The name of the associated attack
         :param timestamp_start: The timestamp as unix time of the first attack packet
         :param timestamp_end: The timestamp as unix time of the last attack packet
+        :param parameters: The list of parameters used to run the attack
         :param attack_note: A note associated to the attack (optional)
         """
         self.attack_name = attack_name
         self.timestamp_start = timestamp_start
         self.timestamp_end = timestamp_end
+        self.seed = seed
         self.attack_note = attack_note
+        self.parameters = parameters
 
     def __eq__(self, other):
         return self.timestamp == other.timestamp