#!/usr/bin/python3


## TODO si meme peptide avec swithc de position ils sohnt regrouper avec le meme pepid ca pose probleme poru recupererles deux version dans les sortie de diann si un peptide a un probleme IL les deux sequence ont le meme ids dans ce cas il faudrait ressortir une seule sequence

## Function ot export protein inference for diann Result.
## generate peptide and protein file compatible with mcq.read.masschroq

import xml.sax

import argparse
import sys
import os
import time
import json
import pandas as pd
import re

sys.path.append("/usr/local/lib/groupingprotein")  # noqa: E402
# from utils.UNIMOD import UNIMOD
from diannReader.tsvReader import tsvReader
from diannReader.parquetReader import parquetReader
from utils.proForma import proFormaEncode2, mergedModificationString, proFormaEncode


# simple function to generate msrunid like in i2masschroq
# msrunlist : the list of msrunfileNames
# return : a dictionnary with msrunfileName in key and msrunid in value

def createMSRunID(msrunlist):
    dictMsrunID = {}
    index = 1
    for msrun in msrunlist:
        if index/10000 >= 1:
            msrunID = "msrune"+str(index)
        elif index/1000 >= 1:
            msrunID = "msrund"+str(index)
        elif index/100 >= 1:
            msrunID = "msrunc"+str(index)
        elif index/10 >= 1:
            msrunID = "msrunb"+str(index)
        else:
            msrunID = "msruna"+str(index)
        dictMsrunID[msrun] = msrunID
        index += 1
    return dictMsrunID


def convertDictToLine(line):
    result = f"{line['Run']}\t{line['Modified.Sequence']}\t{line['Stripped.Sequence']}\t{line['Precursor.Charge']}\t{line['Precursor.Quantity']}\t{line['RT']}\t{line['Quantity.Quality']}\t{line['PG.Q.Value']}\t{line['Protein.Group']}\t{line['Precursor.Mz']}"
    return result


# xml handler need refactoring
# Class Defined Sax Handler
class OutputProteinListHandler(xml.sax.ContentHandler):
    def __init__(self, outfile, proFormaDict):
        xml.sax.ContentHandler.__init__(self)
        self.outfile = open(os.path.join(outfile, "protein_mcqr.tsv"), "w")
        self.prots = {}
        self.protIds = {}
        # self.subgroupTotal = {}
        # self.subgroupUnique = {}
        # self.subgroupSpecific = {}
        self.seqs = {}
        self.outfile.write("peptide" + "\t" + "proForma" + "\t")
        self.outfile.write("protein" + "\t" + "protein_description" + "\n")
        self.firstElement = True
        self.seq = ''
        # self.seqGroup = seqGroup
        self.proFormaDict = proFormaDict
        self.proFormaDict["proForma_no_loc"] = []
        self.proFormaDict["proForma"] = []
        self.proFormaDict["pepID"] = []
        self.proFormaDict["sequence"] = []
        self.proFormaDict["mhTheo"] = []

    def startElement(self, name, attrs):
        if self.firstElement:
            self.firstElement = False
            if name != "grouping_protein":
                raise Exception("The input is not a gp result file")
        elif name == "protein":
            self.protIds[attrs.get('id')] = attrs.get('desc')
        elif name == "group":
            self.groupId = attrs.get('id')
            self.subgroup = []
            self.seqs.clear()
            self.prots.clear()
            self.seq = ""
        elif name == "subgroup":
            sub = attrs.get('id')
            self.subgroup.append(sub)
            self.prots[sub] = str(attrs.get('protIds')).split()
        elif name == "peptide":
            self.seq = attrs.get('id')
            self.subs = attrs.get('subgroupIds').split(" ")
            self.specific = len(self.subs) == 1
            self.seqs[self.seq] = {"id": attrs.get('id'),
                                   "subgroups": self.subs,
                                   "mods": [],
                                   "seqLI": attrs.get("seqLI"),
                                   "seq": "",
                                   "mhTheo": attrs.get("mhTheo")}
        elif name == "sequence_modifs":
            self.seqs[self.seq]["seq"] = attrs.get("seq")
            if len(attrs.get("mods")) > 3:
                mods = attrs.get("mods").split(" ")
                modString = []
                for mod in mods:
                    massmod = round(float(mod.split(":")[1]), 2)
                    modString.append(f'{mod.split(":")[0]}:{massmod}')
                modString = " ".join(modString)
                proForma = proFormaEncode2(attrs.get("seq"), modString)
                self.seqs[self.seq]["mods"].append(
                    {"sequence": attrs.get("seq"),
                     "modstring": modString,
                     "proForma": proForma})
                # print(f'{attrs.get("seq")} {attrs.get("mods")} : {proForma}')
            else:
                self.seqs[self.seq]["mods"].append(
                    {"sequence": attrs.get("seq"),
                     "modstring": "",
                     "proForma": attrs.get("seq")})

    def endElement(self, name):
        if name == "group":
            self.printSubGroup()

    def printSubGroup(self):
        for pep in self.seqs.keys():
            proForma_key = ""
            for sub in self.seqs[pep]["subgroups"]:
                # self.seqlist.append(self.seqs[pep]["seq"])
                proForma_string = self.seqs[pep]["seq"]
                proForma_key = self.seqs[pep]["seq"]
                proForma_key = re.sub(r'[IL]', 'J', proForma_key)
                if len(self.seqs[pep]["mods"]) != 0:
                    if self.seqs[pep]["mods"][0]["modstring"] != "":
                        mergeMods = mergedModificationString(self.seqs[pep]["mods"])
                        proForma_key = proFormaEncode(mergeMods, no_pose=True)
                        proForma_key = re.sub(r'[IL]', 'J', proForma_key)
                        proForma_string = proFormaEncode(mergeMods)
                # for mods in self.seqs[pep]["mods"]:
                #     self.proForma_merged[mods["proForma"]] = proForma_string
                # self.proForma_id[proForma_string] = {"id":self.seqs[pep]["id"], "mods":self.seqs[pep]["mods"], "mhTheo":self.seqs[pep]["mhTheo"]}
                self.outfile.write(self.seqs[pep]["id"] + "\t"+proForma_string+"\t")
                self.outfile.write(f'{sub}.a1' + "\t" + self.protIds[f'{sub}.a1'] + "\n")
            self.proFormaDict["proForma_no_loc"].append(proForma_key)
            self.proFormaDict["proForma"].append(proForma_string)
            self.proFormaDict["pepID"].append(self.seqs[pep]["id"])
            self.proFormaDict["sequence"].append(self.seqs[pep]["mods"][0]["sequence"])
            self.proFormaDict["mhTheo"].append(self.seqs[pep]["mhTheo"])


ParquetcolumnName = ["Run", "Run.Index",
                     "Modified.Sequence",
                     "Stripped.Sequence",
                     "Precursor.Charge",
                     "Precursor.Mz",
                     "Precursor.Quantity",
                     "RT",
                     "Quantity.Quality",
                     "PG.Q.Value",
                     "Protein.Group"]

oldTsvColumnName = ["Run",
                    "Modified.Sequence",
                    "Stripped.Sequence",
                    "Precursor.Charge",
                    "Precursor.Quantity",
                    "RT",
                    "Quantity.Quality",
                    "PG.Q.Value",
                    "Protein.Group"]


peptide_col_names = ["group",
                     "msrun",
                     "msrunfile",
                     "peptide",
                     "mz",
                     "rt",
                     "maxintensity",
                     "area",
                     "rtbegin",
                     "rtend",
                     "label",
                     "sequence",
                     "z",
                     "mods"]

protonH = 1.007825035


# main part

# Defined command line
command = argparse.ArgumentParser(prog='gp-output-protein-list',
                                  description='Export protein list on gp-grouping result.',
    usage='%(prog)s [options]')
command.add_argument('-i', '--infile', nargs="?", \
    type=argparse.FileType("r"), default=sys.stdin, \
    help='Open gp-grouping result default:STDIN')
command.add_argument('-o', '--outputDirectory', \
    default=".", \
    help='Export protein and peptide file for mcqr. Default value is current directory')
command.add_argument('-f', '--diannFile', nargs="?",\
    help='Diann result files (report.tsv or report.parquet)')
command.add_argument('-q', '--quantityQuality', nargs="?", \
    default=0.5, type=float, \
    help='Quantity.quality threshold default value 0.5')
command.add_argument('-p', '--PGQValue', nargs="?", \
    default=0.01, type=float, \
    help='PG.Q.Value threshold. Default value 0.5')

command.add_argument('-v', '--version', action='version', \
    version='%(prog)s ${GP_VERSION}')

#Read arguments of command line
args = command.parse_args()


#Parse the input with xml sax reader
print("Parsing grouping result file")
start = time.time()
parser = xml.sax.make_parser()
proFormaDict = dict()
parser.setContentHandler(OutputProteinListHandler(args.outputDirectory, proFormaDict))
parser.parse(args.infile)
stop = time.time()
sys.stdout.write(f"Parse Grouping file in {round(stop-start, 2)} seconde\n")
sys.stdout.flush()
proFormaDictDF = pd.DataFrame(proFormaDict)
proFormaDictDF.to_csv("control_peptideID.tsv", sep="\t", encoding='utf-8', index=False, header=True)

# Read diann result file and load it in dictionnary (key is line number)
start = time.time()
print("Loading DIA-NN file form directory :")
print(args.diannFile)
extension = os.path.splitext(args.diannFile)[-1]
#print(extension)
msrunfile = []
removedline = open("non_infered_pep.tsv", "w")
nonInferedLine = 0
controlDict = open("controldict.tsv", "w")
if extension == ".parquet":
    diannParquet = parquetReader(args.diannFile)
    # print(f"Diann version {diannParquet.getDiannVersion()}")
    diannParquet.parquetRead(ParquetcolumnName, proFormaDict["sequence"])
    msrunfile = diannParquet.getMsrunfileNames()
    temp = diannParquet.getDF()
    temp.to_csv("peptide_before_filtering.tsv", sep="\t", encoding='utf-8', index=False, header=True)
    print("Remove non valide quantification")
    diannParquet.filteringResult({"PG_Q_Value":args.PGQValue, "Quantity_Quality":args.quantityQuality})
    print("Merge peptidoform quantifications")
    diannParquet.mergePeptidoform()
    result = diannParquet.getDF()
    temp.to_csv("peptide_after_filtering.tsv", sep="\t", encoding='utf-8', index=False, header=True)
else:
    diannTSV = tsvReader(args.diannFile)
    print(f"Diann version {diannTSV.getDiannVersion()}")
    diannTSV.tsvRead(ParquetcolumnName, proFormaDict["sequence"])
    msrunfile = diannTSV.getMsrunfileNames()
    temp = diannTSV.getDF()
    temp.to_csv("peptide_before_filtering.tsv", sep="\t", encoding='utf-8', index=False, header=True)
    print("Remove non valide quantification")
    diannTSV.filteringResult({"PG_Q_Value":args.PGQValue, "Quantity_Quality":args.quantityQuality})
    print("Merge peptidoform quantifications")
    diannTSV.mergePeptidoform()
    result = diannTSV.getDF()
    temp.to_csv("peptide_after_filtering.tsv", sep="\t", encoding='utf-8', index=False, header=True)
    
    # inputfile = open(args.diannFile, "r")
    # headers= inputfile.readline().strip().split("\t")
    # msrunfiles = []
    # if "Run.Index" in headers:
    #     cols_index = {}
    #     for column in ParquetcolumnName:
    #         cols_index[column] = headers.index(column)
    #     dictDiann = {}
    #     rownum = 1
    #     start = time.time()
    #     for line in inputfile:
    #         if "Modified.Sequence" not in line:
    #             ## removing line if is not in grouping result (conta or not infered)
    #             if line["Stripped.Sequence"] in seqlist:
    #                 tabline = line.strip().split("\t")
    #                 dictDiann[str(rownum)] = {}
    #                 for column in cols_index.keys():proFormaDict
    #                     dictDiann[str(rownum)][column] = tabline[cols_index[column]]
    #                 msrunfile.append(dictDiann[str(rownum)]["Run"])
    #                 proForma = proFormaTranslateUnimodDiannToMass(dictDiann[str(rownum)]["Modified.Sequence"])
    #                 dictDiann[str(rownum)]["proForma_key"] = proForma_merged[proForma]
    #                 controlDict.write(f'{dictDiann[str(rownum)]["Modified.Sequence"]}\t{dictDiann[str(rownum)]["Stripped.Sequence"]}\t{dictDiann[str(rownum)]["proForma_key"]}\t{dictDiann[str(rownum)]["Run"]}\n')
    #                 rownum += 1
    #             else:
    #                 removedline.write(json.dumps(line))
    #                 removedline.write("\n")
    #                 nonInferedLine += 1
    #         if rownum%10000 == 0:
    #             stop = time.time()
    #             sys.stdout.write(f"* {rownum} lines in {round(stop-start, 2)} seconde\n")
    #             sys.stdout.flush()
    #             start = time.time()
    #         elif rownum%1000 == 0:
    #             sys.stdout.write("*")
    #             sys.stdout.flush()
    #     stop = time.time()
    #     sys.stdout.write(f"* {rownum} lines in {round(stop-start, 2)} seconde\n")
    #     sys.stdout.flush()
    #     print(f"The result files contains {len(dictDiann.keys())} lines")
    #     msrunfile = list(set(msrunfile))
    # else:
    #     print("Old version of diann")
    #     cols_index = {}
    #     for column in oldTsvColumnName:
    #         cols_index[column] = headers.index(column)
    #     dictDiann = {}
    #     rownum = 1
    #     start = time.time()
    #     for line in inputfile:
    #         if "Modified.Sequence" not in line:
    #             tabline = line.strip().split("\t")
    #             dictDiann[str(rownum)] = {}
    #             for column in cols_index.keys():
    #                 dictDiann[str(rownum)][column] = tabline[cols_index[column]]
    #             if dictDiann[str(rownum)]["Stripped.Sequence"] in seqlist:
    #                 dictDiann[str(rownum)]["Precursor.Mz"] = 0
    #                 msrunfile.append(dictDiann[str(rownum)]["Run"])
    #                 proForma = proFormaTranslateUnimodDiannToMass(dictDiann[str(rownum)]["Modified.Sequence"])
    #                 dictDiann[str(rownum)]["proForma_key"] = proForma_merged[proForma]
    #                 controlDict.write(f'{dictDiann[str(rownum)]["Modified.Sequence"]}\t{dictDiann[str(rownum)]["Stripped.Sequence"]}\t{dictDiann[str(rownum)]["proForma_key"]}\t{dictDiann[str(rownum)]["Run"]}\n')
    #                 rownum += 1
    #             else:
    #                 removedline.write(json.dumps(line))
    #                 removedline.write("\n")
    #                 nonInferedLine += 1
    #         if rownum%10000 == 0:
    #             stop = time.time()
    #             sys.stdout.write(f"* {rownum} lines in {round(stop-start, 2)} seconde\n")
    #             sys.stdout.flush()
    #             start = time.time()
    #         elif rownum%1000 == 0:
    #             sys.stdout.write("*")
    #             sys.stdout.flush()
    #     stop = time.time()
    #     sys.stdout.write(f"* {rownum} lines in {round(stop-start, 2)} seconde\n")
    #     sys.stdout.flush()
    #     print(f"The result files contains {len(dictDiann.keys())} lines")
    #     msrunfile = list(set(msrunfile))
# print(msrunfile)


# controlDict.close()
## Génération d'un id unique pour les msrunfile
print("Create msrunID")
dictMsrunID = createMSRunID(msrunfile)
print(dictMsrunID)

pepId_mass = proFormaDictDF[["proForma_no_loc", "pepID", "mhTheo"]]
pepId_mass = pepId_mass.astype({"mhTheo": "float"})
pepId_mass.to_csv("before_merge.tsv", sep="\t", encoding='utf-8', index=False, header=True)
result = pd.merge(result, pepId_mass, on = "proForma_no_loc")
print(result.shape[0])
result['Precursor_Mz'] = result.apply(lambda row: (row.mhTheo + ((row.Precursor_Charge-1)*protonH))/row.Precursor_Charge, axis=1)
result["msrun"] = result.apply(lambda row: dictMsrunID[row.Run], axis =1)
result["RT"] = result["RT"]*60
result["group"] = "G1"
result["maxintensity"] = 0
result["rtbegin"] = 0
result["rtend"] = 0
result["label"] = 0
result = result[["group" ,"msrun" ,"Run", "pepID", "Precursor_Mz", "RT", "maxintensity",
                      "Precursor_Quantity", "rtbegin", "rtend" , "label", "Stripped_Sequence",
                      "Precursor_Charge", "proForma_no_loc"]]
 # ["group" ,"msrun" ,"msrunfile", "peptide", "mz", "rt", "maxintensity",
                      # "area", "rtbegin", "rtend" , "label", "sequence",
                      # "z", "mods"]
result = result.rename(columns={"Run": 'msrunfile','pepID': 'peptide', 'Precursor_Mz': 'mz', 'RT': 'rt', 'Precursor_Quantity': 'area', 'Stripped_Sequence': 'sequence', 'Precursor_Charge': 'z', 'proForma_no_loc': 'mods'})
print(result.shape[0])
outputfile = os.path.join(args.outputDirectory, "peptide_mcqr.tsv")
result.to_csv(outputfile, sep="\t", encoding='utf-8', index=False, header=True)


#
# output_Peptide = open(os.path.join(args.outputDirectory, "peptide_mcqr.tsv"), "w")
# colnames = "\t".join(peptide_col_names)
# output_Peptide.write(f'{colnames}\n')
# line_to_reprocess = dict()
# nbLineWrite = 0
# contaLine = 0
# nonInferedLine = 0
# contaFile= open("contaminant_line.tsv", "w")
# 
# toReprocess = open("to_reprocess.txt", "w")
# for line in dictDiann.keys():
#     line = dictDiann[line]
#     if (float(line["PG.Q.Value"]) < args.PGQValue) and (float(line["Quantity.Quality"])>args.quantityQuality):
#         if "conta|" not in line["Protein.Group"]:
#             peptide = line["Stripped.Sequence"]
#             seqProforma = proFormaTranslateUnimodDiannToMass(line["Modified.Sequence"])
#             if seqProforma in proForma_id.keys():
#                 pepid  = proForma_id[seqProforma]
#                 mz = (float(pepid["mhTheo"])+((int(line["Precursor.Charge"])-1)*protonH))/int(line["Precursor.Charge"])
#                 output_Peptide.write(f'G1\t{dictMsrunID[line["Run"]]}\t{line["Run"]}\t{pepid["id"]}\t{mz}\t{float(line["RT"])*60}\t0\t{line["Precursor.Quantity"]}\t0\t0\t\t{peptide}\t{line["Precursor.Charge"]}\t{seqProforma}\n')
#                 nbLineWrite += 1
#             elif line["Stripped.Sequence"] not in seqlist:
#                 #print(f'{line["Stripped.Sequence"]} from {line["Protein.Group"]} removed during inference')
#                 removedline.write(json.dumps(line))
#                 removedline.write("\n")
#                 nonInferedLine += 1
#             else:
#                 toReprocess.write(convertDictToLine(line)+"\t"+seqProforma+"\n")
#                 if line["Modified.Sequence"] in line_to_reprocess.keys():
#                     if line["Precursor.Charge"] in line_to_reprocess[line["Modified.Sequence"]].keys():
#                         line_to_reprocess[line["Modified.Sequence"]][line["Precursor.Charge"]][line["Run"]] = line
#                     else:
#                         line_to_reprocess[line["Modified.Sequence"]] = {line["Precursor.Charge"] : {line["Run"] : line}}
#                 else:
#                     line_to_reprocess[line["Modified.Sequence"]] = {line["Precursor.Charge"] : {line["Run"] : line}}
#         else:
#             contaLine += 1
#             contaFile.write(json.dumps(line)+"\n")
# json_pep = open("json_pep.json", "w")
# json_pep.write(json.dumps(proForma_id))
# json_pep.close()
# output_Peptide.close()
# print(f"number of line writed : {nbLineWrite}")
# print(f"number of line in conta : {contaLine}")
# print(f"number of line in non_infered_pep : {nonInferedLine}")
# print(f"number of line to reprocess : {len(line_to_reprocess.keys())}")
# ## creation d'un tableau avec pour cle la sequence non modifier et en valeur la sequence modifier pour les leucien isoleucine je vais remplacer les L et les I par des J
# compacted_to_reprocesse = dict()
# for proForma in line_to_reprocess.keys():
#     seq = re.sub("\(UniMod\:[0-9]*\)", "", proForma)
#     seqj = re.sub("[IL]", "J", seq)
#     # print(seq)
#     if seqj in compacted_to_reprocesse.keys():
#         compacted_to_reprocesse[seqj].append({"sequence" : seq, "proForma" :proForma})
#     else:
#         compacted_to_reprocesse[seqj] = [{"sequence" : seq, "proForma" :proForma}]
# ## ecriture des lignes dans le tableau si plusieur peptide
# print(compacted_to_reprocesse)
# reprocessed_keys = dict()
# for seqj in compacted_to_reprocesse.keys():
#     mods = []
#     #print(f"To compact: \n{compacted_to_reprocesse[seqj]}")
#     for seqj in compacted_to_reprocesse[seqj]:
#         print(seqj)
#         mods.append(translateDiannUnimodToMassAminoFormat(seqj["proForma"]))
#     print(mods)
#     modMerged = mergedModificationString(mods)
#     print(modMerged)
#     print(proFormaEncode(modMerged))
# 





