#!/usr/bin/python3

import re
from Bio import SeqIO
from Bio.SeqUtils.ProtParam import ProteinAnalysis
import subprocess
import tempfile
import time
import pandas as pd

import argparse
import sys
import os
import filecmp

sys.path.append("/usr/local/lib/groupingprotein")  # noqa: E402
# from utils.UNIMOD import UNIMOD
from diannReader.tsvReader import tsvReader
from diannReader.parquetReader import parquetReader
from diannReader.reportLog import reportLog
from utils.proForma import translateDiannUnimodToMassAminoFormat
from utils.fastaDigested import fastaDigested


# function to compute mass for peptide and mods
def __computePepMass(row):
    protonH = 1.007825035
    if row["Precursor_Mz"] != "0.0":
        pep_mass = ProteinAnalysis(row['Stripped_Sequence'], monoisotopic=True
                                   ).molecular_weight()
        if len(row['mods']["mods"]) != 0:
            for mod in row['mods']["mods"]:
                pep_mass += float(mod["mono_isotopique"])
        pep_mass += protonH
        return pep_mass
    else:
        return row["Precursor_Mz"]


protonH = 1.007825035
Known_modifications = {
    "UniMod:4": {"unimod": "UniMod:4", "mass_modification": "57.0214614868",
                 "amino_acid": "C", "modType": "Fixed"}
    }

PeptideParseColumn = {"Run", "Modified.Sequence", "Stripped.Sequence",
                      "Precursor.Charge", "Precursor.Quantity", "RT",
                      "Protein.Ids", "Protein.Names", "Q.Value", "PG.Q.Value",
                      "Precursor.Mz", "Quantity.Quality", "Lib.Q.Value", "Precursor.Id"}
PeptideParseColumn_V1 = {"Run", "Modified.Sequence", "Stripped.Sequence",
                         "Precursor.Charge", "RT", "Protein.Ids",
                         "Precursor.Quantity", "Protein.Names", "Q.Value",
                         "PG.Q.Value", "Quantity.Quality", "Lib.Q.Value"}




# Defined command line
desc = "Process Diann result to peptide result. \
    The result can then be grouped by gp-grouping program."
command = argparse.ArgumentParser(prog='gp-read-diprotonH = 1.007825035ann',
                                  description=desc,
                                  usage='%(prog)s [options] files')
command.add_argument('-q', '--Qvalue', default=0.01, type=float,
                     help='Maximun peptide qvalue threshold (default:0.01)')
command.add_argument('-p', '--PG_Qvalue', default=0.05, type=float,
                     help='Maximun protein group qvalue threshold\
                     (default:0.01)')
command.add_argument('-l', '--lib_Qvalue', default=0.01, type=float,
                     help='Maximun library qvalue threshold (default:0.01)')
command.add_argument('-o', '--outfile', nargs="?",
                     type=argparse.FileType("w"), default=sys.stdout,
                     help='Save peptide result default:STDOUT')
command.add_argument('-f', '--fasta', default=None, nargs='?',
                     type=str, action='append',
                     help='The Fasta file which is used for DIA-NN')
command.add_argument('file', metavar='file', nargs='?',
                     help='diann report file in tsv or parquet format')
command.add_argument('-r', '--reportLog', default="", type=str,
                     help='Diann log report files (can be used to get \
                     partially the parameter of Diann processing)')
command.add_argument('-d', '--debug', default=0, type=int,
                     help="set the debug levels, permit to create some \
                     log files to verifiy the processing, 0 = no log, 1 = log")
command.add_argument('-L', '--LogDirectory', default=".", type=str,
                     help="directory to save logfile. used only with \
                     debug = 1, default current directory")
command.add_argument('-v', '--version', action='version',
                     version='%(prog)s ${GP_VERSION}')

# Read arguments of command line
args = command.parse_args()


# if fasta argument is used test if fastafile exist.
if args.fasta is not None:
    not_found = False
    for fastafile in args.fasta:
        if os.path.isfile(fastafile):
            print(f"\tFound exisiting Fasta File : {fastafile}")
        else:
            print(f"\tFasta present in argument but not found : {fastafile}")
            not_found = True
    if not_found:
        print("\tplease verify the presence of the files")
        exit()
param = None

if args.reportLog != "":
    print(f"Parsing log file : {args.reportLog}")
    try:
        param = reportLog(args.reportLog)
    except OSError as e:
        print(e)
        sys.exit()
    param.parselog()
    print("Information retrieve from reportLog :")
    fastaFiles = param.getFastaFile()
    for fastafile in fastaFiles:
        if os.path.isfile(fastafile):
            print(f"\tFound exisiting Fasta File : {fastafile}")
        else:
            fastaFiles.remove(fastafile)
            print("\tFasta present in log but not found :" +
                  f" {fastafile}. It is removed.")
    if param.getQValue() != -1:
        print(f"QValue was set to {param.getQValue()} in log File")
    else:
        print("QValue was not set to in log File, using Qvalue" +
              " argument from Command line")
    if param.getMatriceQValue() != -1:
        print(f"matrix_QValue was set to {param.getMatriceQValue()}" +
              " in log File")
    else:
        print("matrix_QValue was not set in log File, using Qvalue" +
              " argument from Command line")
    if param.isDecoyReported():
        print("Decoy sequences was reported in the result files")
    else:
        print("Decoy sequences was not reported in the result files")
    if param.isInfered():
        print("Inference was performed during DIA-NN process." +
              " Best result must be obtain without inference")
    else:
        print("Inference was not performed during DIA-NN process")
    if len(param.getMods()) != 0:
        print("List of searched Modifications :")
        for mod in param.getMods():
            print(f"\t{mod} :")
            print("\t\tAmino Acid : " +
                  f"{param.getMods()[mod]['amino_acid']}")
            print("\t\tModification Mass : " +
                  f"{param.getMods()[mod]['mass_modification']}")
            print("\t\tModification type : " +
                  f"{param.getMods()[mod]['modType']}")
    if len(param.getInputFile()) != 0:
        print("List of rawFile processed : ")
        for inputfile in param.getInputFile():
            print(f"\t* {inputfile}")

fasta_for_process = []
Qvalue = 0.01
PG_qvalue = 0.01

# test concordance between args and reportLog if used
if param is not None and param != -1:
    if args.Qvalue > param.getQValue():
        print("Qvalue parameter was less strict that the one in log file." +
              f"Qvalue from reportlog is used : {param.getQValue()}")
        QValue = param.getQValue()
    else:
        print("Qvalue parameter was less strict that the one in log file." +
              f"Qvalue from parameters is used : {args.Qvalue}")
        QValue = args.Qvalue

    if len(fastaFiles) == 0 and args.fasta is not None:
        print("fasta file must be defined with --fasta argument" +
              "or taken from reportLog and be valid file.")
        sys.exit()
    for fasta_from_param in fastaFiles:
        the_same = False
        if args.fasta is not None:
            for fasta_from_arg in args.fasta:
                result = filecmp.cmp(fasta_from_param, fasta_from_arg)
                if result:
                    the_same = True
                    fasta_for_process.append(fasta_from_param)
            if not the_same:
                print(f"the fasta file {fasta_from_param} present in the" +
                      "reportLog was not present or different from those" +
                      "in the settings. take the one in the reportLog")
                fasta_for_process.append(fasta_from_param)
        else:
            fasta_for_process = fastaFiles
else:
    if args.fasta is not None:
        fasta_for_process = args.fasta
    else:
        sys.exit()
print("FastaFiles used :")
if args.debug == 1:
    for fasta in fasta_for_process:
        print(f"\t* {fasta}")

# not that was relevant to reparse protein accession
# against the database in finditer
# no parsing of sequence matches was done during processing after

tempfasta = tempfile.NamedTemporaryFile(delete=False, mode='w+t')

print("readings database files")
# tempFasta = open("fasta.tsv", "w")
accessions = {}
for fasta in fasta_for_process:
    nb_seq = 0
    with open(fasta) as handle:
        for record in SeqIO.parse(handle, "fasta"):
            if record.id not in accessions.keys():
                description = record.description.replace('"', '\'')
                accessions[record.id] = {"accession": record.id,
                                         # "fullDescr": f'{record.id}' +
                                         # f'{record.description}',
                                         "fullDescr": description,
                                         "sequence": str(record.seq)}
                nb_seq += 1
                fullseqLI = str(record.seq).replace("L", "I")
                tempfasta.write(f"{record.id}\t{fullseqLI}\n")
                tempfasta.flush()
            else:
                print(f"accession {record.id} is already present in fasta" +
                      " files check your databases")
    print(f"fasta file {fasta}, contains {nb_seq} non dpuplicated sequences")
tempfasta.close()
print("DIA-NN sequence search was performed using" +
      f"{len(accessions.keys())} sequences")

if args.debug == 1:
    print("creating pep_ids log files")
    pep_ids = open(os.path.join(args.LogDirectory, "pep_ids.tsv"), "w")
    pep_ids.write("Modified.sequence\tSequence\tmods\tQ.Value\tLib.Q.Value" +
                  "\tPG.Q.Value\tQuantity.Quality\n")

# test if file is tsv or parquet file
extension = os.path.splitext(args.file)[-1]
uniquePep = set()

if extension == ".parquet":
    diannData = parquetReader(args.file)
    diannData.parquetRead(PeptideParseColumn)
    if args.debug == 1:
        dictDiann = diannData.getDF()
        dictDiann = dictDiann[['Stripped_Sequence', 'Protein_Ids']]
        dictDiann.to_csv(os.path.join(args.LogDirectory, "protein_ids.txt"),
                         sep="\t", encoding='utf-8', index=False, header=True)
else:
    diannData = tsvReader(args.file)
    diannData.tsvRead(PeptideParseColumn, )
    if args.debug == 1:
        dictDiann = diannData.getDF()
        dictDiann = dictDiann[['Stripped_Sequence', 'Protein_Ids']]
        dictDiann.to_csv(os.path.join(args.LogDirectory, "protein_ids.txt"),
                         sep="\t", encoding='utf-8', index=False, header=True)

uniquePep = diannData.getUniqueSequences()
print(f'Number of peptide identified in result : {len(uniquePep)}')
print(tempfasta.name)
count = 0
not_found = []
print("retrieve accessions for identified peptide")
accession_found = set()
start = time.time()
pep_prot_associations_bis = {}
pep_prot_association = {"seq": [], "accessions": []}
for seq in uniquePep:
    seqLI = seq.replace("L", "[LI]")
    pattern = f'(\tM*|[KR]){seqLI}'
    pepfound = False
    pep_prot_associations_bis[seq] = []
    search = subprocess.run(['grep', "-P", pattern, tempfasta.name],
                            capture_output=True, text=True)
    if len(search.stdout) > 1:
        pepfound = True
        search = search.stdout.strip("\n").split("\n")
        for line in search:
            pep_prot_associations_bis[seq].append(line.split("\t")[0])
            # print(line.split("\t")[0])
            pep_prot_association["seq"].append(seq)
            pep_prot_association["accessions"].append(line.split("\t")[0])
    if not pepfound:
        not_found.append(seq)
    count += 1
    if count % 1000 == 0:
        stop = time.time()
        sys.stdout.write(f"* {count} peptides searched in " +
                         f"{stop-start} seconde\n")
        sys.stdout.flush()
        start = time.time()
    elif count % 100 == 0:
        sys.stdout.write("*")
        sys.stdout.flush()
pep_prot_association = pd.DataFrame(pep_prot_association)


if args.debug == 1:
    pep_prot_association.to_csv(os.path.join(args.LogDirectory,
                                             "pep_prot_association.tsv"),
                                sep="\t", encoding='utf-8',
                                index=False, header=True)

print(f'Number of peptide not found : {len(not_found)}')
print('Number of redundant accession found : ' +
      f'{len(pep_prot_association["accessions"].unique().tolist())}')

samples = {}
print("Remove non valide quantification")
start = time.time()
print(f"{PG_qvalue} {Qvalue}")
diannData.filteringResult({"PG_Q_Value": PG_qvalue, "Q_Value": Qvalue})
print(f"Removed non valide quantifications in {time.time()-start}")
diannDF = diannData.getDF()
print("Translate modifications from unimod to mass value")
start = time.time()
diannDF["mods"] = diannDF["Modified_Sequence"].apply(translateDiannUnimodToMassAminoFormat)
print(f"Translate modifications in {time.time()-start}")
diannDF = diannDF[diannDF["Stripped_Sequence"].isin(pep_prot_association["seq"].unique().tolist())]
start = time.time()
diannDF["Precursor_Mz"] = diannDF.apply(__computePepMass, axis=1)
print(f"Compute peptide mass in {time.time()-start}")
print("Exporting peptide xml file")
if args.debug == 1:
    pep_ids.close()
start = time.time()
outputfile = args.outfile
outputfile.write('<?xml version="1.0" encoding="utf-8" ?>\n')
outputfile.write('<peptide_result>\n')
outputfile.write('<filter evalue="'+str(QValue)+'" />\n')
for sampleName in diannData.getMsrunfileNames():
    print(f"Exporting identifiction in {sampleName}")
    outputfile.write(f'<sample name="{sampleName}" file="{sampleName}">\n')
    scanid = 0
    modlines = ""
    sample = diannDF[diannDF["Run"] == sampleName]
    count = 0
    for index, row in enumerate(sample.itertuples(), 1):
        scanid = scanid+1
    #    print(f"scan {scanid} : {peptiz}")
        prot = pep_prot_associations_bis[row.Stripped_Sequence]
        # prot = pep_prot_association[pep_prot_association["seq"] == row.Stripped_Sequence]["accessions"]
        # print(prot)
        # if row["Precursor_Mz"] != 0:
        #     pep_mass = row["Precursor_Mz"]
        # else:
        #     pep_mass = ProteinAnalysis(row['Stripped_Sequence'], monoisotopic=True).molecular_weight()
        #     if len(row['mods']["mods"]) != 0:
        #         # print(row["mods"])
        #         for mod in row['mods']["mods"]:
        #             # print(mod)
        #             pep_mass += float(mod["mono_isotopique"])
        #     pep_mass += protonH
        outputfile.write(f"<scan num=\"{scanid}\" z=\"{row.Precursor_Charge}\" mhObs=\"{row.Precursor_Mz}\">\n")
        for current_prot in prot:
            if len(row.mods["mods"]) == 0:
                outputfile.write(f"<psm seq=\"{row.Stripped_Sequence}\" mhTheo=\"{row.Precursor_Mz}\" evalue=\"{row.Q_Value}\" prot=\"{accessions[current_prot]['fullDescr']}\"></psm>\n")
            else:
                outputfile.write(f"<psm seq=\"{row.Stripped_Sequence}\" mhTheo=\"{row.Precursor_Mz}\" evalue=\"{row.Q_Value}\" prot=\"{accessions[current_prot]['fullDescr']}\">\n")
                for mod in row.mods["mods"]:
                    outputfile.write(f'<mod aa="{mod["amino_acid"]}" pos="{mod["amino_acid_number"]}" mod="{mod["mono_isotopique"]}" />\n')
                outputfile.write("</psm>\n")
        outputfile.write("</scan>\n")
        count += 1
        if count % 10000 == 0:
            stop = time.time()
            sys.stdout.write(f"* {count} identifications in " +
                             f"{stop-start} seconde\n")
            sys.stdout.flush()
            start = time.time()
        elif count % 1000 == 0:
            sys.stdout.write("*")
            sys.stdout.flush()
    sys.stdout.write(f" {count} identifications\n")
    outputfile.write("</sample>\n")
outputfile.write("</peptide_result>\n")
outputfile.close()
# os.remove(tempfasta.name)
