Source code for cluster.preprocess.pre_node_feed_text2seq
from cluster.preprocess.pre_node_feed import PreNodeFeed
import os,h5py
[docs]class PreNodeFeedText2Seq(PreNodeFeed):
"""
"""
[docs] def run(self, conf_data):
"""
:param conf_data:
:return:
"""
super(PreNodeFeedText2Seq, self).run(conf_data)
self.file_list_size = max([len(self.input_paths[0]), len(self.input_paths[1])])
self._init_node_parm(conf_data['node_id'])
[docs] def has_next(self):
"""
check if hdf5 file pointer has next
:return:
"""
if(self.file_list_size > self.pointer) :
return True
else :
return False
[docs] def next(self):
"""
move pointer +1
:return:
"""
if(self.has_next()) :
self.pointer = self.pointer + 1
[docs] def len(self):
"""
:return:
"""
return self.file_list_size
def __getitem__(self, key):
"""
:param key:
:return:
"""
encode = self._convert_data_format(self.input_paths[0][self.pointer], key)
decode = self._convert_data_format(self.input_paths[1][self.pointer], key)
return encode, decode
def _convert_data_format(self, file_path, index):
"""
just pass hdf5 file chunk
:param file_path:
:param index:
:return:
"""
try:
h5file = h5py.File(file_path, mode='r')
rawfile = h5file['rawdata']
return rawfile[index.start : index.stop]
except Exception as e:
raise Exception(e)
finally:
h5file.close()
[docs] def data_size(self):
try:
h5file = h5py.File(self.input_paths[self.pointer], mode='r')
return h5file['rawdata'].len()
except Exception as e:
raise Exception(e)
finally:
h5file.close()