#!/usr/bin/python3

import sys, subprocess
import xml.etree.ElementTree as ET
from collections import namedtuple, defaultdict

verbose = False

# Chunk server information
CSInfo = namedtuple('CSInfo', ('id', 'host_id', 'location', 'tier', 'avail', 'space_total', 'space_avail'))

# Storage encoding and placement information
StorageInfo = namedtuple('StorageInfo', ('cluster_name', 'n', 'k', 'fail_domain', 'tier'))

class StorageMap:
	"""Keeps mapping from failure domain to the amount of storage space available"""
	fail_domain_map = {
		"room" : 1,
		"row"  : 2,
		"rack" : 3,
		"host" : 4,
		"disk" : 5,
	}

	@staticmethod
	def fail_domain_valid(domain_name):
		"""Returns True if domain_name is valid failure domain name, False otherwise"""
		return domain_name in StorageMap.fail_domain_map

	def __init__(self, cs_infos, fail_domain, tier=None, drop_hosts=set()):
		"""Create storage map given the CSInfo iterable and failure domain name"""
		fd = StorageMap.fail_domain_map[fail_domain]
		self.total_map = defaultdict(int)
		self.avail_map = defaultdict(int)
		self.used_map  = defaultdict(int)
		self.move_bytes = 0
		for cs in cs_infos:
			used = cs.space_total - cs.space_avail
			if cs.host_id in drop_hosts:
				if verbose:
					print('dropping', cs.host_id + '.' + cs.id, ':', used, '/', cs.space_total, 'bytes used', file=sys.stderr)
				self.move_bytes += used
				continue
			if tier is not None and tier != cs.tier:
				continue
			path = cs.location + (cs.host_id, cs.id)
			domain_path = path[:fd]
			self.total_map[domain_path] += cs.space_total
			self.avail_map[domain_path] += cs.space_avail
			self.used_map [domain_path] += used

	def domain_map(self, total):
		"""Returns domain to space mapping for total or available space"""
		return self.total_map if total else self.avail_map

	def use_space(self, use):
		"""Use the given amount of space"""
		domains = sorted(self.avail_map.keys(), key=lambda p: self.avail_map[p])
		# use space in smallest domains first to simulate a kind of worst case
		for path in domains:
			avail = self.avail_map[path]
			if avail < use:
				self.avail_map[path] = 0
				use -= avail
			else:
				self.avail_map[path] -= use
				break

	def degrade(self):
		"""
		Simulate dropping the largest failure domain redistributing
		its used space among remaining domains.
		"""
		if not self.used_map:
			# no domains to drop
			return
		# find largest domain
		path, used = max(self.used_map.items(), key=lambda i: i[1])
		if verbose:
			print('dropping', '.'.join(path), ':', used, 'bytes used', file=sys.stderr)
		# drop it
		del self.total_map[path]
		del self.avail_map[path]
		del self.used_map [path]
		self.move_bytes += used

	def physical_space(self, total):
		"""Returns the available physical space"""
		return sum(self.domain_map(total).values())

	def virtual_space(self, total, n, k):
		"""Returns the amount of virtual space available considering the (n, k) storage scheme"""
		assert n >= k
		avail_list = list(self.domain_map(total).values())
		avail_list.sort()
		domains = len(avail_list)
		if domains < n:
			# Not enough failure domains available
			return 0
		total, filled = 0, 0
		# Iterate available failure domains in order of increasing available space
		for avail in avail_list:
			if domains <= n:
				if filled and total / filled < avail:
					# The next domain has too much space which can't be utilized
					# So we just stop at this point
					break
				filled += 1
			# Accumulate free space
			total += avail
			domains -= 1
		per_repl = total // filled
		# Consider redistributed space from removed nodes
		use_repl = self.move_bytes // n
		if per_repl > use_repl:
			per_repl -= use_repl
		else:
			per_repl = 0
		# Returns the available space taking redundancy factor into account
		return k * per_repl

def vstorage_exec(cluster_name, cmd, options):
	"""Execute specified vstorage command with given list of options. Returns command's stdout"""
	args = ['vstorage', '-q']
	if cluster_name is not None:
		args += ['-c', cluster_name]
	args.append(cmd)
	try:
		return subprocess.check_output(args + options)
	except subprocess.CalledProcessError:
		return None

def get_cluster_stat_xml(cluster_name):
	"""Query storage stat and return it as XMl tree"""
	out = vstorage_exec(cluster_name, 'stat', ['-X'])
	if out is None:
		return None
	return ET.XML(out)

def get_path_attrs_xml(path):
	"""Get storage attributes given the path to the mounted storage folder"""
	out = vstorage_exec(None, 'get-attr', ['-X', path])
	if out is None:
		return None
	return ET.XML(out)

def get_cluster_cs_info(cluster_name):
	"""Query and parse cluster stat yielding CSInfo tuples"""
	xml = get_cluster_stat_xml(cluster_name)
	if xml is None:
		return None
	cs_info = []
	for cs in xml.find('cs_list'):
		host, space = cs.find('host_info'), cs.find('space')
		cs_info.append(
			CSInfo(
				id = cs.find('id').text,
				host_id = host.find('host_id').text,
				location = tuple(host.find('location').text.split('.')),
				tier = int(cs.find('tier').text),
				avail = int(cs.find('available').text),
				space_total = int(space.find('total').text),
				space_avail = int(space.find('avail').text)
			))
	return cs_info

def get_cluster_stor_info(path):
	"""Returns StorageInfo tuple given the path to the mounted storage folder"""
	xml = get_path_attrs_xml(path)
	if xml is None:
		return None
	cluster_name, n, k, fail_domain, tier = None, None, None, None, None
	cname = xml.find('cluster-name')
	if cname is not None:
		cluster_name = cname.text
	attrs = xml.find('file-attributes')
	for attr in attrs.findall('attribute'):
		name = attr.find('name').text
		if name == 'failure-domain':
			fail_domain = attr.find('value').text
		elif name == 'tier':
			tier = int(attr.find('value').text)
		elif name == 'encoding' or name == 'replicas':
			n, k = parse_storage_scheme(attr.find('value').text)
	return StorageInfo(cluster_name, n, k, fail_domain, tier)

def parse_storage_scheme(storage_scheme):
	"""Parse storage scheme specification string. Returns (n, k) tuple."""
	scheme = storage_scheme.split('+')
	try:
		if len(scheme) < 2:
			n, k = int(scheme[0].split(':')[0]), 1
		else:
			k = int(scheme[0])
			n = k + int(scheme[1].split('/')[0])
		return n, k
	except:
		return None, None

def get_space_with_suffix(value):
	suffixes = ["", "K", "M", "G", "T", "P"]
	i, f = 0, 0
	while i < len(suffixes) - 1:
		n = value // 1024;
		if n == 0:
			break
		f = ((100 * value) // 1024) % 100
		value = n
		i += 1

	if not f or value >= 100:
		return '%d%s' % (value, suffixes[i])
	if value >= 10:
		return '%d.%01d%s' % (value, f // 10, suffixes[i])
	return '%d.%02d%s' % (value, f, suffixes[i])

def cli_usage():
	print("""Cluster storage capacity calculator.
Usage:
  %s path_to_the_mounted_storage_folder [options]
  %s cluster_name failure_domain storage_scheme [-t tier] [options]
where
  cluster_name     the name of the cluster
  failure_domain   one of the room|row|rack|host|disk
  storage_scheme   either number of replicas or stripe+redundancy if encoding is used
Options available:
  -t TIER          only storage servers with matching tier will be considered
  --total          consider total space including already occupied
  --degraded       approximate space left after dropping the largest failure domain,
                   may be used multiple times to simulate dropping multiple domains
  --drop-host ID   approximate space left after dropping the host with particular ID,
                   may be used multiple times to simulate dropping multiple hosts
  -h               print amount of space with human readable suffix like
                   K for kilobytes, M for megabytes, G for gigabytes etc.
  -v               print additional debug information to standard error
""" % (sys.argv[0], sys.argv[0]), file=sys.stderr)

def cli_print_space(stor_info, total, degraded, drop_hosts, human_readable):
	"""Command line worker"""
	cs_info = get_cluster_cs_info(stor_info.cluster_name)
	if cs_info is None:
		print('%s is not a valid cluster name' % stor_info.cluster_name, file=sys.stderr)
		return 1
	stor = StorageMap(cs_info, stor_info.fail_domain, stor_info.tier, drop_hosts)
	for _ in range(degraded):
		stor.degrade()
	if verbose:
		if stor_info.tier is not None:
			print('tier [%d] ' % stor_info.tier, end='', file=sys.stderr)
		print('storage domains available:', file=sys.stderr)
		for path, avail in stor.domain_map(total).items():
			print('.'.join(path), ':', avail, 'bytes', file=sys.stderr)
		phy_space = stor.physical_space(total)
		print(phy_space, 'bytes of physical storage', file=sys.stderr)
		print('considering (%d, %d) storage scheme' % (stor_info.n, stor_info.k), file=sys.stderr)
		print('total' if total else 'available', 'virtual space in bytes:', file=sys.stderr)

	virt_space = stor.virtual_space(total, stor_info.n, stor_info.k)
	print(get_space_with_suffix(virt_space) if human_readable else virt_space)

	if verbose and phy_space:
		utilization = int(100 * virt_space / (phy_space * stor_info.k / stor_info.n))
		print('%d%% storage utilization' % utilization, file=sys.stderr)

	return 0

def cli_avail_space():
	"""Command line interface"""
	args = sys.argv[1:]
	if not args or '--help' in args:
		cli_usage()
		return 1

	total = '--total' in args
	if total:
		args.remove('--total')

	degraded = args.count('--degraded')
	for _ in range(degraded):
		args.remove('--degraded')

	drop_hosts = set()
	while '--drop-host' in args:
		i = args.index('--drop-host')
		del args[i]
		if i < len(args):
			drop_hosts.add(args[i])
			del args[i]
		else:
			print('missing host ID to drop', file=sys.stderr)
			return 1

	global verbose
	verbose = '-v' in args
	if verbose:
		args.remove('-v')

	human_readable = '-h' in args
	if human_readable:
		args.remove('-h')

	if len(args) == 1:
		si = get_cluster_stor_info(args[0])
		if si is None:
			print('%s is not a valid cluster path' % args[0], file=sys.stderr)
			return 1
		assert si.cluster_name is not None
		return cli_print_space(si, total, degraded, drop_hosts, human_readable)

	if '-t' in args:
		i = args.index('-t')
		del args[i]
		try:
			tier = int(args[i])
			del args[i]
		except:
			print('invalid tier parameter', file=sys.stderr)
			return 1
	else:
		tier = None

	if len(args) != 3:
		print('invalid parameters', file=sys.stderr)
		return 1

	cluster_name, fail_domain, storage_scheme = args
	if not StorageMap.fail_domain_valid(fail_domain):
		print('%s is not a valid failure domain' % fail_domain, file=sys.stderr)
		return 1

	n, k = parse_storage_scheme(storage_scheme)
	if not n or not k:
		print('%s is not a valid storage scheme' % storage_scheme, file=sys.stderr)
		return 1

	return cli_print_space(StorageInfo(cluster_name, n, k, fail_domain, tier),
		total, degraded, drop_hosts, human_readable)

if __name__ == '__main__':
	sys.exit(cli_avail_space())
