#!/usr/bin/python3

import re
from Bio import SeqIO
from Bio.SeqUtils.ProtParam import ProteinAnalysis
from shutil import which
import subprocess
import tempfile
import time
import xml.sax
from math import floor


import argparse
import sys
import os
import filecmp

sys.path.append("/usr/local/lib/groupingprotein")
# from utils.UNIMOD import UNIMOD
from utils.proForma import convertPepXMLModToMass

protonH = 1.007825035

## xml handler need refactoring
#Class Defined Sax Handler
class pepXMLHandler(xml.sax.ContentHandler):
	def __init__(self, outputString):
		xml.sax.ContentHandler.__init__(self)
		self.outputString = open(outputString, "w")
		self.msnfiles = {}
		self.firstElement = True
		self.spectrum = ""
		self.modificationList = {}
		self.amino_mods = {}
		self.nbAssign = 0
		self.mz = 0
		self.badprot = False

	def startElement(self, name, attrs):
		if self.firstElement:
			self.firstElement = False
			if name != "msms_pipeline_analysis":
				raise Exception("The input is not a gp result file")
		# elif name == "aminoacid_modification":
		# 	amino_acid = attrs.get("aminoacid")
		# 	total_mass = attrs.get("mass")
		# 	massdiff = attrs.get("massdiff")
		# 	if total_mass in self.modificationList.keys():
		# 		self.modificationList[total_mass][amino_acid] = [massdiff]
		# 	else:
		# 		self.modificationList[total_mass] = {amino_acid : [massdiff]}
		# 	if amino_acid in self.amino_mods.keys():
		# 		if massdiff not in self.amino_mods[amino_acid]["mods"]:
		# 			self.amino_mods[amino_acid]["mods"].append(massdiff)
		# 	else:
		# 		self.amino_mods[amino_acid] = {"mass": float(total_mass)-float(massdiff), "mods":[massdiff]}
		elif name == "msms_run_summary":
			self.msnfiles[attrs.get('base_name')] = {}
			self.msrun = attrs.get('base_name')
			print(attrs.get('base_name'))
		elif name == "spectrum_query":
			self.spectrum = attrs.get('start_scan')
			self.nbAssign +=1
			mh = float(attrs.get('precursor_neutral_mass'))+protonH
			self.msnfiles[self.msrun][self.spectrum] = {"scanNum" : self.spectrum, "z" : attrs.get('assumed_charge'), "mhObs" : mh, "mods" : [], "proteins" : [], "peptide": "", "mhTheo": ""}
		elif name == "search_hit":
			if attrs.get('protein') == "":
				self.badprot = True
			self.msnfiles[self.msrun][self.spectrum]['peptide'] = attrs.get('peptide')
			self.msnfiles[self.msrun][self.spectrum]['mhTheo'] = float(attrs.get('calc_neutral_pep_mass'))+protonH
			protein_name = attrs.get('protein').replace('"','')+" "+attrs.get('protein_descr').replace('"','')
			self.msnfiles[self.msrun][self.spectrum]['proteins'].append(f"{protein_name}")
		elif name == "alternative_protein":
			protein_name = attrs.get('protein').replace('"','')+" "+attrs.get('protein_descr').replace('"','')
			self.msnfiles[self.msrun][self.spectrum]['proteins'].append(f"{protein_name}")
		elif name == "modification_info":
			if attrs.get("modified_peptide") == "EEEEPKRGTEAAKKKYAQVC(+57.02)VTM":
				print(convertPepXMLModToMass(attrs.get("modified_peptide")))
			self.msnfiles[self.msrun][self.spectrum]["mods"] = convertPepXMLModToMass(attrs.get("modified_peptide"))
	def endElement(self, name):
		if name == "msms_run_summary":
			print(f"Number of asignation : {self.nbAssign}")
			self.nbAssign = 0
		elif name == "spectrum_query":
			if self.badprot:
				self.msnfiles[self.msrun].pop(self.spectrum)
				self.badprot = False
			self.spectrum = ""
		elif name == "msms_pipeline_analysis":
			#print(self.msnfiles)
			self.outputString.write('<?xml version="1.0" encoding="utf-8" ?>\n<peptide_result>\n<filter evalue="0.05" />\n')
			for msrun in self.msnfiles.keys():
				self.outputString.write(f'<sample name="{msrun}" file="{msrun}.mzML">\n')
				print(f"writing {msrun} result in file")
				for spectrum in self.msnfiles[msrun].keys():
					spectrum = self.msnfiles[msrun][spectrum]
					self.outputString.write(f'<scan num="{spectrum["scanNum"]}" z="{spectrum["z"]}" mhObs="{spectrum["mhObs"]}">\n')
					for protein in spectrum["proteins"]:
						self.outputString.write(f'<psm seq="{spectrum["peptide"]}" mhTheo="{spectrum["mhTheo"]}" evalue="0.0001" prot="{protein}">\n')
						for mod in spectrum["mods"]:
							self.outputString.write(f'<mod aa="{mod["amino_acid"]}" pos="{mod["pos"]}" mod="{mod["mod"]}" />\n')
						self.outputString .write(f'</psm>\n')
					self.outputString.write(f'</scan>\n')
				self.outputString.write(f'</sample>\n')
			self.outputString.write(f'</peptide_result>\n')


#Defined command line
desc = "Process pepXML result to peptide result. \
    The result can then be grouped by gp-grouping program."
command = argparse.ArgumentParser(prog='gp-read-pepXML', \
    description=desc, usage='%(prog)s [options] files')
command.add_argument('-v', '--version', action='version', \
    version='%(prog)s ${GP_VERSION}')
command.add_argument('-o', '--outfile', default="pep.xml", \
    help='file to save result no stdout support')
command.add_argument('files', metavar='files', nargs='+', \
    help='List of X!Tandem files to process')

args = command.parse_args()

parser = xml.sax.make_parser()
outputString = args.outfile
parser.setContentHandler(pepXMLHandler(outputString))
for file in args.files:
	parser.parse(file)

#
# for row in dictDiann:
# 	if dictDiann[row]["Stripped.Sequence"] not in uniquePep:
# 		uniquePep.add(dictDiann[row]["Stripped.Sequence"])
# print(f'Number of peptide identified in result : {len(uniquePep)}')
# print(tempfasta.name)
# pep_prot_association = {}
# count = 0
# not_found = []
# print("retrieve accessions for identified peptide")
# accession_found = set()
# start = time.time()
# for seq in uniquePep:
# 	seqLI = seq.replace("L", "I")
# 	pattern = f'(\tM*|[KR]){seqLI}'
# 	pep_prot_association[seq] = []
# 	pepfound = False
# 	search = subprocess.run(['grep', "-P", pattern, tempfasta.name], capture_output=True, text=True)
# 	if len(search.stdout) > 1:
# 		pepfound = True
# 		search = search.stdout.strip("\n").split("\n")
# 		for line in search:
# 			# print(line.split("\t")[0])
# 			pep_prot_association[seq].append(line.split("\t")[0])
# 			accession_found.add(line.split("\t")[0])
# 	if not pepfound:
# 		not_found.append(seq)
# 	count +=1
# 	if count%1000 == 0:
# 		stop = time.time()
# 		sys.stdout.write(f"* {count} peptides searched in {stop-start} seconde\n")
# 		sys.stdout.flush()
# 		start = time.time()
# 	elif count%100 == 0:
# 		sys.stdout.write("*")
# 		sys.stdout.flush()
#
# print(f'Number of peptide not found : {len(not_found)}')
# print(f'Number of redundant accession found : {len(accession_found)}')
#
# samples = {}
# # pep_ids = open("pep_ids.tsv", "w")
# for line in dictDiann:
# 	line = dictDiann[line]
# 	if (float(line["PG.Q.Value"]) < args.PG_Qvalue) and (float(line["Q.Value"]) <QValue) and (float(line["Lib.Q.Value"]) < args.lib_Qvalue):
# 		sample = line["Run"]
# 		parsed_mods = translateDiannUnimodToMassAminoFormat(line["Modified.Sequence"])
# 		# previous_match = 0
# 		# mods = []
# 		# mods_string = []
# 		# for match in re.finditer(r"(\(UniMod\:[0-9]*\))", line["Modified.Sequence"]):
# 		# 	# if match.span()[0] == 1:
# 		# 	# 	amino_acid = line["Modified.Sequence"][match.span()[0]]
# 		# 	# 	amino_acid_number = match.span()[0]+1
# 		# 	# elif match.span()[0] == 0:
# 		# 	if match.span()[0] == 0:
# 		# 		# print(match)
# 		# 		# print(f'{match.group()} : {match.span()[0]} {line["Modified.Sequence"]} {line["Modified.Sequence"][match.span()[1]]}, {match.span()[1]-match.span()[0]}')
# 		# 		amino_acid = line["Modified.Sequence"][match.span()[1]]
# 		# 		amino_acid_number = match.span()[0]+1
# 		# 	else:
# 		# 		amino_acid = line["Modified.Sequence"][match.span()[0]-1]
# 		# 		# print(f'{match.group()} : {match.span()[0]} {line["Modified.Sequence"]} {line["Modified.Sequence"][match.span()[0]-1]}, {match.span()[1]-match.span()[0]}')
# 		# 		amino_acid_number = match.span()[0]-previous_match
# 		# 	translateDiannUnimodToMassAminoFormat()
# 		if args.debug==1:
# 			pep_ids.write(f'{line["Modified.Sequence"]}\t{line["Stripped.Sequence"]}\t{" ".join(parsed_mods["mods_string"])}\t{line["Q.Value"]}\t{line["Lib.Q.Value"]}\t{line["PG.Q.Value"]}\t{line["Quantity.Quality"]}\n')
# 		if extension == ".parquet":
# 			precursorMZ = line["Precursor.Mz"]
# 		else:
# 			precursorMZ = 0
# 		if sample in samples.keys():
# 			samples[sample].append({"proteins.Ids":line["Protein.Ids"], "protein.Names":line["Protein.Names"], "sequence_mods":line["Modified.Sequence"], "sequence":line["Stripped.Sequence"], "charge":line["Precursor.Charge"], "Q.value":line["Q.Value"], "mods":parsed_mods["mods"], "precursorMZ":precursorMZ})
# 		else:
# 			samples[sample] = [{"proteins.Ids":line["Protein.Ids"], "protein.Names":line["Protein.Names"], "sequence_mods":line["Modified.Sequence"], "sequence":line["Stripped.Sequence"], "charge":line["Precursor.Charge"], "Q.value":line["Q.Value"], "mods":parsed_mods["mods"], "precursorMZ":precursorMZ}]
# if args.debug == 1:
# 	pep_ids.close()
#
# outputfile = args.outfile
# outputfile.write('<?xml version="1.0" encoding="utf-8" ?>\n')
# outputfile.write('<peptide_result>\n')
# outputfile.write('<filter evalue="'+str(QValue)+'" />\n')
# for sample in samples:
# 	outputfile.write(f'<sample name="{sample}" file="{sample}">\n')
# 	scanid = 0
# 	modlines = ""
# 	for peptiz in samples[sample]:
# 		scanid = scanid+1
# 	#	print(f"scan {scanid} : {peptiz}")
# 		current_pep = peptiz
# 		prot = pep_prot_association[current_pep["sequence"]]
# 		if current_pep["precursorMZ"] != 0:
# 			pep_mass = current_pep["precursorMZ"]
# 		else:
# 			pep_mass = ProteinAnalysis(current_pep['sequence'], monoisotopic=True).molecular_weight()
# 			if len(current_pep['mods']) != 0:
# 				for mod in current_pep['mods']:
# 					pep_mass += float(mod["mono_isotopique"])
# 			pep_mass += protonH
# 		outputfile.write(f"<scan num=\"{scanid}\" z=\"{current_pep['charge']}\" mhObs=\"{pep_mass}\">\n")
# 		for current_prot in prot:
# 			# r = re.compile(f".*{prot[i]}.*")
# 			# protRealId = list(filter(r.match, accessions.keys()))
# 			# if len(protRealId) != 1:
# 			# 	print(f"ERRROR {prot[i]}")
# 			# 	print(current_pep['sequence'])
# 			# else:
# 			# protRealId = protRealId[0]
# 			if len(current_pep['mods']) == 0:
# 				outputfile.write(f"<psm seq=\"{current_pep['sequence']}\" mhTheo=\"{pep_mass}\" evalue=\"{current_pep['Q.value']}\" prot=\"{accessions[current_prot]['fullDescr']}\"></psm>\n")
# 			else:
# 				outputfile.write(f"<psm seq=\"{current_pep['sequence']}\" mhTheo=\"{pep_mass}\" evalue=\"{current_pep['Q.value']}\" prot=\"{accessions[current_prot]['fullDescr']}\">\n")
# 				for mod in current_pep['mods']:
# 					outputfile.write(f'<mod aa="{mod["amino_acid"]}" pos="{mod["amino_acid_number"]}" mod="{mod["mono_isotopique"]}" />\n')
# 				outputfile.write("</psm>\n")
# 		outputfile.write("</scan>\n")
# 	outputfile.write("</sample>\n")
# outputfile.write("</peptide_result>\n")
# outputfile.close()
# os.remove(tempfasta.name)
