import os
import json
import zlib
from base64 import b64decode

from google.cloud import bigquery
from google.oauth2 import service_account

TABLE_NAME = "<project_id>.<dataset>.<table-name>"

def transform(json_payload):
    rows_to_insert = []

    for row in json_payload['logEvents']:
        item = {}
        item['id'] = row['id']
        item['timestamp'] = row['timestamp'] / 1000
        if 'extractedFields' in row:
            for k, v in row['extractedFields'].items():
                item[k] = v

        rows_to_insert.append(item)
    return rows_to_insert


def _get_client_key():
    project_id = os.environ['project_id']
    private_key_id = os.environ['private_key_id']
    private_key = os.environ['private_key'].replace("\\n", "\n")
    client_email = os.environ['client_email']
    client_id = os.environ['client_id']

    client_key = {
      "type": "service_account",
      "project_id": project_id,
      "private_key_id": private_key_id,
      "private_key": private_key,
      "client_email": client_email,
      "client_id": client_id,
      "auth_uri": "https://accounts.google.com/o/oauth2/auth",
      "token_uri": "https://oauth2.googleapis.com/token",
      "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
      "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/lambda-bigquery-stream%40hangfive-26bb4.iam.gserviceaccount.com"
    }

    return client_key


def lambda_handler(event, context):
    data = event['awslogs']['data']
    compressed_payload = b64decode(data)
    cloudwatch_payload = zlib.decompress(compressed_payload, 16 + zlib.MAX_WBITS)
    json_payload = json.loads(cloudwatch_payload)


    client_key = _get_client_key()
    credentials = service_account.Credentials.from_service_account_info(
        client_key
    )

    client = bigquery.Client(credentials=credentials, project=credentials.project_id,)
    
    # SKIP control message
    if json_payload['messageType'] == 'CONTROL_MESSAGE':
        return

    rows_to_insert = transform(json_payload)
    # print(rows_to_insert)

    # BQ insert
    errors = client.insert_rows_json(TABLE_NAME, rows_to_insert)  # Make an API request.
    if errors == []:
        print("New rows have been added.")
    else:
        print("Encountered errors while inserting rows: {}".format(errors))