#!/usr/bin/python

from __future__ import print_function
import os
import sys
import optparse
from collections import defaultdict
try:
    from StringIO import StringIO
except ImportError:
    from io import StringIO
from cprotobuf.plugin_pb import CodeGeneratorRequest, CodeGeneratorResponse, FieldDescriptorProto, FileDescriptorSet
from sys import stdin, stdout
try:
    stdin.buffer
except AttributeError:
    pass
else:
    stdin = stdin.buffer
    stdout = stdout.buffer

CPP_TYPES = {
    FieldDescriptorProto.TYPE_DOUBLE:    'double',
    FieldDescriptorProto.TYPE_FLOAT:     'float',
    FieldDescriptorProto.TYPE_INT32:     'int32',
    FieldDescriptorProto.TYPE_INT64:     'int64',
    FieldDescriptorProto.TYPE_UINT32:    'uint32',
    FieldDescriptorProto.TYPE_UINT64:    'uint64',
    FieldDescriptorProto.TYPE_FIXED32:   'fixed32',
    FieldDescriptorProto.TYPE_FIXED64:   'fixed64',
    FieldDescriptorProto.TYPE_SFIXED32:  'sfixed32',
    FieldDescriptorProto.TYPE_SFIXED64:  'sfixed64',
    FieldDescriptorProto.TYPE_SINT32:    'sint32',
    FieldDescriptorProto.TYPE_SINT64:    'sint64',
    FieldDescriptorProto.TYPE_BOOL:      'bool',
    FieldDescriptorProto.TYPE_STRING:    'string',
    FieldDescriptorProto.TYPE_BYTES:     'bytes',
    FieldDescriptorProto.TYPE_ENUM:      'enum',
}


keywords = ['static', 'class']


def fieldname(s):
    if s in keywords:
        s = s+'_'
    return s


def convert_default_value(v, t):
    if v == 'false':
        return 'False'
    elif v == 'true':
        return 'True'
    elif t == FieldDescriptorProto.TYPE_STRING:
        return repr(v)
    elif t == FieldDescriptorProto.TYPE_BYTES:
        return repr(bytes(v))
    else:
        return v


def real_message_name(type_name):
    parts = type_name.rsplit('.', 2)
    if len(parts) > 2:
        return parts[2]
    else:
        return '.'.join(parts)


def typename(type_name, type, seen_messages):
    if type == FieldDescriptorProto.TYPE_MESSAGE:
        name = real_message_name(type_name)
        # for the case without package specified.
        if name[0] == '.':
            name = name[1:]
        if name in seen_messages:
            return name
        else:
            return '\'%s\'' % name
    else:
        return '\'%s\'' % CPP_TYPES[type]


def write_field(fp, desc, seen_messages):
    ftype = typename(desc.type_name, desc.type, seen_messages)
    fname = fieldname(desc.name)
    fnumber = desc.number
    txt = '    %(fname)-15s = Field(%(ftype)s,\t%(fnumber)d'
    args = []
    if desc.label == FieldDescriptorProto.LABEL_REQUIRED:
        pass
    elif desc.label == FieldDescriptorProto.LABEL_OPTIONAL:
        args.append('required=False')
    elif desc.label == FieldDescriptorProto.LABEL_REPEATED:
        args.append('repeated=True')
        if desc.options.packed:
            args.append('packed=True')
    if desc.default_value:
        args.append('default=%s' % convert_default_value(desc.default_value, desc.type))
    if args:
        txt += ', '
        txt += ', '.join(args)
    txt += ')'
    print(txt % locals(), file=fp)


def write_message(fp, message_descriptor, seen_messages):
    classname = message_descriptor.name
    print('class %(classname)s(ProtoEntity):' % locals(), file=fp)

    for descriptor in message_descriptor.enum_type:
        write_enum(fp, descriptor, 4)

    if message_descriptor.field:
        for field_descriptor in message_descriptor.field:
            write_field(fp, field_descriptor, seen_messages)
    else:
        print('    pass', file=fp)

    print('', file=fp)


def write_enum(fp, descriptor, indent=0):
    print('%s# enum %s' % (' '*indent, descriptor.name), file=fp)
    for value in descriptor.value:
        print('%s%s=%d' % (' '*indent, value.name, value.number), file=fp)


def sort_messages(desces):
    index_by_name = dict((desc.name, desc) for desc in desces)
    deps = defaultdict(set)

    for desc in desces:
        if desc.field:
            for field_desc in desc.field:
                if field_desc.type == FieldDescriptorProto.TYPE_MESSAGE:
                    type_name = real_message_name(field_desc.type_name)
                    if type_name in index_by_name and type_name != desc.name:
                        deps[desc.name].add(type_name)

    while True:
        if not index_by_name:
            break
        for name, attr in index_by_name.items():
            if name not in deps:
                index_by_name.pop(name)
                yield attr

                for k, v in deps.copy().items():
                    try:
                        v.remove(name)
                    except KeyError:
                        pass
                    if not v:
                        del deps[k]

                break
        else:
            for name, attr in index_by_name.items():
                index_by_name.pop(name)
                yield attr


def main(proto_files):
    response = CodeGeneratorResponse()

    packages = defaultdict(list)

    for desc in proto_files:
        packages[desc.package].append(desc)

    seen_messages = set()
    for pkg, desces in packages.items():
        f = response.file.add()
        if pkg == '':
            pkg = 'protos'
        f.name = '%s_pb.py' % pkg.replace('.', '_')

        fp = StringIO()
        print('# coding: utf-8', file=fp)
        print('from cprotobuf import ProtoEntity, Field', file=fp)

        for desc in desces:
            print('# file: %s' % desc.name, file=fp)
            for descriptor in desc.enum_type:
                write_enum(fp, descriptor)
            for descriptor in sort_messages(desc.message_type):
                write_message(fp, descriptor, seen_messages)
                seen_messages.add(descriptor.name)

        f.content = fp.getvalue()

    return response

if __name__ == '__main__':
    if len(sys.argv) > 1:
        parser = optparse.OptionParser()
        parser.add_option("-d", "--output_directory", dest="output_directory",
                          help="directory to write output modules.", metavar="DIRECTORY", default='.')
        opt, args = parser.parse_args()

        for input_file in args:
            request = FileDescriptorSet()
            request.ParseFromString(open(input_file, 'rb').read())
            response = main(request.file)
            for file in response.file:
                open(os.path.join(opt.output_directory, file.name), 'w').write(file.content)
    else:
        request = CodeGeneratorRequest()
        request.ParseFromString(stdin.read())
        response = main(request.proto_file)
        stdout.write(response.SerializeToString())
