#!/usr/bin/env python3

# SPDX-FileCopyrightText: 2012-2022 Tobias Leupold <tl at stonemx dot de>
#
# SPDX-License-Identifier: GPL-3.0-or-later

"""
    gpx2svg - Convert GPX formatted geodata to Scalable Vector Graphics (SVG)
"""

__version__ = '@VERSION@'

import argparse
import sys
import math
from xml.dom.minidom import parse as parseXml
from os.path import abspath
from enum import Enum

class Projection(Enum):
    Mercator = 0
    WGS84 = 1

def parseGpx(gpxFile):
    """Get the latitude and longitude data of all track segments in a GPX file"""

    if gpxFile == '/dev/stdin':
        gpxFile = sys.stdin

    # Get the XML information
    try:
        gpx = parseXml(gpxFile)
    except IOError as error:
        print('Error while reading file: {}. Terminating.'.format(error), file = sys.stderr)
        sys.exit(1)
    except:
        print('Error while parsing XML data:', file = sys.stderr)
        print(sys.exc_info(), file = sys.stderr)
        print('Terminating.', file = sys.stderr)
        sys.exit(1)

    # Iterate over all tracks, track segments and points
    gpsData = []
    for track in gpx.getElementsByTagName('trk'):
        for trackseg in track.getElementsByTagName('trkseg'):
            trackSegData = []
            for point in trackseg.getElementsByTagName('trkpt'):
                trackSegData.append(
                    (float(point.attributes['lon'].value), float(point.attributes['lat'].value))
                )
            # Leave out empty segments
            if(trackSegData != []):
                gpsData.append(trackSegData)

    return gpsData

def calcProjection(gpsData, projection):
    """Calculate a plane projection for a GPS dataset"""

    projectedData = []

    if projection == Projection.Mercator:
        for segment in gpsData:
            projectedSegment = []
            for coord in segment:
                projectedSegment.append(mercatorProjection(coord))

            projectedData.append(projectedSegment)

    elif projection == Projection.WGS84:
        minX, maxX, minY, maxY = extentOfProjectedData(gpsData)

        for segment in gpsData:
            projectedSegment = []
            for coord in segment:
                projectedSegment.append(wgs84Projection(coord, minY, minX))

            projectedData.append(projectedSegment)

    return projectedData

def mercatorProjection(coord):
    """Calculate the Mercator projection of a coordinate pair"""

    # Assuming we're on earth, we have (according to GRS 80):
    r = 6378137.0

    # As long as meridian = 0 and can't be changed, we don't need:
    #    meridian = meridian * math.pi / 180.0
    #    x = r * ((coord[0] * math.pi / 180.0) - meridian)

    # Instead, we use this simplified version:
    x = r * coord[0] * math.pi / 180.0
    y = r * math.log(math.tan((math.pi / 4.0) + ((coord[1] * math.pi / 180.0) / 2.0)))

    return x, y

def wgs84Projection(coord, minY, minX):
    """Calculate the WGS84 projection of a coordinate pair"""

    # Equatorial circumfence according to WGS84
    equatorialCircumfence = 40075016.68557849 # 6378137.0      * 2 * math.pi
    # Polar circumfence according to WGS84
    polarCircumfence = 39940652.74224401      # 6356752.314245 * 2 * math.pi

    deltaLatitude = coord[1] - minY
    deltaLongitude = coord[0] - minX
    latitudeCircumference = equatorialCircumfence * math.cos(minY * math.pi / 180.0)
    x = deltaLongitude * latitudeCircumference / 360
    y = deltaLatitude * polarCircumfence / 360

    return x, y

def extentOfProjectedData(gpsData):
    """Get the extent of a dataset and return the resulting minX, maxX, minY, maxY"""

    minX = maxX = gpsData[0][0][0]
    minY = maxY = gpsData[0][0][1]
    for segment in gpsData:
        for coord in segment:
            if coord[0] < minX:
                minX = coord[0]
            if coord[0] > maxX:
                maxX = coord[0]
            if coord[1] < minY:
                minY = coord[1]
            if coord[1] > maxY:
                maxY = coord[1]

    return minX, maxX, minY, maxY

def moveProjectedData(gpsData):
    """Move a dataset to 0,0 and return it with the resulting width and height"""

    # Find the minimum and maximum x and y coordinates
    minX, maxX, minY, maxY = extentOfProjectedData(gpsData)

    # Move the GPS data to 0,0
    movedGpsData = []
    for segment in gpsData:
        movedSegment = []
        for coord in segment:
            movedSegment.append((coord[0] - minX, coord[1] - minY))
        movedGpsData.append(movedSegment)

    # Return the moved data and it's width and height
    return movedGpsData, maxX - minX, maxY - minY

def searchCircularSegments(gpsData):
    """Splits a GPS dataset to tracks that are circular and other tracks"""

    circularSegments = []
    straightSegments = []

    for segment in gpsData:
        if segment[0] == segment[len(segment) - 1]:
            circularSegments.append(segment)
        else:
            straightSegments.append(segment)

    return circularSegments, straightSegments

def combineSegmentPairs(gpsData):
    """Combine segment pairs to one bigger segment"""

    combinedData = []

    # Walk through the GPS data and search for segment pairs
    # that end with the starting point of another track
    while len(gpsData) > 0:
        # Get one segment from the source GPS data
        firstTrackData = gpsData.pop()
        foundMatch = False

        # Try to find a matching segment
        for i in range(len(gpsData)):
            if firstTrackData[len(firstTrackData) - 1] == gpsData[i][0]:
                # There is a matching segment, so break here
                foundMatch = True
                break

        if foundMatch == True:
            # We found a pair of segments with one shared point, so pop the data of the second
            # segment from the source GPS data and create a new segment containing all data, but
            # without the overlapping point
            firstTrackData.pop()
            combinedData.append(firstTrackData + gpsData[i])
            gpsData.pop(i)
        else:
            # No segment with a shared point was found, so just append the data to the output
            combinedData.append(firstTrackData)

    return searchCircularSegments(combinedData)

def combineSegments(gpsData):
    """Combine all segments of a GPS dataset that can be combined"""

    # Search for circular segments. We can't combine them with any other segment.
    circularSegments, remainingSegments = searchCircularSegments(gpsData)

    # Search for segments that can be combined
    while True:
        # Look how many tracks we have now
        segmentsBefore = len(remainingSegments)

        # Search for segments that can be combined
        newCircularSegments, remainingSegments = combineSegmentPairs(remainingSegments)

        # Add newly found circular segments to processedSegments -- they can't be used anymore
        circularSegments = circularSegments + newCircularSegments

        if segmentsBefore == len(remainingSegments):
            # combineSegmentPairs() did not reduce the number of tracks anymore,
            # so we can't combine more tracks and can stop here
            break

    return circularSegments + remainingSegments

def chronologyJoinSegments(gpsData):
    """Join all segments to a big one in the order defined by the GPX file."""
    joinedSegment = []
    for segment in gpsData:
        joinedSegment += segment
    return [joinedSegment]

def scaleCoords(coord, height, scale):
    """Return a scaled pair of coordinates"""
    return coord[0] * scale, (coord[1] * -1 + height) * scale

def generateScaledSegment(segment, height, scale):
    """Create the coordinate part of an SVG path string from a GPS data segment"""
    for coord in segment:
        yield scaleCoords(coord, height, scale)

def writeSvgData(gpsData, width, height, cmdArgs):
    """Output the SVG data -- quick 'n' dirty, without messing around with dom stuff ;-)"""

    # Pass the cmdArgs settings to nicer variables
    maxPixels = cmdArgs.m
    dropSinglePoints = cmdArgs.d
    outfile = cmdArgs.o
    scaleFactor = cmdArgs.s

    if scaleFactor == None:
        # Calculate the scale factor we need to fit the requested maximal output size
        if width <= maxPixels and height <= maxPixels:
            scale = 1
        elif width > height:
            scale = maxPixels / width
        else:
            scale = maxPixels / height

    else:
        scale = 1.0 / scaleFactor

    # Open the requested output file or map to /dev/stdout
    if outfile != '/dev/stdout':
        try:
            fp = open(outfile, 'w')
        except IOError as error:
            print("Can't open output file: {}. Terminating.".format(error), file = sys.stderr)
            sys.exit(1)
    else:
        fp = sys.stdout

    # Header data
    fp.write( '<?xml version="1.0" encoding="UTF-8" standalone="no"?>\n')
    fp.write(('<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" '
              '"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">\n'))
    fp.write( '<!-- Created with gpx2svg {} -->\n'.format(__version__))

    if scaleFactor == None:
        unit = 'px'
    else:
        unit = 'mm'

    fp.write(('<svg xmlns="http://www.w3.org/2000/svg" version="1.1" '
              'width="{width}{unit}" height="{height}{unit}" '
              'viewBox="0 0 {width} {height}">\n').format(unit = unit,
                                                          width = width * scale,
                                                          height = height * scale))

    # Process all track segments and generate ids and path drawing commands for them

    # First, we split the data to circular and straight segments
    circularSegments, straightSegments = searchCircularSegments(gpsData)
    realCircularSegments = []
    singlePoints = []
    for segment in circularSegments:
        # We can leave out the last point, because it's equal to the first one
        segment.pop()
        if len(segment) == 1:
            # It's a single point
            if dropSinglePoints == False:
                # We want to keep single points, so add it to singlePoints
                singlePoints.append(segment)
        else:
            realCircularSegments.append(segment)

    circularSegments = realCircularSegments

    # Draw single points if requested
    if len(singlePoints) > 0:
        fp.write('<g>\n')
        for segment in singlePoints:
            x, y = scaleCoords(segment[0], height, scale)
            fp.write(
                '<circle cx="{}" cy="{}" r="0.5" style="stroke:none;fill:black"/>\n'.format(x, y)
            )
        fp.write('</g>\n')

    # Draw all circular segments
    if len(circularSegments) > 0:
        fp.write('<g>\n')
        for segment in circularSegments:
            fp.write('<path d="M')
            for x, y in generateScaledSegment(segment, height, scale):
                fp.write(' {} {}'.format(x, y))
            fp.write(' Z" style="fill:none;stroke:black"/>\n')
        fp.write('</g>\n')

    # Draw all un-closed paths
    if len(straightSegments) > 0:
        fp.write('<g>\n')
        for segment in straightSegments:
            fp.write('<path d="M')
            for x, y in generateScaledSegment(segment, height, scale):
                fp.write(' {} {}'.format(x, y))
            fp.write('" style="fill:none;stroke:black"/>\n')
        fp.write('</g>\n')

    # Close the XML
    fp.write('</svg>\n')

    # Close the file if necessary
    if fp != sys.stdout:
        fp.close()

def main():
    # Setup the command line argument parser
    cmdArgParser = argparse.ArgumentParser(
        description = 'Convert GPX formatted geodata to Scalable Vector Graphics (SVG)',
        epilog = 'gpx2svg {} - https://nasauber.de/opensource/gpx2svg/'.format(__version__)
    )
    cmdArgParser.add_argument(
        '-i', metavar = 'FILE', nargs = '?', type = str, default = '/dev/stdin',
        help = 'GPX input file (default: read from STDIN)'
    )
    cmdArgParser.add_argument(
        '-o', metavar = 'FILE', nargs = '?', type = str, default = '/dev/stdout',
        help = 'SVG output file (default: write to STDOUT)'
    )
    cmdArgParser.add_argument(
        '-p', choices = [ 'mercator', 'wgs84' ], default = 'mercator',
        help = 'Projection to use. The default setting is "mercator"'
    )
    group = cmdArgParser.add_mutually_exclusive_group()
    group.add_argument(
        '-m', metavar = 'PIXELS', nargs = '?', type = int, default = 3000,
        help = 'Maximum width or height of the SVG output in pixels (default: 3000). Can\'t be '
               'used together with the "-s" switch'
    )
    group.add_argument(
        '-s', metavar = 'SCALE', nargs = 1, type = int,
        help = 'Scale factor for the WGS84 projection. When this is set, the output will be in mm '
               'instead of px, with 1 mm representing SCALE mm of the original data (e. g. a value '
               'of 1000 will produce an SVG file with 1 mm representing 1 m). Can\'t be used '
               'together with the "-m" switch'
    )
    cmdArgParser.add_argument(
        '-d', action = 'store_true',
        help = 'Drop single points (default: draw a circle with 1px diameter)'
    )
    cmdArgParser.add_argument(
        '-r', action = 'store_true',
        help = ('"Raw" conversion: Create one SVG path per track segment, don\'t try to combine '
                'paths that end with the starting point of another path')
    )
    cmdArgParser.add_argument(
        '-j', action = 'store_true',
        help = ('Join all segments to a big one in the order of the GPX file. This can create an '
                'un-scattered path if the default combining algorithm does not work because there '
                'are no matching points across segments (implies -r)')
    )

    # Get the given arguments
    cmdArgs = cmdArgParser.parse_args()

    # Map "-" to STDIN or STDOUT
    if cmdArgs.i == '-':
        cmdArgs.i = '/dev/stdin'
    if cmdArgs.o == '-':
        cmdArgs.o = '/dev/stdout'

    # Check if a given input or output file name is a relative representation of STDIN or STDOUT
    if cmdArgs.i != '/dev/stdin':
        if abspath(cmdArgs.i) == '/dev/stdin':
            cmdArgs.i = '/dev/stdin'
    if cmdArgs.o != '/dev/stdout':
        if abspath(cmdArgs.o) == '/dev/stdout':
            cmdArgs.o = '/dev/stdout'

    # Get the requested projection and it's specific arguments

    if cmdArgs.p == 'mercator':
        if cmdArgs.s != None:
            print('The scale option can only be set when the WGS84 projection is used (-p wgs84). '
                  'Terminating.', file = sys.stderr)
            sys.exit(1)

        projection = Projection.Mercator

    elif cmdArgs.p == 'wgs84':
        projection = Projection.WGS84

        if cmdArgs.s != None:
            cmdArgs.s = cmdArgs.s[0] / 1000.0

    # Get the latitude and longitude data from the given GPX file or STDIN
    gpsData = parseGpx(cmdArgs.i)

    # Check if we actually _have_ data
    if gpsData == []:
        print('No data to convert. Terminating.', file = sys.stderr)
        sys.exit(1)

    # Join all segments if requested by "-j"
    if cmdArgs.j:
        gpsData = chronologyJoinSegments(gpsData)

    # Try to combine all track segments that can be combined if not requested otherwise
    # Don't execute if all segments are already joined with "-j"
    if not cmdArgs.r and not cmdArgs.j:
        gpsData = combineSegments(gpsData)

    # Calculate a plane projection for a GPS dataset
    # At the moment, we only have the Mercator projection
    gpsData = calcProjection(gpsData, projection)

    # Move the projected data to the 0,0 origin of a cartesial coordinate system
    # and get the raw width and height of the resulting vector data
    gpsData, width, height = moveProjectedData(gpsData)

    # Write the resulting SVG data to the requested output file or STDOUT
    writeSvgData(gpsData, width, height, cmdArgs)

if __name__ == '__main__':
    main()
