#!/usr/bin/env python

import sys
import json
import traceback
import click
from queue import Queue
import paho.mqtt.client as mqtt
from Crypto.PublicKey import RSA
from Crypto.Cipher import AES, PKCS1_OAEP
from base64 import b64decode
from random import sample
from cockroachpoker.game import Suit
from cockroachpoker.signed_message import sign


@click.command()
@click.option('--mqtt-host', default='euterpe3')
@click.option('--mqtt-port', default=1885)
@click.option('--mqtt-user')
@click.option('--mqtt-pass')
@click.argument('name')
@click.argument('game')
def play_cockroach_poker(
    name, game, mqtt_host, mqtt_port, mqtt_user, mqtt_pass):
  topic_root = 's1m0n.r3dd1ng@gmail.com/cockroach-poker/{}/'.format(game)

  def dummy_string():
      s = 'Nothing to see here!'
      return ''.join(sample(s, k=len(s)))

  def sign_and_publish(client, state, topic, msg):
      client.publish(topic_root + topic,
              json.dumps(sign(state['name'], state['key'], json.dumps(msg))))

  def input_choice(prompt, options):
    while True:
      v = input(prompt)
      if v in options:
        return v
      print('It must be one of:', options)

  def show_table(table):
    print('Table:')
    print('\ttabled:')
    for n, cs in table['tabled'].items():
      print('\t\t',n,cs)
    print('\tplayed:', table['played'])
    print('\tnext:', table['next'])

  def show_hand(hand):
    print('Your hand:', sorted(hand))

  def do_play(client, state):
    card = input_choice('Which card will you play?', set(state['hand']))
    claim = input_choice('What will you claim it is?', Suit.values)
    to_player = input_choice('Who will you pass it to?',
        list(set(state['table']['tabled'].keys()) - set([state['name']])))
    sign_and_publish(
        client,
        state,
        'play',
        {'to': to_player, 'card': card, 'claim': claim})

  def do_call(client, state):
    sign_and_publish(client, state, 'call',
                    'y' == input_choice('Agree [y|n]?', ['y', 'n']))

  def do_call_or_pass(client, state):
    print(state['table']['played']['from'][-1],
        "claims it's a",
        state['table']['played']['claim'])
    action = input_choice('Will you [c]all or [p]ass?', ['c', 'p'])
    if action == 'c':
      do_call(client, state)
    else:
      sign_and_publish(client, state, 'will-pass', {'dummy': dummy_string()})

  def do_pass(client, state):
    print('The card is', state['to-pass'])
    claim = input_choice('What will you claim it is?', Suit.values)
    to_player = input_choice('Who will you pass it to?',
        list(set(state['table']['tabled'].keys()) -
            set([state['name']]) -
            set(state['table']['played']['from'])))
    # TODO: should encript this (at least the card)
    sign_and_publish(client, state, 'pass', {
        'to': to_player,
        'card': state['to-pass'],
        'claim': claim})
    state['to-pass'] = None

  action_handlers = {
      'PLAY': do_play,
      'CALL_OR_PASS': do_call_or_pass,
      'PASS': do_pass,
      'CALL': do_call
      }

  def on_to_pass(client, state, card):
    state['to-pass'] = card

  def on_hand(client, state, hand):
    state['hand'] = hand

  def on_table(client, state, table):
    show_table(table)
    state['table'] = table
    show_hand(state['hand'])
    if table['next'] and table['next']['player'] == state['name']:
      print('Your turn to', table['next']['action'])
      action_handlers[table['next']['action']](client, state)

  def on_call(client, state, msg):
    print('{} called {}'.format(msg['from'], msg['message']))

  no_loser = True

  def on_loser(client, state, loser):
    global no_loser
    if loser == state['name']:
      print('You lost the game! ;-(')
    else:
      print(loser, 'lost the game!')
    no_loser = False

  message_handlers = {
      'loser': on_loser,
      'to-pass': on_to_pass,
      'hand': on_hand,
      'table': on_table,
      'call': on_call
      }

  q = Queue()

  def on_message(client, state, message):
    try:
      topic_parts = message.topic.split('/')
      payload = message.payload.decode('utf-8')
      if topic_parts[-2] == state['name']:
        enc_msg = json.loads(payload)
        cipher_rsa = PKCS1_OAEP.new(state['key'])
        session_key = cipher_rsa.decrypt(
                b64decode(enc_msg['session_key'].encode('utf-8')))
        cipher_aes = AES.new(
                session_key,
                AES.MODE_EAX,
                b64decode(enc_msg['nonce'].encode('utf-8')))
        payload = cipher_aes.decrypt_and_verify(
                b64decode(enc_msg['ciphertext'].encode('utf-8')),
                b64decode(enc_msg['tag'].encode('utf-8'))).decode('utf-8')
      q.put(
          (message_handlers[topic_parts[-1]],
          (client, state, json.loads(payload))))
    except Exception as e:
      print(traceback.format_exc())
      raise(e)

  def on_connect(client, state, flags, rc):
    print('on_connect: rc=', rc)
    client.subscribe(topic_root + 'table')
    client.subscribe(topic_root + 'loser')
    client.subscribe(topic_root + 'call')
    client.subscribe(topic_root + '{}/hand'.format(state['name']))
    client.subscribe(topic_root + '{}/to-pass'.format(state['name']))

  key = RSA.generate(2048)
  client = mqtt.Client(
      userdata={'name': name, 'hand': set(), 'table': None, 'key': key})
  client.on_connect = on_connect
  client.on_message = on_message
  if mqtt_user:
    client.username_pw_set(mqtt_user, mqtt_pass)
  print('connection to', mqtt_host, mqtt_port)
  client.connect(mqtt_host, mqtt_port, 60)
  print('publishing to', topic_root + 'join', {'name': name})
  client.publish(
      topic_root + 'join',
      json.dumps(
        {'name': name, 'pub_key': key.publickey().exportKey().decode('utf-8')}))
  client.loop_start()
  while no_loser:
    fn, args = q.get()
    fn(*args)
  client.disconnect()


if __name__ == '__main__':
    play_cockroach_poker()
