## class to read and manipulate diann csv file format

import os
import subprocess
import time
import json
import sys
import re
import pandas as pd
from utils.proForma import *



class tsvReader():

    def __init__(self, filename):
        if os.path.isfile(filename):
            self.filename = filename
            self.colnames = {}
            self.diannVersion = None
            inputfile = open(self.filename, "r")
            tempcol = inputfile.readline().strip().split("\t")
            tempcol = [re.sub(r"\.", "_", x) for x in tempcol]
            colIndex = 0
            for column in tempcol:
                self.colnames[column] = colIndex
                colIndex += 1
            if "Run_Index" in tempcol:
                self.diannVersion == "2.0.0"
            else:
                self.diannVersion = "1.9.0"
                self.colnames["Run_Index"] = self.colnames["Run"]
            inputfile.close()
            print(self.colnames)
        else:
            raise Exception(f"{filename} does not exist, please verify spelling or file location")

    ## function to return diann meta version from column names
    def getDiannVersion(self):
        return(self.diannVersion)

    ## function to read tsv file from diann and populate pandas dataframe with selected column
    def tsvRead(self, columnNames, seqlist = list()):
        print(f"reading {self.filename}")
        cols_index = {}
        dictCols = {}
        columnNames = [re.sub(r"\.", "_", x) for x in columnNames]
        for column in columnNames:
            dictCols[column] = []
        print(columnNames)
        if self.diannVersion == "1.9.0":
            columnNames.remove("Precursor_Mz")

        dictCols["proForma_key"] = []
        dictCols["proForma_no_loc"] = []
        print(dictCols.keys())
        rownum = 1
        start = time.time()
        removedline = open("removedline.tsv", 'w')
        inputfile = open(self.filename, "r")
        nonInferedLine = 0
        for line in inputfile:
            if "Modified.Sequence" not in line:
                ## removing line if is not in grouping result (conta or not infered)
                #print(line[self.colnames["Stripped.Sequence"]])
                tabline = line.strip().split("\t")
                #print(tabline[self.colnames["Stripped.Sequence"]])
                if len(seqlist) == 0:
                    for column in dictCols.keys():
                        if column != "proForma_key" and column != "proForma_no_loc":
                            if column == "Precursor_Mz" and self.diannVersion == "1.9.0":
                                dictCols["Precursor_Mz"].append(0)
                            else:
                                dictCols[column].append(tabline[self.colnames[column]])
                    proForma = proFormaTranslateUnimodDiannToMass(tabline[self.colnames["Modified_Sequence"]])
                    dictCols["proForma_key"].append(proForma)
                    proForma_no_loc = proFormaTranslateUnimodDiannToMass(tabline[self.colnames["Modified_Sequence"]], no_pose=True)
                    dictCols["proForma_no_loc"].append(proForma_no_loc)
                    # controlDict.write(f'{dictDiann[str(rownum)]["Modified.Sequence"]}\t{dictDiann[str(rownum)]["Stripped.Sequence"]}\t{dictDiann[str(rownum)]["proForma_key"]}\t{dictDiann[str(rownum)]["Run"]}\n')
                    rownum += 1
                else:
                    if tabline[self.colnames["Stripped_Sequence"]] in seqlist:
                        #print(tabline[self.colnames["Stripped.Sequence"]])
                        for column in dictCols.keys():
                            if column != "proForma_key" and column != "proForma_no_loc":
                                if column == "Precursor_Mz" and self.diannVersion == "1.9.0":
                                    dictCols["Precursor_Mz"].append(0)
                                else:
                                    dictCols[column].append(tabline[self.colnames[column]])
                        proForma = proFormaTranslateUnimodDiannToMass(tabline[self.colnames["Modified_Sequence"]])
                        dictCols["proForma_key"].append(proForma)
                        proForma_no_loc = proFormaTranslateUnimodDiannToMass(tabline[self.colnames["Modified_Sequence"]], no_pose=True)
                        dictCols["proForma_no_loc"].append(proForma_no_loc)
                        # controlDict.write(f'{dictDiann[str(rownum)]["Modified.Sequence"]}\t{dictDiann[str(rownum)]["Stripped.Sequence"]}\t{dictDiann[str(rownum)]["proForma_key"]}\t{dictDiann[str(rownum)]["Run"]}\n')
                        rownum += 1
                    else:
                        removedline.write(json.dumps(line))
                        removedline.write("\n")
                        nonInferedLine += 1
            if rownum%100000 == 0:
                stop = time.time()
                sys.stdout.write(f"* {rownum} lines in {round(stop-start, 2)} seconde\n")
                sys.stdout.flush()
                start = time.time()
            elif rownum%10000 == 0:
                sys.stdout.write("*")
                sys.stdout.flush()
        stop = time.time()
        sys.stdout.write(f"* {rownum} lines in {round(stop-start, 2)} seconde\n")
        sys.stdout.flush()
        for column in dictCols.keys():
            print(f"{column} = {len(dictCols[column])}")
        self.dictCols = pd.DataFrame(dictCols)
        self.dictCols = self.dictCols.astype({"Precursor_Charge": "float","Precursor_Mz": "float", "Precursor_Quantity": "float","RT": "float","Quantity_Quality": "float","PG_Q_Value": "float",})
        print(f'The result files contains {len(dictCols["Modified_Sequence"])} lines')

    def getMsrunfileNames(self):
        return self.dictCols['Run'].unique().tolist()

    def filteringResult(self, col_threshold):
        before_filtering = self.dictCols.shape[0]
        for col in col_threshold:
            if col in self.dictCols.columns.tolist():
                if col == "Quantity_Quality":
                    self.dictCols["Quantity_Quality"] = self.dictCols[self.dictCols["Quantity_Quality"].astype(float) > col_threshold[col]]["Quantity_Quality"]
                else:
                    self.dictCols[col] = self.dictCols[self.dictCols[col].astype(float) < col_threshold[col]][col]
        self.dictCols = self.dictCols.dropna()
        print(f"Number of line removed : {before_filtering - self.dictCols.shape[0]}")

    def mergePeptidoform(self):
        print(self.dictCols.shape[0])
        self.dictCols = self.dictCols.groupby(["Run", "Run_Index", "Stripped_Sequence",
					"Precursor_Charge", "Protein_Group", "proForma_no_loc"], as_index = False).agg({"Quantity_Quality": 'min', "PG_Q_Value": 'max', 'RT': 'mean', 'Precursor_Mz': 'mean', 'Precursor_Quantity':'sum', "proForma_key": list, "Modified_Sequence": list})
        self.dictCols['proForma_key'] = self.dictCols['proForma_key'].str.join(" ")
        self.dictCols['Modified_Sequence'] = self.dictCols['Modified_Sequence'].str.join(" ")
        print(self.dictCols.shape[0])

    def getUniqueSequences(self):
        return self.dictCols['Stripped_Sequence'].unique().tolist()

    def getDF(self):
        return self.dictCols

#     def computePrecursorMZ(self,):
#         
