from functools import partial, lru_cache
import logging

from bndl.compute.dataset import Dataset, Partition
from bndl.util.funcs import identity, getter
from bndl_cassandra import partitioner
from bndl_cassandra.partitioner import estimate_size, SizeEstimate
from bndl_cassandra.session import cassandra_session
from cytoolz.itertoolz import take

logger = logging.getLogger(__name__)

[docs]def get_or_none(index, container): try: return container[index] except IndexError: return None
[docs]class CassandraCoScanDataset(Dataset):
[docs] def __init__(self, *scans, keys=None, dset_id=None): assert len(scans) > 1 scans = list(scans) scan0 = scans[0] super().__init__(scan0.ctx, src=scans, dset_id=dset_id) for idx, scan in enumerate(scans[1:], 1): if isinstance(scan, str): if '.' in scan: keyspace, table = scan.split('.', 1) else: keyspace, table = scan0.keyspace, scan scan = self.ctx.cassandra_table(keyspace, table) scans[idx] = scan assert scan.contact_points == scan0.contact_points, "only scan in parallel within the same cluster" assert scan.keyspace == scan0.keyspace, "only scan in parallel within the same keyspace" assert len(set(scan.table for scan in scans)) == len(scans), "don't scan the same table twice" self.contact_points = scan0.contact_points self.keyspace = scan0.keyspace self.srcparts = [ for src in scans] self.pcount = len(self.srcparts[0]) # TODO check format (dicts, tuples, namedtuples, etc.) # TODO adapt keyfuncs to below if isinstance(keys, str): self.keys = keys = [keys] * len(scans) else: self.keys = keys with cassandra_session(self.ctx, contact_points=self.contact_points) as session: ks_meta = session.cluster.metadata.keyspaces[self.keyspace] tbl_metas = [ks_meta.tables[scan.table] for scan in scans] if not keys: primary_key_length = len(tbl_metas[0].primary_key) for tbl_meta in tbl_metas[1:]: assert len(tbl_meta.primary_key) == primary_key_length, \ "can't co-scan without keys with varying primary key length" self.keyfuncs = [partial(take, primary_key_length)] * len(scans) self.grouptransforms = [partial(get_or_none, 0)] * len(scans) else: assert len(keys) == len(scans), \ "provide a key for each table scanned or none at all" self.keyfuncs = [] self.grouptransforms = [] for key, scan, tbl_meta in zip(keys, scans, tbl_metas): if isinstance(key, str): key = (key,) keylen = len(key) assert len(tbl_meta.partition_key) <= keylen, \ "can't co-scan over a table keyed by part of the partition key" assert tuple(key) == tuple( for c in tbl_meta.primary_key)[:keylen], \ "the key columns must be the first part (or all) of the primary key" assert scan._select is None or tuple(key) == tuple(scan._select)[:keylen], \ "select all columns or the primary key columns in the order as they " \ "are defined in the CQL schema" self.keyfuncs.append(partial(take, keylen)) if keylen == len(tbl_meta.primary_key): self.grouptransforms.append(partial(get_or_none, 0)) else: self.grouptransforms.append(identity)
[docs] def coscan(self, *others, keys=None): assert len(others) > 0 return CassandraCoScanDataset(*tuple(self.src) + others, keys=keys)
[docs] def parts(self): from bndl_cassandra.dataset import CassandraScanPartition with cassandra_session(self.ctx, contact_points=self.contact_points) as session: size_estimates = sum((estimate_size(session, self.keyspace, src.table) for src in self.src), SizeEstimate(0, 0, 0)) partitions = partitioner.partition_ranges(self.ctx, self.contact_points, self.keyspace, size_estimates=size_estimates) return [ CassandraCoScanPartition(self, idx, [CassandraScanPartition(scan, idx, *part) for scan in self.src]) for idx, part in enumerate(partitions) ]
[docs]class CassandraCoScanPartition(Partition):
[docs] def __init__(self, dset, idx, scans): super().__init__(dset, idx) self.scans = scans
def _locality(self, workers): return self.scans[0]._locality(workers) def _compute(self): keyfuncs = self.dset.keyfuncs grouptransforms = self.dset.grouptransforms subscans = [scan.compute() for scan in self.scans] merged = {} for cidx, scan in enumerate(subscans): keyf = keyfuncs[cidx] for row in scan: key = tuple(keyf(row)) batch = merged.get(key) if not batch: merged[key] = batch = [[] for _ in subscans] batch[cidx].append(row) for key, groups in merged.items(): for idx, (group, transform) in enumerate(zip(groups, grouptransforms)): groups[idx] = transform(group) yield key, groups