#!/usr/bin/python

import codecs, sys, lxml, os, re
import lxml.etree as etree

## sys.path.append( os.path.split( sys.argv[ 0])[ 0])
## from starparser import writeOutput

###
# from ch_show import show
###

##############################
## EDIT DEFAULTS BELOW HERE ##
##############################

## LEGEND:   '.' suppress output altogether
##           '-' output to stdout or input from stdin
##          '--' output to stderr

inputStream = defaultInputStream = '-'
validateXml = defaultValidateXml = True
errorsStream = defaultErrorsStream = '--'
errorsStreamFO = sys.stderr
prettyStream = defaultPrettyStream = '.'
## xmlAsParsedFileName = None

##############################
## EDIT DEFAULTS ABOVE HERE ##
##############################

defaultEncoding = 'UTF-8'
inputEncoding = defaultEncoding
outputEncoding = defaultEncoding


#######################################################
attOrderList = [
    'expression',
    'tokenDisp',
    'containedIn',
    'binary',
    'primitiveSet',
    'symbol',
    'symFirst',
    'symLast',
    'layerNumber',
    'roleNumber',
    'operator',
    'opFirst',
    'opLast',
    'layerMark',
    'parameterIdentifier',
    'piFirst',
    'piLast',
    'prologue',
    'epilogue',
    'first',
    'last',
    'id',
    'implicit',
    'parser',
]
attOrder = {}
ctr = 1
for attName in attOrderList:
    attOrder[ attName] = ctr
    ctr += 1
def iemlAttrSort( a):
    return attOrder[ a]


#######################################################
def usage( emsg, exitval):
    global errorsStreamFO, defaultInputStream, defaultValidateXml, defaultErrorsStream, defaultPrettyStream

    if len( emsg) > 0:
        writeOutput( errorsStreamFO, emsg, '-errors')

    writeOutput( errorsStreamFO, """
Usage: %s [options]

       This program validates and/or pretty-prints
       Star-XML documents.  

Input options:
--------------

     -i  <input file containing a Star-XML document
         encoded in UTF-8>  
           Use '-' for stdin.  (Default: '%s')
           If stdin is used, then the encoding is assumed
           to be 'raw_unicode_escape'.

-inputEncoding <encoding name>

Processing options:
-------------------

  -validateXml   <'True' or 'False'>  Default: '%s'.  If 'True', check
                   to see whether the XML output conforms to the DTD.
                   The DTD must be provided in the input document's
                   internal subset.  If the document is valid, the exit
                   status will be 0.  If invalid, the exit status will
                   be nonzero, and the XML parser's error report(s)
                   will be output on the -errors stream (see below).
                   If 'False', exit status is 0 unless there is an
                   internal error.

Output streams:
---------------
  Note: Special <filename>s are available for outputs:
     Use  '-' for stdout (standard output)
     Use '--' for stderr (standard error output)
     Use  '.' to suppress default output.

  -errors  <filename> Default: '%s'
             For error reports.

  -pretty  <filename> Default: '%s'
             For XML output that has been formatted for human
             inspection.  If '.', no pretty-print processing will be
             done (saves time!).

Other options:
--------------

         -h  Show this help information and exit.

""" % (
    os.path.split( sys.argv[ 0])[ 1],
    defaultInputStream,                              
    defaultValidateXml,
    defaultErrorsStream,
    defaultPrettyStream,
),
        '-errors',
    )
    if len( emsg) > 0:
        writeOutput( errorsStreamFO, emsg, '-errors')
    sys.exit( exitval)
    
#######################################################
def trueOrFalse( arg):
    if arg.lower() == 'true':
        return True
    elif arg.lower() == 'false':
        return False
    else:
        usage( 'Expected \'True\' or \'False\', but found "%s" instead.' % ( arg), 1)
        
#######################################################
def openStream( fileName, rOrW, encodingName=defaultEncoding):
    """
    Return a File Object ("FO") open for writing or reading
    """

    if fileName == '.':
        return None
    elif fileName == '-':
        if rOrW.startswith( 'w'):
            return sys.stdout
        elif rOrW.startswith( 'r'):
            return sys.stdin
        else:
            errMsg( 'internal error')
            sys.exit( 1)
    elif fileName == '--':
        return sys.stderr

    absFileName = os.path.abspath( fileName)
    FO = codecs.open( absFileName, rOrW, encodingName)
    return FO

#######################################################
def writeOutput( FO, buf, streamOptionName,):

    if FO == sys.stdout or FO == sys.stderr:
##         try:
##             FO.write( buf)
##         except UnicodeEncodeError, e:
##             if FO == sys.stdout:
##                 ioName = 'the standard output stream'
##             else:
##                 ioName = 'the error output stream'
##             FO.write( '\n\n(The following output has been entityRef-ified because it contains non-ASCII characters and yet it had to flow through %s.\nTo avoid this contamination, use a filename as the argument to the %s option.)\n\n' % (
##                 ioName,
##                 streamOptionName,
##             ))
##             FO.write( buf.encode( 'us-ascii', 'xmlcharrefreplace').decode( 'us-ascii'))
        FO.write( buf.encode( 'unicode_escape'))
    elif FO == None:
        return  ## do nothing in this case
    else:
        FO.write( buf)
    FO.flush()

## #######################################################
## def writeOutput( FO, buf):

##     if FO == sys.stdout or FO == sys.stderr:
##         FO.write( buf.encode( 'unicode_escape'))
##     elif FO == None:
##         return  ## do nothing in this case
##     else:
##         FO.write( buf)
##     FO.flush()

#######################################################
def readInput( FO):

    if FO == sys.stdin:
        return FO.read().decode( 'unicode_escape')
    elif FO == None:
        return None ## do nothing in this case
    else:
        return FO.read()

#######################################################
def main():
    global inputStream, validateXml, errorsStream, prettyStream
    global inputStreamFO, errorsStreamFO, prettyStreamFO
    global inputEncoding, outputEncoding

    argCtr = 1
    while argCtr < len( sys.argv):
        arg = sys.argv[ argCtr]
        if arg.lower().startswith( '-h'):
            usage( '', 0)
        elif arg.lower() == '-i':
            argCtr += 1
            inputStream = sys.argv[ argCtr]
##         elif arg.lower() == '-xmlasparsedfilename':
##             argCtr += 1
##             xmlAsParsedFileName = sys.argv[ argCtr]
        elif arg.lower().startswith( '-val'):
            argCtr += 1
            validateXml = trueOrFalse( sys.argv[ argCtr])
        elif arg.lower().startswith( '-err'):
            argCtr += 1
            errorsStream = sys.argv[ argCtr]
        elif arg.lower().startswith( '-pr'):
            argCtr += 1
            prettyStream = sys.argv[ argCtr]
        elif arg.lower().startswith( '-inputencoding'):
            argCtr += 1
            inputEncoding = sys.argv[ argCtr]
        else:
            usage( 'unrecognized argument: "%s"' % ( arg), 1)
        argCtr += 1

    inputStreamFO = openStream( inputStream, 'r', inputEncoding)
    errorsStreamFO = openStream( errorsStream, 'r', outputEncoding)
    prettyStreamFO = openStream( prettyStream, 'w', outputEncoding)

    parser = etree.XMLParser( dtd_validation=validateXml)
    xmlStringToParse = readInput( inputStreamFO)
    try:
        root = etree.fromstring( xmlStringToParse, parser)
    except lxml.etree.XMLSyntaxError, e:
        errorReport = '%s' % ( e)
        writeOutput(
            errorsStreamFO,
            errorReport,
            '-errors',
        )
        sys.exit( 1)  ## this is important

    if prettyStreamFO is not None:
        s = prettyXml( root)

        if validateXml:
            dtdString = re.compile( '<!DOCTYPE.*]>', re.DOTALL).search( xmlStringToParse).group( 0)
            parser = etree.XMLParser( dtd_validation=validateXml)
            xmlStringToParse2 = '<?xml version="1.0"?>\n%s\n%s' % (
                dtdString,
                s,
            )
            try:
                root = etree.fromstring( xmlStringToParse2, parser)
            except lxml.etree.XMLSyntaxError, e:
                errorReport = '%s' % ( e)
                writeOutput(
                    errorsStreamFO,
                    errorReport,
                    '-errors',
                )
                sys.exit( 1)  ## this is important

        writeOutput( prettyStreamFO, s, '-pretty')

    sys.exit( 0)  ## this is important


#######################################################
xmlEncodingRE = re.compile( r'(^<\?xml version="1\.0")( *encoding=".*")( *\?>$)')
def removeXmlEncoding( MO):
    return '%s%s' % ( MO.group( 1), MO.group( 3))

#######################################################
def prettyXml( child):
    global prettyXMLMsg, tagClose

    prettyXMLMsg = ''
    tagClose = ''
    _prettyXml( child, 0)

    if len( tagClose) > 0:
        prettyXMLMsg = '%s%s\n' % (
            prettyXMLMsg,
            tagClose,
        )

    return prettyXMLMsg

#######################################################
def attrString( indentLength, attrDict):
    attrString = ''
    attrNames = sorted( attrDict.keys(), key = iemlAttrSort)
    indentStr = ' '
    if len( attrNames) > 0:
        for attrName in attrNames:
            if attrName == 'tokenDisp':  ## this -if- is specific to starparser.py
###
##                 import pdb
##                 pdb.set_trace()
###
                insideTokenDisp = False
                tokenDispLines = attrDict[ attrName].split( '\n')
                for tokenDispLine in tokenDispLines:
                    if len( tokenDispLine) == 0: continue
                    if not insideTokenDisp:
                        attrString = '%s%s%s="%s\n' % (
                            attrString,
                            indentStr,
                            attrName,
                            tokenDispLine.replace( '&', '&amp;').replace( '<', '&lt;').replace( '>', '&gt;').replace( '"', '&#34;'),
                        )
                    else:
                        attrString = '%s%s%s  %s\n' % (
                            attrString,
                            indentStr,
                            ' ' * len( attrName),
                            tokenDispLine.replace( '&', '&amp;').replace( '<', '&lt;').replace( '>', '&gt;').replace( '"', '&#34;'),
                        )
                    indentStr = ' '*indentLength
                    insideTokenDisp = True
                attrString = '%s"\n' % ( attrString[ :-1])
            else:
                attrString = '%s%s%s="%s"\n' % ( attrString, indentStr, attrName, attrDict[ attrName].replace( '&', '&amp;').replace( '<', '&lt;').replace( '>', '&gt;').replace( '"', '&#34;'))
            indentStr = ' '*indentLength
    return attrString



INDENT = 4
#######################################################
def _prettyXml(child, indent = 0):
    global prettyXMLMsg, tagClose

    pcdata = ''
    if child.text:
        pcdata = child.text

    tail = ''
    if child.tail:
        tail = child.tail

    # starttag stuff here

    if tagClose == '':
        tagCloseStr = ''
    else:
        if not prettyXMLMsg.endswith( '\n'):
            prettyXMLMsg = '%s\n' % ( prettyXMLMsg)
        tagCloseStr = '%s%s' % (
            ' '*indent*INDENT,
            tagClose
        )
        tagClose = ''

    prettyXMLMsg = '%s%s<%s' % (
        prettyXMLMsg,
        tagCloseStr,
        child.tag,
    )
    indx = prettyXMLMsg.rfind( '\n')
    if indx == -1:
        lastLineLength = len( prettyXMLMsg) + 1
    else:
        lastLineLength = ( len( prettyXMLMsg) - indx)

    prettyXMLMsg = '%s%s' % (
        prettyXMLMsg,
        attrString( lastLineLength, child.attrib),
    )
    if len( child.getchildren()) == 0 and pcdata == '':
        tagClose = '/>'
        endTagNeeded = False
    else:
        tagClose = '>'
        endTagNeeded = True

    if len( pcdata) > 0:
        if len( tagClose) > 0:
            if not prettyXMLMsg.endswith( '\n'):
                prettyXMLMsg = '%s\n' % ( prettyXMLMsg)
            prettyXMLMsg = '%s%s%s%s' % (
                prettyXMLMsg,
                ' ' * ( indent + 1) * INDENT,
                tagClose,
                pcdata.replace( '&', '&amp;').replace( '<', '&lt;').replace( '>', '&gt;'),
            )
            tagClose = ''
        else:
            prettyXMLMsg = '%s%s' % (
                prettyXMLMsg,
                pcdata.replace( '&', '&amp;').replace( '<', '&lt;').replace( '>', '&gt;'),
            )

    # internal elements here
    for thisChild in child.getchildren():
        _prettyXml( thisChild, indent + 1)

    # endtag stuff here
    if endTagNeeded:
        if len( tagClose) > 0:
            if not prettyXMLMsg.endswith( '\n'):
                prettyXMLMsg = '%s\n' % ( prettyXMLMsg)
            prettyXMLMsg = '%s%s%s' % (
                prettyXMLMsg,
                ' ' * ( indent) * INDENT,
                tagClose,
            )
            tagClose = ''
        prettyXMLMsg = '%s</%s' % ( 
            prettyXMLMsg,
            child.tag,
        )
        tagClose = '>'

    if len( tail) > 0:
        if len( tagClose) > 0:
            if not prettyXMLMsg.endswith( '\n'):
                prettyXMLMsg = '%s\n' % ( prettyXMLMsg)
            prettyXMLMsg = '%s%s%s' % (
                prettyXMLMsg,
                ' ' * ( indent + 1) * INDENT,
                tagClose,
            )
            tagClose = ''
        prettyXMLMsg = '%s%s' % (
            prettyXMLMsg,
            tail.replace( '&', '&amp;').replace( '<', '&lt;').replace( '>', '&gt;')
        )



#######################################################
if __name__ == '__main__':
    main()

