#!/usr/bin/env python3

# Write validated Ilumina reads to output file.
"""
    fixfq.py - from a single fastq file, create files containing valid reads and bad reads

@modified: May 19 2021
@author: Brian Fristensky
@contact: brian.fristensky@umanitoba.ca

"""

import argparse
import datetime
import gc
import os
import re
import sys

# - - - - - - -  GLOBAL VARIABLES - - - - - - - -
PROGRAM = os.path.basename(sys.argv[0]) + ": "   # preceeds print messages
USAGE = "\n\t USAGE: fixfq.py [--filelist] infile"
DEBUG = True
NL = "\n"
BLANKLINE = " " + NL
 
Pattern = r'[AGCTN]+'
NUCLEOTIDES = re.compile(Pattern)

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
"Wrapper class for command line parameters"
class Parameters:

    def __init__(self):
        """
                Initializes arguments:
                Then calls read_args() to fill in their values from command line
                """

        self.IFN = "" 
        self.FILELIST = False # True if --filelist was set. 
        self.read_args()

    def read_args(self):
        """
                Read command line arguments into a Paramters object
                """
        parser = argparse.ArgumentParser()
        parser.add_argument("--filelist", dest="filelist", action="store_true", help="if set, file contains list of filenames")
        parser.add_argument("infile", action="store", default="", help="input file")

        try:
            args = parser.parse_args()
            if args.filelist :
                self.FILELIST = True
            self.IFN = args.infile

        except ValueError:
            print(USAGE)

        if DEBUG :
            print("FILELIST: " + str(self.FILELIST))
            print("IFN: " + self.IFN)


# --------------------------------------------
def ReadList(F) :
    lfile = open(F,"r")
    FILES = []
    for line in lfile.readlines() :
        FILES.append(line.strip())
    lfile.close()
    return FILES

# --------------------------------------------
class Phred :
    """
    Implement Phred scores as dictionaries.
    """
    def __init__(self):
        """
          Initializes arguments
        """
        self.Illumina18 = {}
        for n in range(33,127) :
            CH = chr(n)
            self.Illumina18[CH]=chr(n-33)
        self.transtable = str.maketrans(self.Illumina18)


pdict = Phred()




# --------------------------------------------
class Read :
    """
    Sequencing read. Assumes standard Illumina 4-line read
    """
    def __init__(self):
        self.Id = BLANKLINE
        self.Seq = BLANKLINE
        self.Sep = BLANKLINE
        self.Qual = BLANKLINE
        self.Valid = False

    # First, get the next 4 or fewer lines, if we are at the end of the file
    # Only keep reading lines until a @ is reached.
    def GetLines (self,lnum) :
        self.Id = lines[lnum]
        lnum += 1
        if lnum < totallines and lines[lnum][0] != "@" :
            self.Seq = lines[lnum]
            lnum +=1
            if lnum < totallines and lines[lnum][0] != "@" :
                self.Sep = lines[lnum]
                lnum +=1            
                if lnum < totallines and lines[lnum][0] != "@" :
                    self.Qual = lines[lnum]
                    self.Valid = True
                    lnum +=1     
        return lnum

    # For a presumptive read, validate 
    def ParseRead(self) :
        if  not NUCLEOTIDES.fullmatch(self.Seq.strip()) :
            self.Valid = False
        if self.Sep[0] != "+" : 
            self.Valid = False
        if (len(self.Qual) != len(self.Seq)) :                    
            self.Valid = False
        return


    def WriteRead(self,outfile) :
        outfile.write(self.Id)
        outfile.write(self.Seq)
        outfile.write(self.Sep)
        outfile.write(self.Qual)

#========================    MAIN   =============================

print("========== " + PROGRAM + " ==========")
P = Parameters()

#Convert arguments to variables

if P.FILELIST :
    FILES = ReadList(P.IFN)
else:
    FILES = [P.IFN]

start_time = datetime.datetime.now()
print("Start time: " + str(start_time))
for F in FILES :
    print("---------- " + F + " ----------")
    basename = os.path.splitext(F)[0]
    infile = open(F,"r")
    GOODFN = basename + '_valid.fq'
    goodfile = open(GOODFN,'w')
    BADFN = basename + '_bad.fq'
    badfile = open(BADFN,'w')

    # Read the entire input file 
    print(PROGRAM + "Reading " + F + "...")
    lines = infile.readlines()
    infile.close()
    totallines = len(lines)
    print(PROGRAM + "Total lines: " + str(totallines))
    #Limit = totallines - 4 # Highest possible line number to still have 4 read lines
    #print(PROGRAM + "Limit: " + str(Limit))

    # Get presumptive reads one at a time, and validate them. 
    lnum = 0 # current line number
    r = Read()

    # I have gotten segmentation faults working with files of several Gb.
    # Garbage collection may help, especially on machines with a lot of other
    # processes running.
    gc.collect()
    if DEBUG :
        print(PROGRAM + "Garbage successfully collected.")

    print(PROGRAM + "Writing output files...")
    goodreads = 0
    badreads = 0
    while lnum < totallines :
        if lines[lnum][0] != "@" :
            badfile.write(lines[lnum])
            lnum += 1
        else :
            try:
                lnum = r.GetLines(lnum)
            except e :
                print(PROGRAM + "GetLines error: " + e)

            # A valid fastq read has 4 lines
            # For debugging, print a message every 10,000 lines
            #if lnum % 1000000 == 0 :
            #    print(PROGRAM + "line# " + str(lnum))
            if r.Valid :
                try :
                    r.ParseRead()
                except e :
                    print(PROGRAM + "ParseRead error: " + e)
            if r.Valid :
                r.WriteRead(goodfile)
                goodreads += 1
            else :
                r.WriteRead(badfile)
                badreads += 1
            r.__init__()

    print("Total valid reads written to " + GOODFN + ": " + str(goodreads)) 
    print("Total bad reads written to "  + BADFN + ": " + str(badreads)) 
    print(" ")
    goodfile.close()
    badfile.close()

finish_time = datetime.datetime.now()
print("Finish time: " + str(finish_time))
time_elapsed = finish_time - start_time
print("Elapsed time on " + os.uname()[1] + ": " + str(time_elapsed) + " seconds")



