#!/usr/bin/python2

import sys, subprocess
import xml.etree.ElementTree as ET
from collections import namedtuple, defaultdict

# Chunk server information
CSInfo = namedtuple('CSInfo', ('id', 'host_id', 'location', 'tier', 'space_total', 'space_avail'))

# Storage encoding and placement information
StorageInfo = namedtuple('StorageInfo', ('cluster_name', 'n', 'k', 'fail_domain', 'tier'))

class StorageMap:
	"""Keeps mapping from failure domain to the amount of storage space available"""
	fail_domain_map = {
		"room" : 1,
		"row"  : 2,
		"rack" : 3,
		"host" : 4,
		"disk" : 5,
	}

	@staticmethod
	def fail_domain_valid(domain_name):
		"""Returns True if domain_name is valid failure domain name, False otherwise"""
		return domain_name in StorageMap.fail_domain_map

	def __init__(self, cs_infos, fail_domain, tier=None, total=False):
		"""Create storage map given the CSInfo iterable and failure domain name"""
		fd = StorageMap.fail_domain_map[fail_domain]
		self.domain_map = defaultdict(int)
		for cs in cs_infos:
			if tier is not None and tier != cs.tier:
				continue
			path = cs.location + (cs.host_id, cs.id)
			domain_path = path[:fd]
			self.domain_map[domain_path] += cs.space_total if total else cs.space_avail

	def physical_space(self):
		"""Returns the available physical space"""
		return sum(self.domain_map.values())

	def virtual_space(self, n, k):
		"""Returns the amount of virtual space available considering the (n, k) storage scheme"""
		assert n >= k
		avail_list = self.domain_map.values()
		avail_list.sort()
		domains = len(avail_list)
		if domains < n:
			# Not enough failure domains available
			return 0
		total, filled = 0, 0
		# Iterate available failure domains in order of increasing available space
		for avail in avail_list:
			if domains <= n:
				if filled and total / filled < avail:
					# The next domain has too much space which can't be utilized
					# So we just stop at this point
					break
				filled += 1
			# Accumulate free space
			total += avail
			domains -= 1
		# Returns the available space taking redundancy factor into account
		return k * total / n

def vstorage_exec(cluster_name, cmd, options):
	"""Execute specified vstorage command with given list of options. Returns command's stdout"""
	args = ['vstorage', '-q']
	if cluster_name is not None:
		args += ['-c', cluster_name]
	args.append(cmd)
	return subprocess.check_output(args + options)

def get_cluster_stat_xml(cluster_name):
	"""Query storage stat and return it as XMl tree"""
	return ET.XML(vstorage_exec(cluster_name, 'stat', ['-X']))

def get_path_attrs_xml(path):
	"""Get storage attributes given the path to the mounted storage folder"""
	return ET.XML(vstorage_exec(None, 'get-attr', ['-X', path]))

def get_cluster_cs_info(cluster_name):
	"""Query and parse cluster stat yielding CSInfo tuples"""
	for cs in get_cluster_stat_xml(cluster_name).find('cs_list'):
		host, space = cs.find('host_info'), cs.find('space')
		yield CSInfo(
				id = cs.find('id').text,
				host_id = host.find('host_id').text,
				location = tuple(host.find('location').text.split('.')),
				tier = int(cs.find('tier').text),
				space_total = int(space.find('total').text),
				space_avail = int(space.find('avail').text)
			)

def get_cluster_stor_info(path):
	"""Returns StorageInfo tuple given the path to the mounted storage folder"""
	xml = get_path_attrs_xml(path)
	cluster_name, n, k, fail_domain, tier = None, None, None, None, None
	cname = xml.find('cluster-name')
	if cname is not None:
		cluster_name = cname.text
	attrs = xml.find('file-attributes')
	for attr in attrs.findall('attribute'):
		name = attr.find('name').text
		if name == 'failure-domain':
			fail_domain = attr.find('value').text
		elif name == 'tier':
			tier = int(attr.find('value').text)
		elif name == 'encoding' or name == 'replicas':
			n, k = parse_storage_scheme(attr.find('value').text)
	return StorageInfo(cluster_name, n, k, fail_domain, tier)

def parse_storage_scheme(storage_scheme):
	"""Parse storage scheme specification string. Returns (n, k) tuple."""
	scheme = storage_scheme.split('+')
	try:
		if len(scheme) < 2:
			n, k = int(scheme[0].split(':')[0]), 1
		else:
			k = int(scheme[0])
			n = k + int(scheme[1].split('/')[0])
		return n, k
	except:
		return None, None

def cli_usage():
	print >> sys.stderr, \
"""Cluster storage capacity calculator.
Usage:
  %s path_to_the_mounted_storage_folder [--total] [-v]
  %s cluster_name path_to_the_mounted_storage_folder [--total] [-v]
  %s cluster_name failure_domain storage_scheme [-t tier] [--total] [-v]
Where failure_domain is one of
  room row rack host disk
storage_scheme is either number of replicas or stripe+redundancy if encoding
is used. Prints the available storage space in bytes to standard output.
If tier parameter is specified only storage servers with matching tier will
be considered. If --total option is given the result will include
already occupied space and space reserved by the system as safety margin.
Prints additional debug information to standard error if -v option is given.
""" % (sys.argv[0], sys.argv[0], sys.argv[0])

def cli_print_space(stor_info, total, verbose):
	"""Command line worker"""
	stor = StorageMap(get_cluster_cs_info(stor_info.cluster_name), stor_info.fail_domain, stor_info.tier, total)
	if verbose:
		if stor_info.tier is not None:
			print >> sys.stderr, 'tier [%d]' % stor_info.tier,
		print >> sys.stderr, 'storage domains available:'
		for path, avail in stor.domain_map.items():
			print >> sys.stderr, '.'.join(path), ':', avail, 'bytes'
		phy_space = stor.physical_space()
		print >> sys.stderr, phy_space, 'bytes of physical storage'
		print >> sys.stderr, 'considering (%d, %d) storage scheme' % (stor_info.n, stor_info.k)
		print >> sys.stderr, 'total' if total else 'available', 'virtual space in bytes:'

	virt_space = stor.virtual_space(stor_info.n, stor_info.k)
	print virt_space

	if verbose and phy_space:
		utilization = (50 + 100 * virt_space) / (phy_space * stor_info.k / stor_info.n)
		print >> sys.stderr, '%d%% storage utilization' % utilization

	return 0

def cli_avail_space():
	"""Command line interface"""
	args = sys.argv[1:]

	total = '--total' in args
	if total:
		args.remove('--total')

	verbose = '-v' in args
	if verbose:
		args.remove('-v')

	if len(args) == 1:
		si = get_cluster_stor_info(args[0])
		if si.cluster_name is None:
			print >> sys.stderr, 'cluster name parameter is required'
			return 1
		return cli_print_space(si, total, verbose)

	if len(args) == 2:
		cluster_name, path = args
		si = get_cluster_stor_info(path)
		return cli_print_space(StorageInfo(cluster_name, si.n, si.k, si.fail_domain, si.tier), total, verbose)

	if '-t' in args:
		i = args.index('-t')
		del args[i]
		try:
			tier = int(args[i])
			del args[i]
		except:
			print >> sys.stderr, 'invalid tier parameter'
			return 1
	else:
		tier = None

	if len(args) != 3:
		cli_usage()
		return 1

	cluster_name, fail_domain, storage_scheme = args
	if not StorageMap.fail_domain_valid(fail_domain):
		print >> sys.stderr, '%s is not a valid failure domain' % fail_domain
		return 1

	n, k = parse_storage_scheme(storage_scheme)
	if not n or not k:
		print >> sys.stderr, '%s is not a valid storage scheme' % storage_scheme
		return 1

	return cli_print_space(StorageInfo(cluster_name, n, k, fail_domain, tier), total, verbose)

if __name__ == '__main__':
	sys.exit(cli_avail_space())
