Browse Source

Store a "user_specified"-attribute in the label file for each parameter

Stefan Schmidt 6 years ago
parent
commit
a00555a540
3 changed files with 36 additions and 5 deletions
  1. 24 3
      code/Attack/BaseAttack.py
  2. 6 0
      code/Core/AttackController.py
  3. 6 2
      code/Core/LabelManager.py

+ 24 - 3
code/Attack/BaseAttack.py

@@ -9,6 +9,7 @@ import socket
 import sys
 import tempfile
 import time
+import collections
 
 # TODO: double check this import
 # does it complain because libpcapreader is not a .py?
@@ -27,6 +28,8 @@ class BaseAttack(metaclass=abc.ABCMeta):
     Abstract base class for all attack classes. Provides basic functionalities, like parameter validation.
     """
 
+    ValuePair = collections.namedtuple('ValuePair', ['value', 'user_specified'])
+
     def __init__(self, name, description, attack_type):
         """
         To be called within the individual attack class to initialize the required parameters.
@@ -311,13 +314,14 @@ class BaseAttack(metaclass=abc.ABCMeta):
         """
         return self.finish_time - self.start_time
 
-    def add_param_value(self, param, value):
+    def add_param_value(self, param, value, user_specified: bool = True):
         """
         Adds the pair param : value to the dictionary of attack parameters. Prints and error message and skips the
         parameter if the validation fails.
 
         :param param: Name of the parameter that we wish to modify.
         :param value: The value we wish to assign to the specified parameter.
+        :param user_specified: Whether the value was specified by the user (or left default)
         :return: None.
         """
 
@@ -384,7 +388,7 @@ class BaseAttack(metaclass=abc.ABCMeta):
 
         # add value iff validation was successful
         if is_valid:
-            self.params[param_name] = value
+            self.params[param_name] = self.ValuePair(value, user_specified)
         else:
             print("ERROR: Parameter " + str(param) + " or parameter value " + str(value) +
                   " not valid. Skipping parameter.")
@@ -396,7 +400,24 @@ class BaseAttack(metaclass=abc.ABCMeta):
         :param param: The parameter whose value is wanted.
         :return: The parameter's value.
         """
-        return self.params.get(param)
+        parameter = self.params.get(param)
+        if parameter is not None:
+            return parameter.value
+        else:
+            return None
+
+    def get_param_user_specified(self, param: atkParam.Parameter) -> bool:
+        """
+        Returns whether the parameter value was specified by the user for a given parameter.
+
+        :param param: The parameter whose user-specified flag is wanted.
+        :return: The parameter's user-specified flag.
+        """
+        parameter = self.params.get(param)
+        if parameter is not None:
+            return parameter.user_specified
+        else:
+            return False
 
     def check_parameters(self):
         """

+ 6 - 0
code/Core/AttackController.py

@@ -108,7 +108,13 @@ class AttackController:
         self.current_attack.set_statistics(self.statistics)
         if seed is not None:
             self.current_attack.set_seed(seed=seed)
+
         self.current_attack.init_params()
+
+        # Unset the user-specified-flag for all parameters set in init_params
+        for k, v in self.current_attack.params.items():
+            self.current_attack.params[k] = self.current_attack.ValuePair(v.value, False)
+
         # Record the attack
         self.added_attacks.append(self.current_attack)
 

+ 6 - 2
code/Core/LabelManager.py

@@ -17,6 +17,7 @@ class LabelManager:
     TAG_TIMESTAMP_HR = 'timestamp_hr'
     TAG_PARAMETERS = 'parameters'
     ATTR_VERSION = 'version_parser'
+    ATTR_PARAM_USERSPECIFIED = 'user_specified'
 
     # update this attribute if XML scheme was modified
     ATTR_VERSION_VALUE = '0.3'
@@ -95,7 +96,8 @@ class LabelManager:
 
             for param_key, param_value in parameters.items():
                 param = doc.createElement(param_key.value)
-                param.appendChild(doc.createTextNode(str(param_value)))
+                param.appendChild(doc.createTextNode(str(param_value.value)))
+                param.setAttribute(self.ATTR_PARAM_USERSPECIFIED, str(param_value.user_specified))
                 parameters_root.appendChild(param)
 
             return parameters_root
@@ -195,9 +197,11 @@ class LabelManager:
             for param in param.childNodes:
                 # Skip empty text nodes returned by minidom
                 if not isinstance(param, minidom.Text):
+                    import distutils.util
                     param_name = param.tagName
                     param_value = param.childNodes[0].nodeValue
-                    attack.add_param_value(param_name, param_value)
+                    param_userspecified = bool(distutils.util.strtobool(param.getAttribute(self.ATTR_PARAM_USERSPECIFIED)))
+                    attack.add_param_value(param_name, param_value, param_userspecified)
 
             # Create the label from the data read
             label = Label.Label(attack_name, float(timestamp_start), float(timestamp_end), attack.params, attack_note)