Skip to content

Instantly share code, notes, and snippets.

@gdementen
Created July 3, 2014 13:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gdementen/8980d152f7627a6ec3bf to your computer and use it in GitHub Desktop.
Save gdementen/8980d152f7627a6ec3bf to your computer and use it in GitHub Desktop.
Alexis Eidelman Default values Patch for LIAM2
diff --git a/src_liam/data.py b/src_liam/data.py
index 31fe79b..44f11b6 100644
--- a/src_liam/data.py
+++ b/src_liam/data.py
@@ -35,7 +35,7 @@ def append_carray_to_table(array, table, numlines=None, buffersize=10 * MB):
class ColumnArray(object):
- def __init__(self, array=None):
+ def __init__(self, array=None, default_values=None):
columns = {}
if array is not None:
if isinstance(array, (np.ndarray, ColumnArray)):
@@ -43,18 +43,25 @@ class ColumnArray(object):
columns[name] = array[name].copy()
self.dtype = array.dtype
self.columns = columns
+ if isinstance(array, ColumnArray):
+ self.dval = array.dval
+ else:
+ self.dval = []
elif isinstance(array, list):
for name, column in array:
columns[name] = column
self.dtype = np.dtype([(name, column.dtype)
for name, column in array])
self.columns = columns
+ self.dval = []
else:
#TODO: make a property instead?
self.dtype = None
self.columns = columns
+ self.dval = []
else:
self.dtype = None
+ self.dval = []
self.columns = columns
def __getitem__(self, key):
@@ -163,7 +170,7 @@ class ColumnArray(object):
return ca
@classmethod
- def from_table(cls, table, start=0, stop=None, buffersize=10 * 2 ** 20):
+ def from_table(cls, table, start=0, stop=None, default_values={}, buffersize=10 * 2 ** 20):
# reading a table one column at a time is very slow, this is why this
# function is even necessary
if stop is None:
@@ -172,6 +179,7 @@ class ColumnArray(object):
max_buffer_rows = buffersize // dtype.itemsize
numlines = stop - start
ca = cls.empty(numlines, dtype)
+ ca.dval = default_values
buffer_rows = min(numlines, max_buffer_rows)
# chunk = np.empty(buffer_rows, dtype=dtype)
array_start = 0
@@ -220,10 +228,15 @@ class ColumnArray(object):
output_dtype = np.dtype(output_fields)
output_names = set(output_dtype.names)
input_names = set(self.dtype.names)
+ default_values = self.dval
length = len(self)
# add missing fields
for name in output_names - input_names:
- self[name] = get_missing_vector(length, output_dtype[name])
+ if name in default_values:
+ self[name] = np.empty(length, dtype=output_dtype[name])
+ self[name].fill(default_values[name])
+ else:
+ self[name] = get_missing_vector(length, output_dtype[name])
# delete extra fields
for name in input_names - output_names:
del self[name]
@@ -274,7 +287,7 @@ def assertValidType(array, wanted_type, allowed_missing=None, context=None):
wanted_type.__name__))
-def add_and_drop_fields(array, output_fields, missing_fields={}, output_array=None):
+def add_and_drop_fields(array, output_fields, default_values={}, output_array=None):
output_dtype = np.dtype(output_fields)
output_names = set(output_dtype.names)
input_names = set(array.dtype.names)
@@ -283,8 +296,8 @@ def add_and_drop_fields(array, output_fields, missing_fields={}, output_array=No
if output_array is None:
output_array = np.empty(len(array), dtype=output_dtype)
for fname in all_missing_fields:
- if fname in missing_fields.keys():
- output_array[fname] = missing_fields[fname]
+ if fname in default_values:
+ output_array[fname] = default_values[fname]
else:
output_array[fname] = get_missing_value(output_array[fname])
else:
@@ -415,7 +428,7 @@ def appendTable(input_table, output_table, chunksize=10000, condition=None,
num_chunks += 1
if output_fields is not None:
- expanded_data = np.empty(chunksize, dtype=np.dtype(output_fields))
+ expanded_data = ColumnArray.empty(chunksize, dtype=np.dtype(output_fields))
expanded_data[:] = get_missing_record(expanded_data)
def copyChunk(chunk_idx, chunk_num):
@@ -430,11 +443,11 @@ def appendTable(input_table, output_table, chunksize=10000, condition=None,
if output_fields is not None:
# use our pre-allocated buffer (except for the last chunk)
if len(input_data) == len(expanded_data):
- missing_fields = {}
+ default_values = {}
output_data = add_and_drop_fields(input_data, output_fields,
- missing_fields, expanded_data)
+ default_values, expanded_data)
else:
- output_data = add_and_drop_fields(input_data, output_fields) #, missing_fields
+ output_data = add_and_drop_fields(input_data, output_fields, default_values)
else:
output_data = input_data
@@ -472,7 +485,7 @@ def copyTable(input_table, output_node, output_fields=None,
# 1) all arrays have the same columns
# 2) we have id_to_rownum already computed for each array
def buildArrayForPeriod(input_table, output_fields, input_rows,
- input_index, start_period, missing_fields={}):
+ input_index, start_period, default_values={}):
periods_before = [p for p in input_rows.iterkeys() if p <= start_period]
if not periods_before:
id_to_rownum = np.empty(0, dtype=int)
@@ -495,7 +508,7 @@ def buildArrayForPeriod(input_table, output_fields, input_rows,
# if all individuals are present in the target period, we are done already!
if np.array_equal(present_in_period, is_present):
start, stop = input_rows[target_period]
- input_array = ColumnArray.from_table(input_table, start, stop)
+ input_array = ColumnArray.from_table(input_table, start, stop, default_values)
input_array.add_and_drop_fields(output_fields)
return input_array, period_id_to_rownum
@@ -807,9 +820,8 @@ class H5Data(DataSource):
# would be brought back to life. In conclusion, it should be
# optional.
entity.array, entity.id_to_rownum = \
- buildArrayForPeriod(table.table, entity.fields,
- entity.input_rows,
- entity.input_index, start_period)
+ buildArrayForPeriod(table.table, entity.fields, entity.input_rows,
+ entity.input_index, start_period, entity.default_values)
assert isinstance(entity.array, ColumnArray)
entity.array_period = start_period
print("done (%s elapsed)." % time2str(time.time() - start_time))
diff --git a/src_liam/entities.py b/src_liam/entities.py
index b705b0d..62b40d5 100644
--- a/src_liam/entities.py
+++ b/src_liam/entities.py
@@ -10,7 +10,7 @@ import config
from context import EntityContext, context_length
from data import mergeArrays, get_fields, ColumnArray
from expr import (Variable, GlobalVariable, GlobalTable, GlobalArray,
- expr_eval, get_missing_value)
+ expr_eval, missing_values, get_missing_value)
from exprparser import parse
from process import Assignment, Compute, Process, ProcessGroup
from registry import entity_registry
@@ -32,7 +32,7 @@ class Entity(object):
'''
fields is a list of tuple (name, type, options)
'''
- def __init__(self, name, fields, missing_fields=None, links=None,
+ def __init__(self, name, fields, missing_fields=None, default_values={}, links=None,
macro_strings=None, process_strings=None,
on_align_overflow='carry'):
self.name = name
@@ -64,6 +64,7 @@ class Entity(object):
# another solution is to use a Field class
# seems like the better long term solution
self.missing_fields = missing_fields
+ self.default_values = default_values
self.period_individual_fnames = [name for name, _ in fields]
self.links = links
@@ -114,15 +115,29 @@ class Entity(object):
fields = []
missing_fields = []
+ default_values = {}
for name, fielddef in fields_def:
if isinstance(fielddef, dict):
strtype = fielddef['type']
+ import pdb
if not fielddef.get('initialdata', True):
missing_fields.append(name)
+
+ fieldtype = field_str_to_type(strtype, "field '%s'" % name)
+ dflt_type = missing_values[fieldtype]
+ default = fielddef.get('default', dflt_type)
+ if fieldtype != type(default):
+ raise Exception("The default value given to %s is %s"
+ " but %s was expected" %(name, type(default), strtype) )
+
else:
strtype = fielddef
- fields.append((name,
- field_str_to_type(strtype, "field '%s'" % name)))
+ fieldtype = field_str_to_type(strtype, "field '%s'" % name)
+ default = missing_values[fieldtype]
+
+ fields.append((name, fieldtype))
+ default_values[name] = default
+
link_defs = entity_def.get('links', {})
str2class = {'one2many': One2Many, 'many2one': Many2One}
@@ -131,13 +146,13 @@ class Entity(object):
for name, l in link_defs.iteritems())
#TODO: add option for on_align_overflow
- return Entity(ent_name, fields, missing_fields, links,
+ return Entity(ent_name, fields, missing_fields, default_values, links,
entity_def.get('macros', {}),
entity_def.get('processes', {}))
@classmethod
def from_table(cls, table):
- return Entity(table.name, get_fields(table), missing_fields=[],
+ return Entity(table.name, get_fields(table), missing_fields=[], default_values={},
links={}, macro_strings={}, process_strings={})
@staticmethod
@@ -363,6 +378,8 @@ class Entity(object):
# but the usual case (in retro) is that self.array is a superset of
# input_array, in which case mergeArrays returns a ColumnArray
if not isinstance(self.array, ColumnArray):
+ import pdb
+ pdb.set_trace()
self.array = ColumnArray(self.array)
def store_period_data(self, period):
diff --git a/src_liam/expr.py b/src_liam/expr.py
index 9ec2fcf..816f398 100644
--- a/src_liam/expr.py
+++ b/src_liam/expr.py
@@ -73,7 +73,10 @@ def get_missing_vector(num, dtype):
def get_missing_record(array):
row = np.empty(1, dtype=array.dtype)
for fname in array.dtype.names:
- row[fname] = get_missing_value(row[fname])
+ if fname in array.dval:
+ row[fname] = array.dval[fname]
+ else:
+ row[fname] = get_missing_value(row[fname])
return row
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment