-
-
Save frsyuki/8279517 to your computer and use it in GitHub Desktop.
import os | |
import json | |
import httplib | |
import time | |
VERSION = "0.1.0" | |
class ClientSession: | |
def __init__(self, server, user, source=None, catalog=None, schema=None, debug=False): | |
self.server = server | |
self.user = user | |
self.source = source | |
self.catalog = catalog | |
self.schema = schema | |
self.debug = debug | |
class StatementStats: | |
def __init__(self, state=None, scheduled=None, nodes=None, total_splits=None, queued_splits=None, running_splits=None, completed_splits=None, user_time_millis=None, cpu_time_millis=None, wall_time_millis=None, processed_rows=None, processed_bytes=None): | |
self.state = state | |
self.scheduled = scheduled | |
self.nodes = nodes | |
self.total_splits = total_splits | |
self.queued_splits = queued_splits | |
self.running_splits = running_splits | |
self.completed_splits = completed_splits | |
self.user_time_millis = user_time_millis | |
self.cpu_time_millis = cpu_time_millis | |
self.wall_time_millis = wall_time_millis | |
self.processed_rows = processed_rows | |
self.processed_bytes = processed_bytes | |
#self.root_stage = root_stage | |
@classmethod | |
def decode_dict(cls, dic): | |
return StatementStats( | |
state = dic.get("state", None), | |
scheduled = dic.get("scheduled", None), | |
nodes = dic.get("nodes", None), | |
total_splits = dic.get("totalSplits", None), | |
queued_splits = dic.get("queuedSplits", None), | |
running_splits = dic.get("runningSplits", None), | |
completed_splits = dic.get("completedSplits", None), | |
user_time_millis = dic.get("userTimeMillis", None), | |
cpu_time_millis = dic.get("cpuTimeMillis", None), | |
wall_time_millis = dic.get("wallTimeMillis", None), | |
processed_rows = dic.get("processedRows", None), | |
processed_bytes = dic.get("processedBytes", None), | |
#root_stage = StageStats.decode_dict(dic["rootStage", None)), | |
) | |
class Column: | |
def __init__(self, name, type): | |
self.name = name | |
self.type = type | |
@classmethod | |
def decode_dict(cls, dic): | |
return Column( | |
name = dic.get("name"), | |
type = dic.get("type"), | |
) | |
class QueryResults: | |
def __init__(self, id, info_uri=None, partial_cache_uri=None, next_uri=None, columns=None, data=None, stats=None, error=None): | |
self.id = id | |
self.info_uri = info_uri | |
self.partial_cache_uri = partial_cache_uri | |
self.next_uri = next_uri | |
self.columns = columns | |
self.data = data | |
self.stats = stats | |
self.error = error | |
@classmethod | |
def decode_dict(cls, dic): | |
return QueryResults( | |
id = dic.get("id", None), | |
info_uri = dic.get("infoUri", None), | |
partial_cache_uri = dic.get("partialCancelUri", None), | |
next_uri = dic.get("nextUri", None), | |
columns = map((lambda d: Column.decode_dict(d)), dic["columns"]) if dic.has_key("columns") else None, | |
data = dic.get("data", None), | |
stats = StatementStats.decode_dict(dic["stats"]), | |
error = dic.get("error", None), # TODO | |
) | |
class PrestoHeaders: | |
PRESTO_USER = "X-Presto-User" | |
PRESTO_SOURCE = "X-Presto-Source" | |
PRESTO_CATALOG = "X-Presto-Catalog" | |
PRESTO_SCHEMA = "X-Presto-Schema" | |
PRESTO_CURRENT_STATE = "X-Presto-Current-State" | |
PRESTO_MAX_WAIT = "X-Presto-Max-Wait" | |
PRESTO_MAX_SIZE = "X-Presto-Max-Size" | |
PRESTO_PAGE_SEQUENCE_ID = "X-Presto-Page-Sequence-Id" | |
class StatementClient: | |
HEADERS = { | |
"User-Agent": "presto-python/"+VERSION | |
} | |
def __init__(self, http_client, session, query): | |
self.http_client = http_client | |
self.session = session | |
self.query = query | |
self.closed = False | |
self.exception = None | |
self.results = None | |
self._post_query_request() | |
def _post_query_request(self): | |
headers = StatementClient.HEADERS.copy() | |
if self.session.user is not None: | |
headers[PrestoHeaders.PRESTO_USER] = self.session.user | |
if self.session.source is not None: | |
headers[PrestoHeaders.PRESTO_SOURCE] = self.session.source | |
if self.session.catalog is not None: | |
headers[PrestoHeaders.PRESTO_CATALOG] = self.session.catalog | |
if self.session.schema is not None: | |
headers[PrestoHeaders.PRESTO_SCHEMA] = self.session.schema | |
self.http_client.request("POST", "/v1/statement", self.query, headers) | |
response = self.http_client.getresponse() | |
body = response.read() | |
if response.status != 200: | |
raise Exception, "Failed to start query: "+body | |
dic = json.loads(body) | |
self.results = QueryResults.decode_dict(dic) | |
def is_query_failed(self): | |
return self.results.error is not None | |
def is_query_succeeded(self): | |
return self.results.error is None and self.exception is None and self.closed is False | |
def has_next(self): | |
return self.results.next_uri is not None | |
def advance(self): | |
if self.closed or not self.has_next(): | |
return False | |
uri = self.results.next_uri | |
start = time.time() | |
attempts = 0 | |
while True: | |
try: | |
self.http_client.request("GET", uri) | |
except Exception as e: | |
self.exception = e | |
raise e | |
response = self.http_client.getresponse() | |
body = response.read() | |
if response.status == 200 and body: | |
self.results = QueryResults.decode_dict(json.loads(body)) | |
return True | |
if response.status != 503: # retry on Service Unavailable | |
# deterministic error | |
self.exception = Exception("Error fetching next at "+uri+" returned "+str(response.status)+": "+body) # TODO error class | |
raise self.exception | |
if (time.time() - start) < 2*60*60 or self.closed: | |
break | |
self.exception = Exception("Error fetching next") # TODO error class | |
raise self.exception | |
def close(self): | |
if self.closed: | |
return | |
if self.results.next_uri is not None: | |
self.http_client.request("DELETE", self.results.next_uri) | |
self.closed = True | |
class QueryResultIterator: | |
def __init__(self, client): | |
self.client = client | |
self.current_data = client.results.data | |
self.current_offset = 0 | |
def __iter__(self): | |
return self | |
def next(self): | |
if len(self.current_data) > self.current_offset: | |
row = self.current_data[self.current_offset] | |
self.current_offset += 1 | |
return row | |
else: | |
while self.client.has_next(): | |
self.client.advance() | |
if self.client.results.data is not None: | |
self.current_data = self.client.results.data | |
self.current_offset = 1 | |
return self.current_data[0] | |
raise StopIteration | |
class Query: | |
@classmethod | |
def start(cls, session, query): | |
http_client = httplib.HTTPConnection(session.server) | |
return Query(StatementClient(http_client, session, query)) | |
def __init__(self, client): | |
self.client = client | |
def _wait_for_data(self): | |
while self.client.has_next() and self.client.results.data is None: | |
self.client.advance() | |
def columns(self): | |
self._wait_for_data() | |
if not self.client.is_query_succeeded(): | |
self._raise_error() | |
return self.client.results.columns | |
def results(self): | |
if not self.client.is_query_succeeded(): | |
self._raise_error() | |
if self.columns() is None: | |
raise Exception, "Query "+str(self.client.results.id)+" has no columns" | |
return QueryResultIterator(self.client) | |
def _raise_error(self): | |
if self.client.closed: | |
raise Exception, "Query aborted by user" | |
elif self.client.exception is not None: | |
raise Exception, "Query is gone: "+str(self.client.exception) | |
elif self.client.is_query_failed(): | |
results = self.client.results | |
raise Exception, "Query "+str(results.id)+" failed: "+str(results.error) | |
if __name__ == "__main__": | |
session = ClientSession(server="localhost:8880", user="frsyuki", catalog="native", schema="default") | |
q = Query.start(session, "select * from sys.query") | |
print "columns: "+str(q.columns()) | |
for row in q.results(): | |
print row | |
module PrestoClient | |
VERSION = "0.1.0" | |
require 'faraday' | |
require 'json' | |
class ClientSession | |
def initialize(options) | |
@server = options[:server] | |
@user = options[:user] | |
@source = options[:source] | |
@catalog = options[:catalog] | |
@schema = options[:schema] | |
@debug = !!options[:debug] | |
end | |
attr_reader :server | |
attr_reader :user | |
attr_reader :source | |
attr_reader :catalog | |
attr_reader :schema | |
def debug? | |
@debug | |
end | |
end | |
#class StageStats | |
# attr_reader :stage_id | |
# attr_reader :state | |
# attr_reader :done | |
# attr_reader :nodes | |
# attr_reader :total_splits | |
# attr_reader :queued_splits | |
# attr_reader :running_splits | |
# attr_reader :completed_splits | |
# attr_reader :user_time_millis | |
# attr_reader :cpu_time_millis | |
# attr_reader :wall_time_millis | |
# attr_reader :processed_rows | |
# attr_reader :processed_bytes | |
# attr_reader :sub_stages | |
# | |
# def initialize(options={}) | |
# @stage_id = options[:stage_id] | |
# @state = options[:state] | |
# @done = options[:done] | |
# @nodes = options[:nodes] | |
# @total_splits = options[:total_splits] | |
# @queued_splits = options[:queued_splits] | |
# @running_splits = options[:running_splits] | |
# @completed_splits = options[:completed_splits] | |
# @user_time_millis = options[:user_time_millis] | |
# @cpu_time_millis = options[:cpu_time_millis] | |
# @wall_time_millis = options[:wall_time_millis] | |
# @processed_rows = options[:processed_rows] | |
# @processed_bytes = options[:processed_bytes] | |
# @sub_stages = options[:sub_stages] | |
# end | |
# | |
# def self.decode_hash(hash) | |
# new( | |
# stage_id: hash["stageId"], | |
# state: hash["state"], | |
# done: hash["done"], | |
# nodes: hash["nodes"], | |
# total_splits: hash["totalSplits"], | |
# queued_splits: hash["queuedSplits"], | |
# running_splits: hash["runningSplits"], | |
# completed_splits: hash["completedSplits"], | |
# user_time_millis: hash["userTimeMillis"], | |
# cpu_time_millis: hash["cpuTimeMillis"], | |
# wall_time_millis: hash["wallTimeMillis"], | |
# processed_rows: hash["processedRows"], | |
# processed_bytes: hash["processedBytes"], | |
# sub_stages: hash["subStages"].map {|h| StageStats.decode_hash(h) }, | |
# ) | |
# end | |
#end | |
class StatementStats | |
attr_reader :state | |
attr_reader :scheduled | |
attr_reader :nodes | |
attr_reader :total_splits | |
attr_reader :queued_splits | |
attr_reader :running_splits | |
attr_reader :completed_splits | |
attr_reader :user_time_millis | |
attr_reader :cpu_time_millis | |
attr_reader :wall_time_millis | |
attr_reader :processed_rows | |
attr_reader :processed_bytes | |
#attr_reader :root_stage | |
def initialize(options={}) | |
@state = state | |
@scheduled = scheduled | |
@nodes = nodes | |
@total_splits = total_splits | |
@queued_splits = queued_splits | |
@running_splits = running_splits | |
@completed_splits = completed_splits | |
@user_time_millis = user_time_millis | |
@cpu_time_millis = cpu_time_millis | |
@wall_time_millis = wall_time_millis | |
@processed_rows = processed_rows | |
@processed_bytes = processed_bytes | |
#@root_stage = root_stage | |
end | |
def self.decode_hash(hash) | |
new( | |
state: hash["state"], | |
scheduled: hash["scheduled"], | |
nodes: hash["nodes"], | |
total_splits: hash["totalSplits"], | |
queued_splits: hash["queuedSplits"], | |
running_splits: hash["runningSplits"], | |
completed_splits: hash["completedSplits"], | |
user_time_millis: hash["userTimeMillis"], | |
cpu_time_millis: hash["cpuTimeMillis"], | |
wall_time_millis: hash["wallTimeMillis"], | |
processed_rows: hash["processedRows"], | |
processed_bytes: hash["processedBytes"], | |
#root_stage: StageStats.decode_hash(hash["rootStage"]), | |
) | |
end | |
end | |
class Column | |
attr_reader :name | |
attr_reader :type | |
def initialize(options={}) | |
@name = options[:name] | |
@type = options[:type] | |
end | |
def self.decode_hash(hash) | |
new( | |
name: hash["name"], | |
type: hash["type"], | |
) | |
end | |
end | |
class QueryResults | |
attr_reader :id | |
attr_reader :info_uri | |
attr_reader :partial_cache_uri | |
attr_reader :next_uri | |
attr_reader :columns | |
attr_reader :data | |
attr_reader :stats | |
attr_reader :error | |
def initialize(options={}) | |
@id = options[:id] | |
@info_uri = options[:info_uri] | |
@partial_cache_uri = options[:partial_cache_uri] | |
@next_uri = options[:next_uri] | |
@columns = options[:columns] | |
@data = options[:data] | |
@stats = options[:stats] | |
@error = options[:error] | |
end | |
def self.decode_hash(hash) | |
new( | |
id: hash["id"], | |
info_uri: hash["infoUri"], | |
partial_cache_uri: hash["partialCancelUri"], | |
next_uri: hash["nextUri"], | |
columns: hash["columns"] ? hash["columns"].map {|h| Column.decode_hash(h) } : nil, | |
data: hash["data"] | |
stats: StatementStats.decode_hash(hash["stats"]), | |
error: hash["error"], # TODO | |
) | |
end | |
end | |
module PrestoHeaders | |
PRESTO_USER = "X-Presto-User" | |
PRESTO_SOURCE = "X-Presto-Source" | |
PRESTO_CATALOG = "X-Presto-Catalog" | |
PRESTO_SCHEMA = "X-Presto-Schema" | |
PRESTO_CURRENT_STATE = "X-Presto-Current-State" | |
PRESTO_MAX_WAIT = "X-Presto-Max-Wait" | |
PRESTO_MAX_SIZE = "X-Presto-Max-Size" | |
PRESTO_PAGE_SEQUENCE_ID = "X-Presto-Page-Sequence-Id" | |
end | |
class StatementClient | |
HEADERS = { | |
"User-Agent" => "presto-ruby/#{VERSION}" | |
} | |
def initialize(faraday, session, query) | |
@faraday = faraday | |
@faraday.headers.merge!(HEADERS) | |
@session = session | |
@query = query | |
@closed = false | |
@exception = nil | |
post_query_request! | |
end | |
def post_query_request! | |
response = @faraday.post do |req| | |
req.url "/v1/statement" | |
if v = @session.user | |
req.headers[PrestoHeaders::PRESTO_USER] = v | |
end | |
if v = @session.source | |
req.headers[PrestoHeaders::PRESTO_SOURCE] = v | |
end | |
if catalog = @session.catalog | |
req.headers[PrestoHeaders::PRESTO_CATALOG] = catalog | |
end | |
if v = @session.schema | |
req.headers[PrestoHeaders::PRESTO_SCHEMA] = v | |
end | |
req.body = @query | |
end | |
# TODO error handling | |
if response.status != 200 | |
raise "Failed to start query: #{response.body}" # TODO error class | |
end | |
body = response.body | |
hash = JSON.parse(body) | |
@results = QueryResults.decode_hash(hash) | |
end | |
private :post_query_request! | |
attr_reader :query | |
def debug? | |
@session.debug? | |
end | |
def closed? | |
@closed | |
end | |
attr_reader :exception | |
def exception? | |
@exception | |
end | |
def query_failed? | |
@results.error != nil | |
end | |
def query_succeeded? | |
@results.error == nil && !@exception && !@closed | |
end | |
def current_results | |
@results | |
end | |
def has_next? | |
!!@results.next_uri | |
end | |
def advance | |
if closed? || !has_next? | |
return false | |
end | |
uri = @results.next_uri | |
start = Time.now | |
attempts = 0 | |
begin | |
begin | |
response = @faraday.get do |req| | |
req.url uri | |
end | |
rescue => e | |
@exception = e | |
raise @exception | |
end | |
if response.status == 200 && !response.body.to_s.empty? | |
@results = QueryResults.decode_hash(JSON.parse(response.body)) | |
return true | |
end | |
if response.status != 503 # retry on Service Unavailable | |
# deterministic error | |
@exception = StandardError.new("Error fetching next at #{uri} returned #{response.status}: #{response.body}") # TODO error class | |
raise @exception | |
end | |
attempts += 1 | |
sleep attempts * 0.1 | |
end while (Time.now - start) < 2*60*60 && !@closed | |
@exception = StandardError.new("Error fetching next") # TODO error class | |
raise @exception | |
end | |
def close | |
return if @closed | |
# cancel running statement | |
if uri = @results.next_uri | |
# TODO error handling | |
# TODO make async reqeust and ignore response | |
@faraday.delete do |req| | |
req.url uri | |
end | |
end | |
@closed = true | |
nil | |
end | |
end | |
class Query | |
def self.start(session, query) | |
faraday = Faraday.new(url: "http://#{session.server}") do |faraday| | |
faraday.request :url_encoded | |
faraday.response :logger | |
faraday.adapter Faraday.default_adapter | |
end | |
new StatementClient.new(faraday, session, query) | |
end | |
def initialize(client) | |
@client = client | |
end | |
def wait_for_data | |
while @client.has_next? && @client.current_results.data == nil | |
@client.advance | |
end | |
end | |
private :wait_for_data | |
def columns | |
wait_for_data | |
raise_error unless @client.query_succeeded? | |
return @client.current_results.columns | |
end | |
def each_row(&block) | |
wait_for_data | |
raise_error unless @client.query_succeeded? | |
if self.columns == nil | |
raise "Query #{@client.current_results.id} has no columns" | |
end | |
begin | |
if data = @client.current_results.data | |
data.each(&block) | |
end | |
@client.advance | |
end while @client.has_next? | |
end | |
def raise_error | |
if @client.closed? | |
raise "Query aborted by user" | |
elsif @client.exception? | |
raise "Query is gone: #{@client.exception}" | |
elsif @client.query_failed? | |
results = @client.current_results | |
# TODO error location | |
raise "Query #{results.id} failed: #{results.error}" | |
end | |
end | |
private :raise_error | |
end | |
class Client | |
def initialize(options) | |
@session = ClientSession.new(options) | |
end | |
def query(query) | |
Query.start(@session, query) | |
end | |
end | |
end | |
require 'pp' | |
client = PrestoClient::Client.new( | |
server: "localhost:8880", | |
user: "frsyuki", | |
catalog: "native", | |
schema: "default", | |
debug: true | |
) | |
q = client.query("select * from sys.query") | |
p q.columns | |
q.each_row {|row| | |
p row | |
} | |
dict.get は第二引数のデフォルトが None (例外投げたいなら dic[key] つかう) なので、
dic.get('key', None) は dic.get('key') でいいです。
StatementStats.decode_dict(dic) は StatementStats(**dic) で置き換えられるのでいらないです。
他にも何個か要らないのが。
https://gist.github.com/frsyuki/8279517#file-presto-client-py-L157
raise e
をすると、その場所から新しい例外を投げることになる(例外のコンテキストが2つになる)ので、
あえてその場所で例外を発生させたい場合を除き、 raise
とだけ書きます。
https://gist.github.com/frsyuki/8279517#file-presto-client-py-L81
columns=map((lambda d: Column.decode_dict(d)), dic["columns"]) if dic.has_key("columns") else None,
lambda 使わなくても Column.decoe_dict 自体が関数オブジェクトなので
columns=map(Column.decode_dict, dic["columns"]) if dic.has_key("columns") else None,
has_key は deprecated (Python 3 で削除される) ので in を使うと
columns=map(Column.decode_dict, dic["columns"]) if "columns" in dic else None,
Column.decode_dict は削除できるので、代わりにリスト内包にして
columns=[Column(**c) for c in dic["columns"]] if "columns" in dic else None,
Column は named tuple にできるかもしれません。
In [1]: from collections import namedtuple
In [2]: Column = namedtuple("Column", "name type")
In [3]: Column(name="column name", type="integer")
Out[3]: Column(name='column name', type='integer')
raise Exception, "Query "+str(self.client.results.id)+" has no columns"
この raise 構文は Python 3 で削除されました。例外に引数を与えたい場合は普通にインスタンスを作って投げます
raise Exception("Query "+str(self.client.results.id)+" has no columns")
文字列フォーマットは % を使うと便利です。
raise Exception("Query %r has no columns" % self.client.results.id)
%s ではなく %r なのは、整数と文字列を見分けられるからです。
StatementClient の is_query_failed is_query_succeeded has_next は、副作用も引数もないので、
前の行に @property
をつけてプロパティにすると良いと思います。
print "columns: "+str(q.columns())
これは
print "columns:", q.columns()
でいいです (カンマ区切りのオブジェクトを勝手に str() してスペース区切りで表示してくれる)
Python 3 対応をにらむなら、全ての import 文の前に
from __future__ import print_function
しておいて、
print("columns:", q.columns())
になります。
return self.results.error is None and self.exception is None and self.closed is False
真偽値は、特別な理由がない限りは is を使いません。
return self.results.error is None and self.exception is None and not self.closed
あとは好き嫌いがありますが、数値型やシーケンス型以外は基本真偽値としては真になるので、
return not self.results.error and not self.exception and not self.closed
と書けて、さらに
return not (self.results.error or self.exception or self.closed)
とまとめられますね。
QueryResultIterator はただの Python の iterator の実装でしか無いので、ジェネレータを使えば楽ができます。
具体的には、 QueryResultIterator クラスを削除し、 query.results メソッドのを直接ジェネレータにしてしまいます。
def results(self):
client = self.client
if not client.is_query_succeeded():
self._raise_error()
if self.columns() is None:
raise Exception, "Query "+str(client.results.id)+" has no columns"
while True:
for r in client.results.data:
yield r
if not self.client.has_next():
break
client.advance()
if client.results.data is None:
break
Python 2 を使う場合は、 class Hoge(object): のように object を継承しましょう。
old-style class という10年以上前のクラスを使ってしまいます。
(Python 3 では old-style class は撤廃され、省略可能です)