from PyQt4 import QtGui
from PyQt4 import QtCore
from mzxml_scan import mzxml_scan
import sqlite3
import numpy
import bz2
import math


class MsData():
	db = None
	def __init__(self, parent):
		self.parent = parent
		self.db = sqlite3.connect(':memory:', check_same_thread=False)
		c = self.db.cursor()
		c.execute('''CREATE TABLE msdata (sample_name text, num integer, msLevel integer, rt real, basepeakintensity real, tic real, precursor_scan_num integer, precursor_mz real, precursor_intensity real, precursor_charge integer)''')
		self.db.commit()
		c.execute('''CREATE TABLE msorder (sample_name text, sample_order integer, method text, well text, current_location text)''')
		self.db.commit()
		c.close()
	#### classe de remplissage des bases 
	## input scan object
	def add_scan(self,  scans):
		c = self.db.cursor()
		for scan in scans:
			c.execute("INSERT INTO msdata VALUES ('"+scan.get_sample_name()+"',"+str(scan.get_num())+","+str(scan.get_mslevel())+","+str(scan.get_rt())+","+str(scan.get_basepeakintensity())+","+str(scan.get_tic())+","+str(scan.get_precursor_scan_num())+","+str(scan.get_precursor_mz())+","+str(scan.get_precursor_intensity())+","+str(scan.get_precursor_charge())+")")
		self.db.commit()
		c.close()
	## input table sample_name order method well
	def add_sample_order(self, datas):
		c = self.db.cursor()
		for data in datas:
			if isinstance(data, dict):
				c.execute("INSERT INTO msorder VALUES('"+unicode(data['DataFilename'])+"',"+unicode(data['Order'])+",'"+unicode(data['MethodName'])+"','"+unicode(data['PositionCell'])+"','"+unicode(data['currentLocation'])+"')")
				self.db.commit()
			elif len(data) == 4:
				c.execute("INSERT INTO msorder VALUES('"+unicode(data[0])+"',"+data[1]+",'"+unicode(data[2])+"','"+unicode(data[3])+"')")
				self.db.commit()
			elif len(data) == 2:
				c.execute("INSERT INTO msorder VALUES('"+unicode(data[0])+"',"+data[1]+",'','')")
				self.db.commit()
		c.close()
	
	def get_nb_ms_for_level(self, level):
		c = self.db.cursor()
		c.execute("select sample_name, count(*) from msdata where mslevel = ? group by sample_name", (level,))
		result = {}
		for row in c:
			result[row[0]] = row[1]
		c.close()
		return result
	
	def get_nb_ms_for_level_sample(self, level, sample):
		c = self.db.cursor()
		c.execute("select sample_name, count(*) from msdata where mslevel = ? and sample_name = ? group by sample_name", (int(level), str(sample)))
		result = 0
		for row in c:
			result = row[1]
		c.close()
		return result
		
	def get_mean_intensity_ms_for_level(self, level):
		c = self.db.cursor()
		c.execute("select sample_name, avg(basepeakintensity) from msdata where mslevel = ? group by sample_name", (level,))
		result = {}
		for row in c:
			result[row[0]] = row[1]
		c.close()
		return result
		
	def get_mean_intensity_ms_for_level_for_sample(self, level, sample):
		c = self.db.cursor()
		c.execute("select sample_name, avg(basepeakintensity) from msdata where mslevel = ? and sample_name = ?  group by sample_name", (level, str(sample)))
		result =0
		for row in c:
			result = row[1]
		c.close()
		return result
	def get_mean_tic_ms_for_level_for_sample(self, level, sample):
		c = self.db.cursor()
		c.execute("select sample_name, avg(tic) from msdata where mslevel = ? and sample_name = ?  group by sample_name", (level, str(sample)))
		result =0
		for row in c:
			result = row[1]
		c.close()
		return result
		
	def get_nb_precursor_charges_by_charge_for_level(self, level):
		c = self.db.cursor()
		c.execute("select sample_name, precursor_charge , count(precursor_charge) from msdata where mslevel = ? group by sample_name, precursor_charge", (level,))
		result = {}
		for row in c:
			if row[0] in result.keys():
				result[row[0]][row[1]] = row[2]
			else:
				temp = {row[1] : row[2]}
				result[row[0]] = temp
		c.close()
		return result
	
	def get_nb_ms2_by_ms_for_sample(self, sample):
		c = self.db.cursor()
		c.execute("select precursor_scan_num, count(precursor_scan_num) from msdata where mslevel =2 and sample_name = ? group by precursor_scan_num order by precursor_scan_num", (str(sample),))
		result = {}
		for row in c:
			result[row[0]] = row[1]
		c.close()
		return result
		
	def get_nb_ms2_by_ms_for_sample_ordered(self, sample):
		c = self.db.cursor()
		c.execute("select precursor_scan_num, rt, count(precursor_scan_num) from msdata where mslevel =2 and sample_name = ? group by precursor_scan_num order by precursor_scan_num", (str(sample),))
		resultx = []
		resulty = []
		for row in c:
			resultx.append(float(row[1])/float(60))
			resulty.append(row[2])
		result = []
		result.append(resultx)
		result.append(resulty)
		c.close()
		return result
			
		
	def get_cycle_duration_for_sample(self, sample):
		c = self.db.cursor()
		c.execute("select precursor_scan_num from msdata where sample_name = ? and mslevel = 2", (str(sample),))
		temp = []
		for row in c:
			temp.append(row[0])
		c.close()
		min_scan = min(temp)
		max_scan = max(temp)
		c = self.db.cursor()
		c.execute("select num, rt from msdata where sample_name = ? and mslevel = 1 and ? < num < ? order by rt asc", (str(sample), min_scan, max_scan))
		time = 0.0
		result = {}
		for row in c:
			#print row
			curr_time = row[1]
			duration = curr_time-time
			if duration <10:
				result[row[0]] = duration
			time = curr_time
		c.close()
		return result
		
	def get_sample_list(self):
		c = self.db.cursor()
		c.execute("select sample_name from msdata group by sample_name")
		sample_list = []
		for row in c:
			sample_list.append(row[0])
			#print row[0]
		c.close()
		return sample_list
		
	def getFilesLocation(self):
		c = self.db.cursor()
		c.execute("select current_location from msorder")
		sample_list = []
		for row in c:
			sample_list.append(row[0])
			#print row[0]
		c.close()
		return sample_list

	
	def get_cycle_information_by_sample(self):
		sample_list = self.get_sample_list()
		result = {}
		for sample in sample_list:
			result[sample] = {}
			ms2_by_ms = self.get_nb_ms2_by_ms_for_sample(sample)
			min_val = min(ms2_by_ms.values())
			#print str(min(ms2_by_ms.values()))
			max_val = max(ms2_by_ms.values())
			#print str(max(ms2_by_ms.values()))
			avg = sum(ms2_by_ms.values())/float(len(ms2_by_ms.values()))
			#print str(sum(ms2_by_ms.values())/float(len(ms2_by_ms.values())))
			med = numpy.median(numpy.array(ms2_by_ms.values()))
			#print str(numpy.median(numpy.array(ms2_by_ms.values())))
			temp = {'min':min_val, 'max':max_val, 'avg':avg, 'med':med}
			result[sample]['count'] = temp
			cycle_duration = self.get_cycle_duration_for_sample(sample)
			temp = {}
			min_val = min(cycle_duration.values())
			#print str(min(cycle_duration.values()))
			max_val = max(cycle_duration.values())
			#print str(max(cycle_duration.values()))
			avg = sum(cycle_duration.values())/float(len(cycle_duration.values()))
			#print str(sum(cycle_duration.values())/float(len(cycle_duration.values())))
			med = numpy.median(numpy.array(cycle_duration.values()))
			#print str(numpy.median(numpy.array(cycle_duration.values())))
			temp = {'min':min_val, 'max':max_val, 'avg':avg, 'med':med}
			result[sample]['time'] = temp
		c.close()
		return result
			
	def get_intensity_for_sample_by_level(self, sample, level):
		c = self.db.cursor()
		c.execute("select num, basepeakintensity, rt from msdata where sample_name = ? and mslevel = ? order by num asc", (str(sample), level))
		resultx = []
		resulty = []
		for row in c:
			resultx.append(float(row[2])/float(60))
			resulty.append(row[1])
		result = []
		result.append(resultx)
		result.append(resulty)
		c.close()
		return result
		
	def get_order_number_for_sample(self, sample):
		c = self.db.cursor()
		c.execute("select sample_order from msorder where sample_name = ?", (str(sample),))
		result = None
		for row in c:
			result = row[0]
		c.close()
		return result
			
	def get_msdata_information(self, sample):
		nb_ms1 = self.get_nb_ms_for_level_sample(1, sample)
		nb_ms2 = self.get_nb_ms_for_level_sample(2, sample)
		intensity_ms1 = self.get_mean_intensity_ms_for_level_for_sample(1, sample)
		intensity_ms2 = self.get_mean_intensity_ms_for_level_for_sample(2, sample)
		tic_ms1 = self.get_mean_tic_ms_for_level_for_sample(1, sample)
		tic_ms2 = self.get_mean_tic_ms_for_level_for_sample(2, sample)
		cycle_duration = self.get_cycle_duration_for_sample(sample)
		max_cycle = max(cycle_duration.values())
		msorder = self.get_order_number_for_sample(sample)
		result = {'samplename': sample, 'nb_ms1':nb_ms1, 'nb_ms2':nb_ms2, 'intensity_ms1':intensity_ms1, 'intensity_ms2':intensity_ms2, 'tic_ms1':tic_ms1, 'tic_ms2':tic_ms2, 'order':msorder, 'max_cycle':max_cycle}
		return result
		
	def get_nb_precursor_charges_by_charge_for_level_for_sample(self, level, sample):
		c = self.db.cursor()
		c.execute("select precursor_charge , count(precursor_charge) from msdata where mslevel = ? and sample_name = ? group by precursor_charge", (level,str(sample)))
		result = {}
		for row in c:
			result[row[0]] = row[1]
		c.close()
		return result
		
	def get_precursor_ms_for_sample(self, sample):
		c = self.db.cursor()
		c.execute("select precursor_mz from msdata where mslevel = 2 and sample_name = ? ", (str(sample),))
		result = []
		for row in c:
			result.append(row[0])
		c.close()
		return result
		
	def get_precursor_intensity_for_sample(self, sample):
		c = self.db.cursor()
		c.execute("select precursor_intensity from msdata where mslevel = 2 and sample_name = ? ", (str(sample),))
		result = []
		for row in c:
			if row[0] != 0.0:
				result.append(math.log10(float(row[0])))
		c.close()
		return result

	def dumpDB(self, filename):
		f = bz2.BZ2File(filename, 'wb', compresslevel = 1)
		iterator = self.db.iterdump()
		count = 0
		for line in iterator:
			f.write('%s\n' % line)
			count += 1
			if count%100 == 0:
				self.parent.emit(QtCore.SIGNAL('count_loading'))
		f.close()
				
	def loadDB(self, filename):
		self.db = sqlite3.connect(':memory:', check_same_thread=False)
		c = self.db.cursor()
		f = bz2.BZ2File(filename,'rb')
		sql = f.read() 
		c.executescript(sql)
		
	def nbEntry(self):
		c = self.db.cursor()
		c.execute("select num, count(num) from msdata")
		for row in c:
			length = row[0]
		c.close()
		return length
		
	def cleandb(self):
		c = self.db.cursor()
		c.executescript('drop table if exists msdata')
		self.db.commit()
		c.close()
		
	def getProjectSummary(self):
		sample_liste = self.get_sample_list()
		result = {}
		for sample in sample_list:
			info = self.get_msdata_information(sample)
			result[sample] = info
		return result
		
	def test_ordered_info(self):
		length = 0
		c = self.db.cursor()
		c.execute("select sample_order , count(sample_order) from msorder")
		for row in c:
			length = row[0]
		c.close()
		return length
		
		
		