import random
from .. import IPv4 as ip


class IPChooser:
	def random_ip(self) -> ip.IPAddress:
		return ip.IPAddress.from_int(random.randrange(0, 1 << 32))
	
	def size(self) -> int:
		return 1 << 32
	
	def __len__(self) -> int:
		return self.size()

class IPChooserByRange(IPChooser):
	def __init__(self, ip_range: ip.IPAddressBlock) -> "IPChooserByRange":
		self.range = ip_range
	
	def random_ip(self) -> ip.IPAddress:
		start = int(self.range.first_address())
		end = start + self.range.block_size()
		return ip.IPAddress.from_int(random.randrange(start, end))
	
	def size(self) -> int:
		return self.range.block_size()

class IPChooserByList(IPChooser):
	def __init__(self, ips: "list[ip.IPAddress]") -> "IPChooserByList":
		self.ips = list(ips)
		if not self.ips:
			raise ValueError("list of ips must not be empty")
	
	def random_ip(self) -> ip.IPAddress:
		return random.choice(self.ips)
	
	def size(self) -> int:
		return len(self.ips)

class IPGenerator:
	def __init__(self, ip_chooser = IPChooser(), # include all ip-addresses by default (before the blacklist)
			include_private_ips = False, include_localhost = False,
			include_multicast = False, include_reserved = False,
			include_link_local = False, blacklist = None) -> "IPGenerator":
		self.blacklist = []
		self.generated_ips = set()
		
		if not include_private_ips:
			for segment in ip.ReservedIPBlocks.PRIVATE_IP_SEGMENTS:
				self.add_to_blacklist(segment)
		if not include_localhost:
			self.add_to_blacklist(ip.ReservedIPBlocks.LOCALHOST_SEGMENT)
		if not include_multicast:
			self.add_to_blacklist(ip.ReservedIPBlocks.MULTICAST_SEGMENT)
		if not include_reserved:
			self.add_to_blacklist(ip.ReservedIPBlocks.RESERVED_SEGMENT)
		if not include_link_local:
			self.add_to_blacklist(ip.ReservedIPBlocks.ZERO_CONF_SEGMENT)
		if blacklist:
			for segment in blacklist:
				self.add_to_blacklist(segment)
		self.chooser = ip_chooser
	
	@staticmethod
	def from_range(range: ip.IPAddressBlock, *args, **kwargs) -> "IPGenerator":
		return IPGenerator(IPChooserByRange(range), *args, **kwargs)
	
	def add_to_blacklist(self, ip_segment: "Union[ip.IPAddressBlock, str]"):
		if isinstance(ip_segment, ip.IPAddressBlock):
			self.blacklist.append(ip_segment)
		else:
			self.blacklist.append(ip.IPAddressBlock.parse(ip_segment))
	
	def random_ip(self) -> ip.IPAddress:
		if len(self.generated_ips) == self.chooser.size():
			raise ValueError("Exhausted the space of possible ip-addresses, no new unique ip-address can be generated")
		
		while True:
			random_ip = self.chooser.random_ip()
			
			if not self._is_in_blacklist(random_ip) and random_ip not in self.generated_ips:
				self.generated_ips.add(random_ip)
				return str(random_ip)
	
	def clear(self, clear_blacklist = True, clear_generated_ips = True):
		if clear_blacklist: self.blacklist.clear()
		if clear_generated_ips: self.generated_ips.clear()
	
	def _is_in_blacklist(self, ip: ip.IPAddress) -> bool:
		return any(ip in block for block in self.blacklist)

class MappingIPGenerator(IPGenerator):
	def __init__(self, *args, **kwargs) -> "MappingIPGenerator":
		super().__init__(self, *args, **kwargs)
		
		self.mapping = {}
	
	def clear(self, clear_generated_ips = True, *args, **kwargs):
		super().clear(self, clear_generated_ips = clear_generated_ips, *args, **kwargs)
		if clear_generated_ips:
			self.mapping  = {}
	
	def get_mapped_ip(self, key) -> ip.IPAddress:
		if key not in self.mapping:
			self.mapping[key] = self.random_ip()
		
		return self.mapping[key]
	
	def __getitem__(self, item) -> ip.IPAddress:
		return self.get_mapped_ip(item)