Source code for bndl_cassandra.dataset

from functools import partial
import difflib
import functools
import logging

from cassandra import protocol
from cassandra.protocol import ErrorMessage
from cassandra.query import tuple_factory, named_tuple_factory, dict_factory

from bndl.compute.dataset import Dataset, Partition, NODE_LOCAL
from bndl.util import funcs
from bndl.util.callsite import callsite
from bndl.util.retry import do_with_retry
from bndl_cassandra import partitioner
from bndl_cassandra.coscan import CassandraCoScanDataset
from bndl_cassandra.session import TRANSIENT_ERRORS


logger = logging.getLogger(__name__)


def _did_you_mean(msg, word, possibilities):
    matches = difflib.get_close_matches(word, possibilities, n=2)
    if matches:
        msg += ', did you mean ' + ' or '.join(matches) + '?'
    return msg


[docs]def get_table_meta(session, keyspace, table): try: keyspace_meta = session.cluster.metadata.keyspaces[keyspace] except KeyError as e: msg = 'Keyspace %s not found' % (keyspace,) msg = _did_you_mean(msg, keyspace, session.cluster.metadata.keyspaces.keys()) raise KeyError(msg) from e try: return keyspace_meta.tables[table] except KeyError as e: msg = 'Table %s.%s not found' % (keyspace, table) msg = _did_you_mean(msg, table, keyspace_meta.tables.keys()) raise KeyError(msg) from e
class _CassandraDataset(Dataset): def __init__(self, ctx, keyspace, table, contact_points=None): super().__init__(ctx) self.keyspace = keyspace self.table = table self.contact_points = contact_points self._row_factory = named_tuple_factory self._protocol_handler = None self._select = None self._limit = None def _session(self): client_protocol_handler = (getattr(protocol, self._protocol_handler) if self._protocol_handler else None) return self.ctx.cassandra_session(contact_points=self.contact_points, row_factory=self._row_factory, client_protocol_handler=client_protocol_handler) @property @functools.lru_cache(1) def meta(self): with self._session() as session: return get_table_meta(session, self.keyspace, self.table) def as_tuples(self): return self._with(_row_factory=tuple_factory) def as_dicts(self): return self._with(_row_factory=dict_factory) def select(self, *columns): return self._with(_select=columns) def limit(self, num): return self._with(_limit=int(num)) @property def query(self): select = ', '.join(self._select) if self._select else '*' limit = ' limit %s' % self._limit if self._limit else '' query = ('select {select} ' 'from {keyspace}.{table} ' 'where {where}{limit} ' 'allow filtering') #logger.error('ek query1 %s where: %s', query, self._where) query = query.format( select=select, keyspace=self.keyspace, table=self.table, where=self._where, limit=limit ) #logger.error('ek query2 %s', query) return query # converts each dict of numpy arrays (a query page) to a pandas dataframe def _arrays_to_df(pk_cols_selected, arrays): import pandas as pd from bndl.compute.dataframes import DataFrame if len(pk_cols_selected) == 0: index = None elif len(pk_cols_selected) == 1: name = pk_cols_selected[0] index = pd.Index(arrays.pop(name), name=name) else: index = [arrays.pop(name) for name in pk_cols_selected] index = pd.MultiIndex.from_arrays(index, names=pk_cols_selected) return DataFrame(arrays, index)
[docs]class CassandraScanDataset(_CassandraDataset):
[docs] def __init__(self, ctx, keyspace, table, contact_points=None): ''' Create a scan across keyspace.table. :param ctx: The compute context. :param keyspace: str Keyspace of the table to scan. :param table: str Name of the table to scan. :param contact_points: None or str or [str,str,str,...] None to use the default contact points or a list of contact points or a comma separated string of contact points. ''' super().__init__(ctx, keyspace, table, contact_points) self._where_tokens = ('token({partition_key_column_names}) > ? and ' 'token({partition_key_column_names}) <= ?').format( partition_key_column_names=', '.join(c.name for c in self.meta.partition_key) ) self._where_clause = '' self._where_values = ()
@property def _where(self): #logger.error('ek _where where1: %s where2: %s', self._where_tokens, self._where_clause) return ' and '.join(filter(None, ((self._where_tokens, self._where_clause))))
[docs] def where(self, clause, *values): return self._with(_where_clause=clause, _where_values=values)
@callsite()
[docs] def count(self, push_down=None): if push_down is True or (not self.cached and push_down is None): return self.select('count(*)').as_tuples().map(funcs.getter(0)).sum() else: return super().count()
[docs] def as_dataframe(self): ''' Create a bndl.compute.dataframe.DistributedDataFrame out of a Cassandra table scan. When primary key fields are selected, they are used to compose a (multilevel) index. Example:: >>> df = ctx.cassandra_table('ks', 'tbl').as_dataframe() >>> df.collect() comments id timestamp ZIJr6BDGCeo 2014-10-09 19:28:43.657 1 2015-01-12 20:24:49.947 4 2015-01-13 02:24:30.931 39 kxcT9VOI0oU 2015-01-12 14:24:16.378 1 2015-01-12 20:24:49.947 5 2015-01-13 02:24:30.931 8 2015-02-04 10:29:58.118 4 A_egyclRPOw 2015-12-16 13:50:53.210 1 2015-01-18 18:28:19.556 2 2015-01-22 22:28:33.358 4 2015-01-27 02:28:59.578 6 2015-01-31 06:29:07.937 7 ''' import pandas as pd from bndl.compute.dataframes import DistributedDataFrame if self._select: pk_cols_selected = [c.name for c in self.meta.primary_key if c.name in self._select] else: pk_cols_selected = [c.name for c in self.meta.primary_key] to_df = partial(_arrays_to_df, pk_cols_selected) # determine index names index = pk_cols_selected or [None] # determine column names if self._select: columns = self._select else: columns = self.meta.columns columns = [c for c in columns if c not in pk_cols_selected] # creates dicts with column names and numpy arrays per query page arrays = self._with(_row_factory=tuple_factory, _protocol_handler='NumpyProtocolHandler') # convert to dataframes dfs = arrays.map(to_df).map_partitions(pd.concat) return DistributedDataFrame(dfs, index, columns)
[docs] def span_by(self, *cols): ''' Span by groups rows in a Cassandra table scan by a subset of the primary key. This is useful for tables with clustering columns: rows in a cassandra table scan are returned clustered by partition key and sorted by clustering columns. This is exploited to efficiently (without shuffle) group rows by a part of the primary key. Example:: >>> tbl = ctx.cassandra_table('ks', 'tbl') >>> tbl.span_by().collect() [('ZIJr6BDGCeo', comments id timestamp ZIJr6BDGCeo 2014-10-09 19:28:43.657 1 2015-01-12 20:24:49.947 4 2015-01-13 02:24:30.931 9), ('kxcT9VOI0oU', comments id timestamp kxcT9VOI0oU 2015-01-12 14:24:16.378 1 2015-01-12 20:24:49.947 2 2015-01-13 02:24:30.931 5 2015-02-04 10:29:58.118 8), ('A_egyclRPOw', comments id timestamp A_egyclRPOw 2015-12-16 13:50:53.210 1 2015-01-18 18:28:19.556 2 2015-01-22 22:28:33.358 4 2015-01-27 02:28:59.578 6 2015-01-31 06:29:07.937 7)] A Cassandra table scan spanned by part of the primary key consists of pandas.DataFrame objects, and thus allows for easy per group analysis. >>> for key, count in tbl.span_by().map_values(lambda e: e.count()).collect(): ... print(key, ':', count) ... ZIJr6BDGCeo : comments 3 dtype: int64 kxcT9VOI0oU : comments 4 dtype: int64 A_egyclRPOw : comments 5 dtype: int64 ''' from bndl.compute.dataframes import DataFrame pk_cols = [c.name for c in self.meta.primary_key] if not cols: cols = pk_cols[:-1] else: if not all(col in pk_cols for col in cols): raise ValueError('Can only span a cassandra table scan by ' 'columns from the primary key') if self._select: if len(self._select) < len(cols): raise ValueError('Span by on a Cassandra table requires at ' 'least selection of the partition key') elif not all(a == b for a, b in zip(self._select, cols)): raise ValueError('The columns to span by should have the same ' 'order as those selected') if not all(a == b for a, b in zip(pk_cols, cols)) or \ len(cols) >= len(pk_cols): raise ValueError('The columns to span by should be a subset ' 'of and have the same order as primary key') levels = list(range(len(cols))) return self.as_dataframe().map_partitions(partial(DataFrame.groupby, level=levels))
[docs] def itake(self, num): if not self.cached and not self._limit: return self.limit(num).itake(num) else: return super().itake(num)
[docs] def coscan(*scans, keys=None): return CassandraCoScanDataset(*scans, keys=keys)
[docs] def parts(self): partitions = partitioner.partition_ranges(self.ctx, self.contact_points, self.keyspace, self.table) return [ CassandraScanPartition(self, i, *part) for i, part in enumerate(partitions) ]
[docs]class CassandraScanPartition(Partition):
[docs] def __init__(self, dset, part_idx, replicas, token_ranges, size_estimate_mb, size_estimate_keys): super().__init__(dset, part_idx) self.replicas = set(replicas) self.token_ranges = token_ranges self.size_estimate_mb = size_estimate_mb self.size_estimate_keys = size_estimate_keys
def _fetch_token_range(self, session, token_range): try: query = session.prepare(self.dset.query) except ErrorMessage as exc: raise Exception('Unable to prepare query %s, error: %s' % (self.dset.query, str(exc))) query.consistency_level = self.dset.ctx.conf.get('bndl_cassandra.read_consistency_level') if logger.isEnabledFor(logging.INFO): logger.info('executing query %s for token_range %s', query.query_string.replace('\n', ''), token_range) timeout = self.dset.ctx.conf.get('bndl_cassandra.read_timeout') params = token_range + self.dset._where_values resultset = session.execute(query, params, timeout=timeout) results = [] while True: has_more = resultset.response_future.has_more_pages if has_more: resultset.response_future.start_fetching_next_page() results.extend(resultset.current_rows) if has_more: resultset = resultset.response_future.result() else: break return results def _compute(self): retry_count = max(0, self.dset.ctx.conf.get('bndl_cassandra.read_retry_count')) retry_backoff = self.dset.ctx.conf.get('bndl_cassandra.read_retry_backoff') with self.dset._session() as session: logger.debug('scanning %s token ranges', len(self.token_ranges)) for token_range in self.token_ranges: yield from do_with_retry(partial(self._fetch_token_range, session, token_range), retry_count, retry_backoff, TRANSIENT_ERRORS) def _locality(self, workers): return ( (worker, NODE_LOCAL) for worker in workers if worker.ip_addresses() & self.replicas )