from utils.UNIMOD import UNIMOD
import re

# set of tools to generate and manipulate proForma format

# Function to generate proForma string from sequence and mass
# Cterminal modifications  not taking in account in case of array of modstring that indicate there multiple posdition possible by modifications in this case the unknown position tag is used ([+15.99]? for oxydation for exemple)
# Argument:
# * sequence : the sequence to modify (in One Capital character)
# * mods : array of mods string to include in sequences format like old i2masschroq format : "C3:57.0214614868 C4:57.0214614868"
# Return the modified sequence


def proFormaEncode2(sequence, mods):
    if len(mods) < 2:
        return sequence
    else:
        tab_seq = list(sequence)
        tab_mods = mods.split(" ")
        nterm = ""
        for mod in tab_mods:
            tab_mod = mod.split(":")
            # print(tab_mod)
            index = int(tab_mod[0][1:])
            amino_acid = tab_mod[0][0:1]
            mass = round(float(tab_mod[1]),4)
            if mass >0:
                mass= "+"+str(mass)
            if index == 0:
                nterm = "["+str(mass)+"]"
            else:
                tab_seq[index-1]= tab_seq[index-1]+"["+mass+"]"
        result = nterm+"".join(tab_seq)
        return result

def proFormaEncode(mergedMods, no_pose=False):
    tab_seq = list(mergedMods["sequenceJ"])
    tab_mods = mergedMods["fixedMods"]
    nterm= ""
    if no_pose:
        modString = ""
        seq = mergedMods["sequenceJ"]
        mass_tab = []
        for mod in tab_mods:
            tab_mod = mod.split(":")
            mass = round(float(tab_mod[1]),4)
            mass_tab.append(mass)
        for unfixedMod in mergedMods["unfixedMods"]:
            mass_tab.append(float(unfixedMod))
        if len(mass_tab) >0:
            mass_tab = sorted(mass_tab)
            for mass in mass_tab:
                if mass >0:
                    modString  = modString+"[+"+str(mass)+"]"
                else:
                    modString  = modString+"["+str(mass)+"]"
        print(f"{seq} : {mass_tab} : {modString}")
        if modString != "":
            print(f"{seq} : {mass_tab} : ")
            return modString+"?"+seq
        else:
            return seq
    else:
        for mod in tab_mods:
            tab_mod = mod.split(":")
            # print(tab_mod)
            index = int(tab_mod[0][1:])
            amino_acid = tab_mod[0][0:1]
            mass = round(float(tab_mod[1]),4)
            if mass >0:
                mass= "+"+str(mass)
            if index == 0:
                nterm = "["+str(mass)+"]-"
            else:
                tab_seq[index-1]= tab_seq[index-1]+"["+mass+"]"
        result = nterm+"".join(tab_seq)
        if len(mergedMods["unfixedMods"]) >0:
            unfixedString = ""
            for unfixedMod in mergedMods["unfixedMods"]:
                if float(unfixedMod) >0:
                    unfixedString = "[+"+unfixedMod+"]"+unfixedString
                else:
                    unfixedString = "["+unfixedMod+"]"+unfixedString
            result = unfixedString+"?"+result
        return(result)


# convert unimod DIANN to mass proForma
# using UNIMOD dict to convert unimod ids to mass
# Argument:
# * proFormasequence the sequence to modify
# return the modified Sequence

def proFormaTranslateUnimodDiannToMass(proFormasequence, no_pose=False):
    if no_pose:
        modString= ""
        mass_tab = []
        for match in re.finditer(r"(\(UniMod\:[0-9]*\))", proFormasequence):
            unimodString = match.group()[1:-1]
            mass = round(UNIMOD[unimodString]["mass_modification"], 2)
            mass_tab.append(mass)
        if len(mass_tab) >0:
            mass_tab = sorted(mass_tab)
            for mass in mass_tab:
                if mass >0:
                    modString  = modString+"[+"+str(mass)+"]"
                else:
                    modString  = modString+"["+str(mass)+"]"
        proFormasequence = re.sub(r"(\(UniMod\:[0-9]*\))", "", proFormasequence)
        if modString != "":
            proFormasequence = modString+"?"+proFormasequence
    else:
        for match in re.finditer(r"(\(UniMod\:[0-9]*\))", proFormasequence):
            unimodString = match.group()[1:-1]
            mass = round(UNIMOD[unimodString]["mass_modification"], 2)
            if mass >0:
                mass= "+"+str(mass)
            mass = "["+mass+"]"
            if unimodString == "UniMod:1":
                proFormasequence = re.sub(r"(\(UniMod\:[0-9]*\))", "", proFormasequence, count=1)
                proFormasequence = proFormasequence[0]+mass+proFormasequence[1:]
            else:
                proFormasequence = re.sub(r"(\(UniMod\:[0-9]*\))", mass, proFormasequence, count=1)
    proFormasequence = re.sub(r"[IL]", "J", proFormasequence)
    return proFormasequence


# Function to merge multiple modfs string and return fixed position and unknown position
# Argument:
# mods a list of modstrings


def mergedModificationString(modsList):
    ## search for leucine isoleucine ambiguity
    seqs = []
    mods = []
    for item in modsList:
        seqs.append(list(item["sequence"]))
        mods.append(item["modstring"])
    first_seq = seqs[0]
    seqs.pop(0)
    # print(first_seq)
    conflict = False
    while len(seqs) !=0:
        test = seqs[0]
        seqs.pop(0)
        # print(test)
        for i in range(0,len(test)):
            if test[i] != first_seq[i]:
                first_seq[i] = "J"
                conflict = True
    # print(first_seq)
    sequenceJ = "".join(first_seq)
    # print(sequenceJ)
    unfixedMods = list()
    fixedMods = set()
    # print(mods)
    if mods != "":
        treatedmods = mods.copy()
        fixedMods = set(mods[0].split(" "))
        treatedmods.pop(0)
        # print(fixedMods)
        while len(treatedmods) != 0:
            currentModsSet = set(treatedmods[0].split(" "))
            # print(currentModsSet)
            fixedMods = set(list(fixedMods & currentModsSet))
            treatedmods.pop(0)
        # print(fixedMods)
        treatedmods = mods.copy()
        for mod in treatedmods:
            mod = mod.split(" ")
            for submod in mod:
                if (submod not in unfixedMods) & (submod not in fixedMods):
                    unfixedMods.append(submod)
    # print(unfixedMods)
    # detect the number of unfixedmods
    commonUnfixdMods = []
    if len(unfixedMods) != 0:
        nbUnfixedMods = len(modsList[0]["modstring"].split(" "))-len(fixedMods)
        commonUnfixdMods = list(set([x.split(":")[1] for x in unfixedMods]))
    return {"sequenceJ": sequenceJ, "fixedMods":fixedMods, "unfixedMods":commonUnfixdMods}

def translateDiannUnimodToMassAminoFormat(seq):
    previous_match = 0
    mods = []
    mods_string = []
    for match in re.finditer(r"(\(UniMod\:[0-9]*\))", seq):
        #print(f"{match} : postions {match.span()}")
        if match.span()[0] == 0:
            amino_acid = seq[match.span()[1]]
            amino_acid_number = match.span()[0]+1
        else:
            amino_acid = seq[match.span()[0]-1]
            amino_acid_number = match.span()[0]-previous_match

        unimod_key = match.group()[1:-1]
        #print(unimod_key)
        if unimod_key in UNIMOD.keys():
            mono_isotopique = float(UNIMOD[unimod_key]["mass_modification"])
            #print(mono_isotopique)
        else:
            print(f"Modification {unimod_key} is not found please verifiy the id in https://www.unimod.org/")
        mods.append({"amino_acid":amino_acid, "amino_acid_number":amino_acid_number, "mono_isotopique":mono_isotopique})
        previous_match += match.span()[1]-match.span()[0]
        mods_string.append(f"{amino_acid}{amino_acid_number}:{mono_isotopique}")
    sequence = re.sub(r"(\(UniMod\:[0-9]*\))", "", seq)
    return {"sequence" : sequence, "mods" : mods, "modstring" : " ".join(mods_string)}


"P(+42.01)(+15.99)DIPFHPGREDKPQPPPEGRLPD"
def convertPepXMLModToMass(seq):
    workingSeq = seq
    modList = []
    previous_match = 0
    for match in re.finditer(r"(\([+-][0-9]*\.[0-9]+\))", workingSeq):
        pos = match.span()[0]-previous_match
        amino_acid = workingSeq[pos-1]
        mod = match.group(1)[1:-1]
        if mod == "+.98":
            mod = "+0.98"
        modList.append({"amino_acid" : amino_acid, "pos" : pos,"mod": mod})
        workingSeq = workingSeq.replace(match.group(1), "", 1)
        previous_match += match.span()[1]-match.span()[0]
    return(modList)



