#!/usr/bin/env python

import json
import sys
import traceback
import click
from time import sleep
import paho.mqtt.client as mqtt
from random import choice
from Crypto.PublicKey import RSA
from Crypto.Random import get_random_bytes
from Crypto.Cipher import AES, PKCS1_OAEP
from base64 import b64encode
from cockroachpoker.game import Game, Player, Action, Suit
from cockroachpoker.signed_message import verify


@click.command()
@click.option('--mqtt-host', default='euterpe3')
@click.option('--mqtt-port', default=1885)
@click.option('--mqtt-user')
@click.option('--mqtt-pass')
@click.option('--number-of-players',
    default=3,
    help='Number of players in game')
@click.argument('name')
def host_game(number_of_players, name, mqtt_host, mqtt_port, mqtt_user,
    mqtt_pass):
  topic_root = 's1m0n.r3dd1ng@gmail.com/cockroach-poker/{}/'.format(name)
  game = Game()
  pub_keys = {}

  def on_pass(client, msg):
    print('on_pass', msg)
    message = json.loads(msg['message'])
    game.pass_on(msg['from'], message['to'], message['card'], message['claim'])
    client.publish(topic_root + 'table', json.dumps(game.table()))

  def publish_encrypted(name, topic_leaf, data):
    session_key = get_random_bytes(16)
    cipher_rsa = PKCS1_OAEP.new(pub_keys[name])
    enc_session_key = cipher_rsa.encrypt(session_key)
    cipher_aes = AES.new(session_key, AES.MODE_EAX)
    ciphertext, tag = cipher_aes.encrypt_and_digest(
        json.dumps(data).encode('utf-8'))
    msg = {
      'ciphertext': b64encode(ciphertext).decode('utf-8'),
      'tag': b64encode(tag).decode('utf-8'),
      'session_key': b64encode(enc_session_key).decode('utf-8'),
      'nonce': b64encode(cipher_aes.nonce).decode('utf-8')
      }
    print(msg)
    client.publish(
        topic_root + '{}/{}'.format(name, topic_leaf),
        json.dumps(msg))

  def on_call(client, msg):
    print('on_call:', msg)
    verify(pub_keys[msg['from']], msg)
    message = json.loads(msg['message'])
    game.call(message)
    loser = game.check_loser()
    if loser:
      print(loser, 'lost.')
      client.publish(topic_root + 'loser', json.dumps(loser))
      client.disconnect()
    else:
      for n, p in game.players.items():
        publish_encrypted(n, 'hand', p.hand)
      client.publish(topic_root + 'table', json.dumps(game.table()))

  def on_will_pass(client, msg):
    print('on_will_pass:', msg)
    verify(pub_keys[msg['from']], msg)
    card, to_players = game.will_pass()
    publish_encrypted(msg['from'], 'to-pass', json.dumps(card))
    client.publish(topic_root + 'table', json.dumps(game.table()))

  def on_play(client, msg):
    print('on_play:', msg)
    verify(pub_keys[msg['from']], msg)
    message = json.loads(msg['message'])
    game.play(msg['from'], message['to'], message['card'], message['claim'])
    client.publish(topic_root + 'table', json.dumps(game.table()))

  def on_join(client, msg):
    print('on_join:', msg)
    pub_keys[msg['name']] = RSA.importKey(msg['pub_key'].encode())
    sleep(1)
    game.join(msg['name'])
    if len(game.players) == number_of_players:
      client.unsubscribe(topic_root + 'join')
      print('The players are all here. Dealing...')
      game.deal()
      for n, p in game.players.items():
        publish_encrypted(n, 'hand', p.hand)
    client.publish(topic_root + 'table', json.dumps(game.table()))

  message_handlers = {
      'pass': on_pass,
      'will-pass': on_will_pass,
      'call': on_call,
      'join': on_join,
      'play': on_play
      }

  def on_connect(client, userdata, flags, rc):
    print('on_connect: rc=', rc)
    print('subscribing to', topic_root + 'join')
    client.subscribe(topic_root + 'join')
    print('subscribing to', topic_root + 'play')
    client.subscribe(topic_root + 'play')
    print('subscribing to', topic_root + 'call')
    client.subscribe(topic_root + 'call')
    print('subscribing to', topic_root + 'will-pass')
    client.subscribe(topic_root + 'will-pass')
    print('subscribing to', topic_root + 'pass')
    client.subscribe(topic_root + 'pass')

  def on_message(client, userdata, message):
    try:
      ts = message.topic.split('/')
      print(message.topic, message.payload, ts[-1])
      message_handlers[ts[-1]](client,
          json.loads(message.payload.decode('utf-8')))
    except Exception as e:
      print(traceback.format_exc())
      raise(e)


  client = mqtt.Client()
  client.on_connect = on_connect
  client.on_message = on_message
  if mqtt_user:
    client.username_pw_set(mqtt_user, mqtt_pass)
  print('connection to', mqtt_host, mqtt_port)
  client.connect(mqtt_host, mqtt_port, 60)
  client.loop_forever()

if __name__ == '__main__':
  host_game()
