#!/usr/bin/env python3

"""
Dr. Brian Fristensky, University of Manitoba

 Description: Wrapper for running pal2nal.pl

 Synopsis: pal2nal.py  protfile dnafile outfile


@modified: May 28, 2021
@author: Brian Fristensky
@contact: frist@cc.umanitoba.ca
"""

import argparse
import os.path
import subprocess
import sys

PROGRAM = 'pal2nal.py: '
USAGE = '\n\t USAGE: pal2nal.py protfile dnafile outfile'
PID = str(os.getpid())
DEBUG = True

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
"Wrapper class for command line parameters"
class Parameters:

    def __init__(self):
        """
                Initializes arguments:
                Then calls read_args() to fill in their values from command line
                """

        self.PROTFILE = "" 
        self.DNAFILE = ""
        self.OUTFILE = "" 
        self.PAL2NALARGS = []
        self.read_args()

    def read_args(self):
        """
                Read command line arguments into a Paramters object
                """   
        try:
            self.PROTFILE = sys.argv[1]
            self.DNAFILE = sys.argv[2]
            self.OUTFILE = sys.argv[3]
            if len(sys.argv) > 3 :
                self.PAL2NALARGS = sys.argv[4:]
        except ValueError:
            print(USAGE)

        if DEBUG :
            print("PROTFILE: " + self.PROTFILE)
            print("DNAFILE: " + self.DNAFILE)
            print("OUTFILE: " + self.OUTFILE)
            print("PAL2NALARGS: " + str(self.PAL2NALARGS))

class Sequence:

    def __init__(self):
        """
        Holds name and sequence
                """
        self.Name = ""
        self.Seq = ""
        self.SeqLen = 0


class SeqData:

    def __init__(self):
        """
        Holds sequences and associated data
                """
        self.SeqLst = []
        self.NumSeq = 0

def ReadFasta(SFILE, S):
    """
    Read sequences from a fasta file in the form:

    >name
    sequence
    sequence
    sequence...
    
    """
    in_file = open(SFILE, 'r')
    line = in_file.readline()
    S.NumSeq = 0
    while line != '':
        if line[0] == '>':
            #print(line)
            tempSeq = Sequence()
            tokens = line.strip().split(" ")
            tempName = tokens[0][1:]
            tempSeq.Name = tempName.replace(":CDS","_") # get rid of CDS tag from BioLegato
            line = in_file.readline()
            while line != '' and line[0] != '>':
                tempSeq.Seq = tempSeq.Seq + line.strip()
                tempSeq.SeqLen = len(tempSeq.Seq)
                #print(line)
                line = in_file.readline()
            # Replace an X at the end of a protein with a gap char
            if tempSeq.Seq.find('X') > -1 :
                print(tempSeq.Name + ' replacing X with gap')
                tempSeq.Seq = tempSeq.Seq.replace('X','-')
            S.SeqLst.append(tempSeq)
            S.NumSeq = S.NumSeq + 1
    in_file.close()


def WriteFlat(SFILE, S, FlagChar):
    """
    Write the data in various flat file formats.
    fasta - FlagChar = '>'
    flatdna - FlagChar = '#'
    flatpro - FlagChar = '%'
    flattext - FlagChar = '"'
    """

    LineLen = 50
    SFILE.write(FlagChar + S.Name)
    SFILE.write('\n')
    Start = 0
    while Start < S.SeqLen:
        Finish = Start + LineLen - 1
        if Finish >= S.SeqLen:
            Finish = S.SeqLen - 1
        SFILE.write(S.Seq[Start:Finish + 1])
        SFILE.write('\n')
        Start = Finish + 1

def RunPal2Nal(PROTFILE,DNAFILE,OUTFILE,PAL2NALARGS) :

    print("running pal2nal.pl...")
    CLIST = ["pal2nal.pl", PROTFILE, DNAFILE, "-output", "fasta"]
    CLIST.extend(PAL2NALARGS)
    CLIST.extend([">", OUTFILE])
    COMMAND = " ".join(CLIST)
    print(COMMAND)
    os.system(COMMAND)

    return

def main():
    """
        Called when not in documentation mode.
        """
    P = Parameters()

    # Read in sequence files. Names are shortened to the first non-blank token
    # on the name line. This eliminates some of the extra stuff sometimes written
    # to FASTA files, such as sequence length. As well, change the :CDS tags
    # added by the features program to "_". The point of these steps is to ensure
    # that the DNA and protein names are exactly the same.
    ProtAlign = SeqData()
    ReadFasta(P.PROTFILE,ProtAlign)
    DNA= SeqData()
    ReadFasta(P.DNAFILE,DNA)

    # Write tenporary protein files with shortened names 
    TEMPPROTFN = PID + ".pal2nal.tempprot"
    outfile = open(TEMPPROTFN,"w")
    for prot in ProtAlign.SeqLst :
        #print(prot.Name)
        WriteFlat(outfile,prot,'>')
    outfile.close()

    # Create a dictionary of DNA sequence objects keyed on dna sequence name
    # This makes it quicker to write the DNA sequences in the same order as
    # the aligned proteins.
    ddict = {}
    #print(str(len(DNA.SeqLst)) + ' DNA sequences read.')
    for d in DNA.SeqLst :
        ddict[d.Name] = d
    #print(ddict.keys())

    # write DNA sequences in the order the names appear in the protein list
    TEMPDNAFN = PID + ".pal2nal.tempdna"
    outfile = open(TEMPDNAFN,"w")
    for prot in ProtAlign.SeqLst :
        #print('Checking for ' + prot.Name)
        #print(prot.Name, str(ddict[prot.Name].SeqLen))
        if prot.Name in ddict :
            WriteFlat(outfile,ddict[prot.Name],'>')
    outfile.close()

    # Run pal2nal.pl
    RunPal2Nal(TEMPPROTFN,TEMPDNAFN,P.OUTFILE,P.PAL2NALARGS)

    # Clean up
    os.remove(TEMPPROTFN)
    os.remove(TEMPDNAFN)

if '-test' in sys.argv:
    pass
else:
    main()
