#!/usr/bin/env python3

# Write validated Ilumina reads to output file.
"""
    fixfq.py - from a single fastq file, create files containing valid reads and bad reads

@modified: May 23 2021
@author: Brian Fristensky
@contact: brian.fristensky@umanitoba.ca

"""

# An earlier version, fixfq-bigmem.py, read the entire input file into memory before processing.
# That version was only about 3.5% faster than this version, which reads a line at a time
# from the input file as it works. This version uses only a trivial amount of RAM, so it's
# safer for personal computers.


import argparse
import datetime
import os
import re
import sys

# - - - - - - -  GLOBAL VARIABLES - - - - - - - -
PROGRAM = os.path.basename(sys.argv[0]) + ": "   # preceeds print messages
USAGE = "\n\t USAGE: fixfq.py [--filelist] infile"
DEBUG = True
NL = "\n"
BLANKLINE = " " + NL
 
Pattern = r'[AGCTN]+'
NUCLEOTIDES = re.compile(Pattern)

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
"Wrapper class for command line parameters"
class Parameters:

    def __init__(self):
        """
                Initializes arguments:
                Then calls read_args() to fill in their values from command line
                """

        self.IFN = "" 
        self.FILELIST = False # True if --filelist was set. 
        self.read_args()

    def read_args(self):
        """
                Read command line arguments into a Paramters object
                """
        parser = argparse.ArgumentParser()
        parser.add_argument("--filelist", dest="filelist", action="store_true", help="if set, file contains list of filenames")
        parser.add_argument("infile", action="store", default="", help="input file")

        try:
            args = parser.parse_args()
            if args.filelist :
                self.FILELIST = True
            self.IFN = args.infile

        except ValueError:
            print(USAGE)

        if DEBUG :
            print("FILELIST: " + str(self.FILELIST))
            print("IFN: " + self.IFN)


# --------------------------------------------
def ReadList(F) :
    lfile = open(F,"r")
    FILES = []
    for line in lfile.readlines() :
        FILES.append(line.strip())
    lfile.close()
    return FILES

# --------------------------------------------
class Phred :
    """
    Implement Phred scores as dictionaries.
    """
    def __init__(self):
        """
          Initializes arguments
        """
        self.Illumina18 = {}
        Illumina18Pattern = "["
        for n in range(33,127) :
            CH = chr(n)
            self.Illumina18[CH]=chr(n-33)
            Illumina18Pattern += CH
        self.transtable = str.maketrans(self.Illumina18)
        Illumina18Pattern +=  "]+"
        if DEBUG:
            print(Illumina18Pattern)
        self.Illumina18regex = re.compile(Illumina18Pattern)

# --------------------------------------------
class Read :
    """
    Sequencing read. Assumes standard Illumina 4-line read
    """
    def __init__(self):
        self.Id = BLANKLINE
        self.Seq = BLANKLINE
        self.Sep = BLANKLINE
        self.Qual = BLANKLINE
        self.Valid = False

    # First, get the next 4 or fewer lines, if we are at the end of the file
    # Only keep reading lines until a @ is reached.
    def GetLines (self,line,lnum) :
        self.Id = line
        line = infile.readline()
        lnum += 1
        if line != "" and line[0] != "@" :
            self.Seq = line
            line = infile.readline()
            lnum +=1
            if line != "" and line[0] != "@" :
                self.Sep = line
                line = infile.readline()
                lnum +=1            
                # Since "@" is also a legal quality symbol, we can't rule
                # out the fourth line just because it starts with "@"
                if line != "" : 
                    self.Qual = line
                    self.Valid = True
                    line = infile.readline()
                    lnum +=1     
        return line,lnum

    # For a presumptive read, validate 
    def ParseRead(self) :
        if  not NUCLEOTIDES.fullmatch(self.Seq.strip()) :
            self.Valid = False
        if self.Sep[0] != "+" : 
            self.Valid = False
        if (len(self.Qual) != len(self.Seq)) :                    
            self.Valid = False
        elif not phred.Illumina18regex.match(self.Qual.strip()) :
            self.Valid = False
        return

    def WriteRead(self,outfile) :
        outfile.write(self.Id)
        outfile.write(self.Seq)
        outfile.write(self.Sep)
        outfile.write(self.Qual)

#========================    MAIN   =============================

print("========== " + PROGRAM + " ==========")
P = Parameters()

phred = Phred()

#Convert arguments to variables

if P.FILELIST :
    FILES = ReadList(P.IFN)
else:
    FILES = [P.IFN]

start_time = datetime.datetime.now()
print("Start time: " + str(start_time))
for F in FILES :
    print("---------- " + F + " ----------")
    basename = os.path.splitext(F)[0]
    infile = open(F,"r")
    GOODFN = basename + '_valid.fq'
    goodfile = open(GOODFN,'w')
    BADFN = basename + '_bad.fq'
    badfile = open(BADFN,'w')

    line = infile.readline()
    totallines = 0

    # Get presumptive reads one at a time, and validate them. 
    lnum = 0 # current line number
    r = Read()

    print(PROGRAM + "Writing output files...")
    goodreads = 0
    badreads = 0
    while line != "" :
        if line[0] != "@" :
            badfile.write(line)
            line = infile.readline()
            lnum += 1
        else :
            try:
                line,lnum = r.GetLines(line,lnum)
            except e :
                print(PROGRAM + "GetLines error: " + e)

            # A valid fastq read has 4 lines
            # For debugging, print a message every 10,000 lines
            #if lnum % 1000000 == 0 :
            #    print(PROGRAM + "line# " + str(lnum))
            if r.Valid :
                try :
                    r.ParseRead()
                except e :
                    print(PROGRAM + "ParseRead error: " + e)
            if r.Valid :
                r.WriteRead(goodfile)
                goodreads += 1
            else :
                r.WriteRead(badfile)
                badreads += 1
            r.__init__()

    print("Total valid reads written to " + GOODFN + ": " + str(goodreads)) 
    print("Total bad reads written to "  + BADFN + ": " + str(badreads)) 
    print(" ")
    goodfile.close()
    badfile.close()
    infile.close()


finish_time = datetime.datetime.now()
print("Finish time: " + str(finish_time))
time_elapsed = finish_time - start_time
print("Elapsed time on " + os.uname()[1] + ": " + str(time_elapsed) + " seconds")



