#!/usr/bin/env python3

import argparse
import datetime
import getpass
import os
import re
import stat
import subprocess
import sys
import time

'''
bl_magicblast.py - Trim  adaptors from Illumina reads

Synopsis: bl_magicblast.py tsvfile outdir database [--mbargs {magicblast arguments}]


@modified: April 4, 2021
@author: Brian Fristensky
@contact: Brian.Fristensky@umanitoba.ca  
'''

PROGRAM = "bl_magicblast.py : "
USAGE = "\n\tUSAGE: bl_magicblast.py tsvfile outdir [--skgrep match|mismatch] database [--mbargs {magicblast arguments}]"

DEBUG = True
if DEBUG :
    print('bl_magicblast: Debugging mode on')


# - - - - - - - - - - - - - Utility classes - - - - - - - - - - - - - - - - -
def chmod_ar(filename):
    """
    Make a file world-readable.
    """
    if os.path.exists(filename):
        st = os.stat(filename)
        os.chmod(filename, st.st_mode | stat.S_IREAD \
        | stat.S_IRGRP | stat.S_IROTH)

		
def chmod_arx(filename):
    """
    Make a file or directory world-readable and world-executable/searchable.
    """
    if os.path.exists(filename):
        st = os.stat(filename)
        os.chmod(filename, st.st_mode | stat.S_IEXEC | stat.S_IREAD \
            | stat.S_IXGRP | stat.S_IRGRP | stat.S_IXOTH \
            | stat.S_IROTH)

def LocalHostname():
    """
    Return the name of the local machine. Tries a number of methods
    to get a name other than 'localhost' or a null result.
    """
    import socket
    import platform

    def CheckName(name) :
        if name == None or name.startswith("localhost") or name == "" :
            OKAY = False
        else :
            OKAY = True
        return OKAY

    name = os.getenv('HOSTNAME') 

    if not CheckName(name) :
        name = platform.uname()[1]

    if not CheckName(name) :
        if socket.gethostname().find('.')>=0:
            name=socket.gethostname()
        else:
            name=socket.gethostbyaddr(socket.gethostname())[0]

    return name


def SendEmail(From,To,Subject,Text) :
    """
        Very simple email method adapted from:
        http://stackoverflow.com/questions/882712/sending-html-email-using-python
        There are more elaborate examples on this site for sending
        HTML messages and attachments.
    """
    import smtplib
    from email.mime.multipart import MIMEMultipart
    from email.mime.text import MIMEText

    Host = 'localhost'

    msg = MIMEMultipart('alternative')
    msg['Subject'] = Subject
    Html = """\
        <html>
          <head></head>
          <body>
            <p>
            %s
            </p>
          </body>
        </html>
        """ %(Text)
    part1 = MIMEText(Text, 'plain')
    part2 = MIMEText(Html, 'html')
    msg.attach(part1)
    msg.attach(part2)

    try:
       server = smtplib.SMTP(Host)
       server.sendmail(From, To, msg.as_string())
       server.quit()         

       print("Successfully sent email")
    except :
       print("Error: unable to send email")


# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
class Parameters:
    """
      	Wrapper class for command line parameters
      	"""
    def __init__(self):
        """
     	  Initializes arguments:
                TSVFILE = ""
                OUTDIR = ""

     	  Then calls read_args() to fill in their values from command line
          """
        self.TSVFILE = ""
        self.OUTDIR = ""
        self.THREADS = "1"
        self.SKGREP = []
        self.DBTYPE = "blast"
        self.DBFILE = ""
        self.EMAIL = "" # not used 
        self.MBLASTARGS = [] 
        self.read_args()

        if DEBUG :
            print('------------ Parameters from command line ------')
            print('    TSVFILE: ' + self.TSVFILE)
            print('    OUTDIR: ' + self.OUTDIR)
            print('    THREADS: ' + self.THREADS)
            print('    SKGREP: ' + str(self.SKGREP))
            print('    DBTYPE: ' + self.DBTYPE)
            print('    DBFILE: ' + self.DBFILE)
            print('    MBLASTARGS: ' + str(self.MBLASTARGS))
            print()  

    def read_args(self):
        """
        	Read command line arguments into a Parameter object
    	"""

        parser = argparse.ArgumentParser()
        parser.add_argument("tsvfile", action="store", default="", help="input file")
        parser.add_argument("outdir", action="store", default="", help="directory for all output files")
        parser.add_argument("threads", action="store", default="1", help="number of threads to use")
        parser.add_argument("--skgrep", action="store", default="", help="match|mismatch")
        parser.add_argument("--database", nargs=2, action="store", default="", help="dbtype filename")
        #parser.add_argument("--mblastargs", nargs='*', action="append", default=[], help="arguments to magicblast")
        #parser.add_argument("--mblastargs", type=list, action="store", default=[], help="arguments to magicblast")

        try:
            #args = parser.parse_args()
            args,argv = parser.parse_known_args()
            self.TSVFILE = args.tsvfile
            self.OUTDIR = args.outdir 
            self.THREADS = args.threads
            if args.skgrep == "match" :
                self.SKGREP = ["match"]
            elif args.skgrep == "mismatch" :
                self.SKGREP = ["mismatch"]
            elif args.skgrep == "both" :
                self.SKGREP = ["match","mismatch"]
            else :
                self.SKGREP = []
            self.DBTYPE = args.database[0]
            self.DBFILE = args.database[1]
            #self.MBLASTARGS = args.mblastargs
            self.MBLASTARGS = argv
        except ValueError:
            print(USAGE)

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
class TSVFiles :
    """
    Methods for reading lists of paired read TSV files, and for
    writing lists to output.
    """
    def __init__(self):
        """
     	  Initializes arguments:
                READPAIRS = []

          """
        self.READPAIRS = []            

    def ReadTSVfile(self,FN) :
        """
        TSV file containing names of paired-end and/or single end read files.
        Paired-end files are on lines such as

        leftreadfile.fq.gz<TAB>rightreadfile.fq.gz

        Single-end files have a each file on a separate line

        reads1.fq.gz
        reads2.fq.gz
        reads3.fq.gz
        """
        TAB = '\t'
        F = open(FN,"r")
        for line in F.readlines() :
            line = line.strip()
            if len(line) > 0 and not line.startswith('#') :
                # get rid of double quotes that enclose fields when some programs write
                # output, and then split by TABs.
                tokens = line.replace('"','').split(TAB)

                # ignore blank fields. Add either single or pair of filenames
                # to list. Only process names from first two fields on a line
                # and ignore other fields. 
                if len(tokens) > 0 :
                    r1 = tokens[0].strip()
                    if len(r1) > 0 :
                        fnames = [r1]
                    else :
                        fnames = []
                    if len(tokens) > 1 :
                        r2 = tokens[1].strip()
                        if len(r2) > 0 :
                            fnames.append(r2)
                    if len(fnames) > 0 :
                        self.READPAIRS.append(fnames)
        if DEBUG :
            print(str(self.READPAIRS))
        F.close()


# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def RunMagicblast(P,PR,LOGFILE) :

    # Given an input filename, the basename and file extension as a list
    def FNparts(FN) :
        if FN.endswith(".gz") : # remove .gz extension
            TempName = FN[:-3]
        else :
            TempName = FN
        Components = TempName.rpartition(".")
        BaseName = os.path.basename(Components[0])
        Sep = Components[1]
        Ext = Components[2]
        return [BaseName,Ext]

    # Return the longest identical substring in common
    # between s1 and s2, reading from left to right.
    def FirstDiff(s1,s2) :
        I = 0
        MINLEN=min(len(s1),len(s2))
        DONE = False
        while s1[I] == s2[I] and I < MINLEN :
                I +=1

        return s1[:I] 

    # Count the number of lines in a file. Similar to wc -l in Unix
    def lc(FN) :
        f = open(FN,"r")
        numlines = len(f.readlines())
        f.close()
        return numlines
        
    # Extract names of hits from magicblast output
    def tsv2name(TFN,NFN) :
        TFile = open(TFN,"r")
        NFile = open(NFN,"w")
        NL = '\n'
        for line in TFile.readlines():
            if not line.startswith('#') :
                name = line.split()[0]
                NFile.write(name + NL)
        NFile.close()
        TFile.close()


    # ----- Step 1: Parse read file name(s) into components to be used for output filename.

    PAIRED = len(PR) == 2

    if PAIRED :
        TFN0 = FNparts(PR[0])
        TFN1 = FNparts(PR[1])
        if TFN0[1] == TFN1[1] : # both have the same file extension
            OKAY = True
        else :
            OKAY = False
            print(" ".join([">>>Paired files ",PR[0],"and",PR[1]," may be of different types."]))
            print("Also, make sure that all files have an extension indicating the type of file.")
            print("For a fastq file, examples might include: reads.fq, reads.fq.gz, reads.fastq, reads.fastq.gz")
            print("For a fasta file, examples might include: reads.fsn, reads.fsn.gz etc.")
            print("Fasta file extensions could be almost anything et. txt, txt.gz")
            print(">>>Aborting magicblast.py")
    else:
        TFN0 = FNparts(PR[0])
        OKAY = True

    if OKAY :
        Ext = TFN0[1]

        # Create an output filename

        if PAIRED :
            BaseName = FirstDiff(TFN0[0],TFN1[0])
        else :
            BaseName = TFN0[0]
        MBoutfile = ".".join([BaseName,"magicblast","tsv"])

        # Construct the command string - - - - - - - - - - - - - - 
        # Why spread this out over many lines rather than using a long string?
        # Because it gives us easy coherent access to different parts of the command string
        # at a moment's notice. Personally, I find that complex "one liners" are less 
        # readable, not more readable.
        COMMAND=["magicblast"]

        if PAIRED :
            COMMAND.extend(["-query", PR[0], "-query_mate", PR[1]])   
        else :   
            COMMAND.extend(["-query", PR[0]])
        if Ext in ["fastq","fq"] :
            COMMAND.extend(["-infmt","fastq"])
        else :
            COMMAND.extend(["-infmt","fasta"])

        if P.DBTYPE == "blast" :
            COMMAND.extend(["-db", FNparts(P.DBFILE)[0]])
        else :
            COMMAND.extend(["-subject", P.DBFILE])

        OUTFMT = "tabular"
        COMMAND.extend(["-outfmt",OUTFMT,"-out",MBoutfile])
        COMMAND.extend(["-num_threads", P.THREADS])

        COMMAND.extend(P.MBLASTARGS)  

        print("COMMAND: ",COMMAND)

        # Run magicblast - - - - - - - - - - - - - - - - -
        LOGFILE.write("======== Magicblast on " + LocalHostname() + " ==========" + "\n")
        LOGFILE.write('COMMAND: ' + str(COMMAND) + '\n')
        StartTime = datetime.datetime.now()
        LOGFILE.write('Start time: ' + str(StartTime) + '\n')
        LOGFILE.write('\n')
        LOGFILE.flush()
        p = subprocess.Popen(COMMAND,stdout=LOGFILE,stderr=LOGFILE)
        p.wait()
        p.communicate()
        MBretcode = p.returncode

        if  not MBretcode == 0 :
            LOGFILE.write("magicblast failed with return code: " + str(MBretcode) + "\n")
            LOGFILE.write('\n')
        LOGFILE.flush()
        NumHits = lc(MBoutfile)
        LOGFILE.write(MBoutfile + "\t" + str(NumHits)+ "\tHits" + "\n")
        LOGFILE.write('\n')

        if  not MBretcode == 0 :
            LOGFILE.write("Aborting bl_magicblast.py \n")
            LOGFILE.write('\n')
        else:
            # Run seqkit grep to extract reads from read files which match/mismatch names
            # in magicblast output.
            for OutputChoice in P.SKGREP :
                NameFile = BaseName + ".magicblast.nam"
                tsv2name(MBoutfile,NameFile)

                com1 = ["seqkit","grep","-j",P.THREADS]
                com1.extend(["-f",NameFile])
                if OutputChoice == "mismatch" : #
                    com1.extend(["-v"])
                    MatchFlag = "-"
                    LOGFILE.write("\t" + "- - - - - - seqkit grep: Mismatch - - - - - " + '\n')
                else :
                    MatchFlag = "+"
                    LOGFILE.write("\t" + "- - - - - - seqkit grep: Match - - - - - " + '\n')

                #print("com1: " + str(com1))

                for ReadFile in PR :
                    # this is a shallow copy of com1, meaning that COMMAND gets a copy of com1, 
                    # rather than a reference to com1
                    COMMAND = com1[:]
                    Components = FNparts(ReadFile)
                    SKoutfile = os.path.basename(Components[0]) + "_" + FNparts(P.DBFILE)[0] + MatchFlag + "." + Components[1]
                    COMMAND.extend(["-o",SKoutfile])
                    COMMAND.extend([ReadFile])
                    print("COMMAND: " + str(COMMAND))
                    p = subprocess.Popen(COMMAND,stdout=LOGFILE,stderr=LOGFILE)
                    p.wait()
                    p.communicate()
                    LOGFILE.flush() 
                    SKretcode = p.returncode
                    if not SKretcode == 0 :
                        LOGFILE.write("seakit grep return code: " + str(SKretcode) + "\n")
                        LOGFILE.write('\n')
                    else :
                        OutputLines = lc(SKoutfile)        
                        if Ext in ["fastq","fq"] : # For every sequence in fastq files there are 4 lines
                            HitsWritten = OutputLines/4
                        else : # 2 lines assumed in a fasta file for each sequence
                            HitsWritten = OutputLines/2
                        LOGFILE.write("\t" + SKoutfile + "\t" + str(HitsWritten)+ "\tsequences extracted from " + ReadFile + "\n")
                        LOGFILE.write('\n')
                    LOGFILE.flush()        

        FinishTime = datetime.datetime.now()
        LOGFILE.write('Finish time: ' + str(FinishTime) + '\n')
        ElapsedTime = FinishTime - StartTime
        LOGFILE.write('Elapsed time: ' + str(ElapsedTime) + '\n')
        LOGFILE.write('\n')

#======================== MAIN PROCEDURE ==========================
def main():
    """
        Called when not in documentation mode.
        """
	
    # Read parameters from command line
    P = Parameters()
    #exit()

    TF = TSVFiles()
    if not P.TSVFILE == "" :
        TF.ReadTSVfile(P.TSVFILE)
        OKAY = True
    else :
        OKAY = False

    # Create output directory, if it doesn't already exist.
    if OKAY and P.OUTDIR != "" :
        if not os.path.isdir(P.OUTDIR) :
            os.mkdir(P.OUTDIR)
                             
    LOGFN = os.path.join(P.OUTDIR,"bl_magicblast.log")
    LOGFILE = open(LOGFN,'w')
    LOGFILE.write('\n')

    # Run magicblast
    if OKAY :
        for PR in TF.READPAIRS :
            RunMagicblast(P,PR,LOGFILE)
          
    LOGFILE.close()

    # Notify user when job is done, if email address was
    # supplied using --email
    if P.EMAIL != "" :
        Sender = getpass.getuser() + '@' + LocalHostname()
        Subject = 'bl_magicblast.py completed'
        Message = 'bl_magicblast.py: Completed<br>'
        LOGFILE = open(os.path.join('bl_magicblast.log'),'r')
        for line in LOGFILE.readlines() :
            Message = Message + line + '<br>'
        LOGFILE.close()
        SendEmail(Sender,[P.EMAIL],Subject,Message)


if __name__ == "__main__":
    main()

