#!/usr/bin/env python
"""collectParameters: 

Usage:
  collectParameters [options] file1 [file2 ...]
Combine and optionally transform parameter files from PARADIGM
Options:
        -o file    output to this file instead of stdout
        -p         transform expectations to parameters
        -f         transform parameters to factors
        -v         verbose output
"""
import os, os.path, sys, getopt, re

inputType = "parameters"
outputType = "parameters"
pseudoCount = 0.0
verbose = True

def usage(code = 0):
    print __doc__
    if code != None: sys.exit(code)

def log(msg, die = False):
    if (verbose):
        sys.stderr.write(msg)
    if die:
        sys.exit(1)

def matchingFiles(path, substring):
    files = []
    if len(path) == 0:
        if substring.startswith("/"):
            path = "/"
            substring = substring.lstrip("/")
        else:
            path = "./"
    parsedstring = re.split("/", substring)
    match = "^%s$" % (re.sub("\*", ".*", parsedstring.pop(0)))
    for file in os.listdir(path):
        if re.match(match, file):
            if len(parsedstring) == 0:
                files.append(path + file)
            else:
                files += matchingFiles(path + file + "/", "/".join(parsedstring))
    return (files)

def readStream(stream, paramData = {}):
    header = stream.readline().rstrip("\n\r")
    if "HEADER" not in paramData:
        paramData["HEADER"] = header
    for line in stream:
        if line.startswith(">"):
            factor = line.rstrip("\n\r")
            pseudoCount = 1.0
            totalDim = 9
            targetDim = 3
            sourceDim = totalDim/targetDim
            if factor not in paramData:
                paramData[factor] = {}
                for i in range(sourceDim):
                    paramData[factor][i] = {}
                    for j in range(targetDim):
                        paramData[factor][i][j] = []
            sourceIndex = 0
            targetIndex = 0
        else:
            pline = re.split("\t", line.rstrip("/r/n"))
            paramData[factor][sourceIndex][targetIndex].append(float(pline[-1]))
            targetIndex += 1
            if targetIndex == targetDim:
                sourceIndex += 1
                targetIndex = 0
    return (paramData)

def main(args):
    ## parse arguments
    try:
        opts, args = getopt.getopt(args, "o:pfq")
    except getopt.GetoptError, err:
        print str(err)
        usage(2)
    
    inputs = args
    
    if len(inputs) < 1:
        print "incorrect number of arguments"
        usage(1)
    
    outFile = None
    toParameters = False
    toFactors = False
    
    global inputType, outputType, verbose
    for o, a in opts:
        if o == "-o":
            outFile = a
        elif o == "-p":
            toParameters = True
        elif o == "-f":
            toFactors = True
        elif o == "-q":
            verbose = False
    
    if toParameters:
        inputType = "expectations"
        outputType = "parameters"
    else:
        inputType = "parameters"
        if toFactors:
            outputType = "factors"
        else:
            outputType = "parameters"
    
    ## read inputs
    inputData = {}
    for input in inputs:
        if input == "/dev/stdin":
            stream = sys.stdin
            inputData = readStream(stream, paramData = inputData)
        else:
            for file in matchingFiles("", "outputFilesEM/*learn*"):
                stream = open(file, "r")
                inputData = readStream(stream, paramData = inputData)
                stream.close()
    
    ## compute combineData
    combineData = {}
    if inputType == "expectations":
        for factor in inputData.keys():
            if factor == "HEADER":
                combineData["HEADER"] = inputData["HEADER"]
                continue
            else:
                combineData[factor] = {}
                for i in inputData[factor].keys():
                    combineData[factor][i] = {}
                    for j in inputData[factor][i].keys():
                        combineData[factor][i][j] = sum(inputData[factor][i][j]) + pseudoCount
    if inputType == "parameters":
        for factor in inputData.keys():
            if factor == "HEADER":
                combineData["HEADER"] = inputData["HEADER"]
                continue
            else:
                combineData[factor] = {}
                for i in inputData[factor].keys():
                    combineData[factor][i] = {}
                    for j in inputData[factor][i].keys():
                        assert(len(inputData[factor][i][j]) == 1)
                        combineData[factor][i][j] = inputData[factor][i][j][0]
    
    ## convert types
    outputData = {}
    if toParameters:
        for factor in combineData.keys():
            if factor == "HEADER":
                outputData["HEADER"] = combineData["HEADER"]
                continue
            else:
                outputData[factor] = {}
                for i in combineData[factor].keys():
                    outputData[factor][i] = {}
                    for j in combineData[factor][i].keys():
                        outputData[factor][i][j] = combineData[factor][i][j]/sum(combineData[factor][i][k] for k in combineData[factor][i].keys())
    else:
        outputData = combineData
    
    ## output
    if outFile is not None:
        f = open(outFile, "w")
        f.write("%s\n" % (outputData["HEADER"]))
        for factor in outputData.keys():
            if factor.startswith(">"):
                f.write("%s\n" % (factor))
                for sourceIndex in range(3):
                    for targetIndex in range(3):
                        f.write("%s\t%s\t%4f\n" % (targetIndex, sourceIndex, outputData[factor][sourceIndex][targetIndex]))
    else:
        print "%s" % (outputData["HEADER"])
        for factor in outputData.keys():
            if factor.startswith(">"):
                print "%s" % (factor)
                for sourceIndex in range(3):
                    for targetIndex in range(3):
                        print "%s\t%s\t%4f" % (targetIndex, sourceIndex, outputData[factor][sourceIndex][targetIndex])

if __name__ == "__main__":
    main(sys.argv[1:])
