Source code for materializationengine.shared_tasks

import datetime
import os
from typing import Generator, List
import time

from celery.result import AsyncResult
from celery.utils.log import get_task_logger
from dynamicannotationdb.key_utils import build_segmentation_table_name
from dynamicannotationdb.models import SegmentationMetadata
from sqlalchemy import and_, func, text
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.orm.exc import NoResultFound
from materializationengine.info_client import get_relevant_datastack_info
from materializationengine.celery_init import celery
from dynamicannotationdb.models import AnalysisVersion, VersionErrorTable
from materializationengine.database import dynamic_annotation_cache, sqlalchemy_cache
from materializationengine.utils import (
    create_annotation_model,
    create_segmentation_model,
    get_config_param,
)

celery_logger = get_task_logger(__name__)


[docs]def generate_chunked_model_ids( mat_metadata: dict, use_segmentation_model=False ) -> List[List]: """Creates list of chunks with start:end index for chunking queries for materialization. Parameters ---------- mat_metadata : dict Materialization metadata Returns ------- List[List] list of list containing start and end indices """ celery_logger.info("Chunking supervoxel ids") if use_segmentation_model: AnnotationModel = create_segmentation_model(mat_metadata) else: AnnotationModel = create_annotation_model(mat_metadata) chunk_size = mat_metadata.get("chunk_size") if not chunk_size: ROW_CHUNK_SIZE = get_config_param("MATERIALIZATION_ROW_CHUNK_SIZE") chunk_size = ROW_CHUNK_SIZE return chunk_ids(mat_metadata, AnnotationModel.id, chunk_size)
[docs]def create_chunks(data_list: List, chunk_size: int) -> Generator: """Create chunks from list with fixed size Args: data_list (List): list to chunk chunk_size (int): size of chunk Yields: List: generator of chunks """ if len(data_list) <= chunk_size: chunk_size = len(data_list) for i in range(0, len(data_list), chunk_size): yield data_list[i : i + chunk_size]
@celery.task def workflow_failed(request, exc, traceback, mat_info, *args, **kwargs): aligned_volume = mat_info[0]["aligned_volume"] analysis_version = mat_info[0]["analysis_version"] datastack = mat_info[0]["datastack"] message = f"Task failure: {request.id}" failure_info = f""" Datastack: {datastack} Task ID: {request.id} task_args: {str(args)} task_kwargs: {kwargs} Exception: {exc} Traceback: {traceback} """ session = sqlalchemy_cache.get(aligned_volume) version = ( session.query(AnalysisVersion) .filter(AnalysisVersion.version == analysis_version) .one() ) error_entry = VersionErrorTable( exception=str(exc), error=traceback, analysisversion_id=version.id ) celery_logger.info(error_entry) session.add(error_entry) version.valid = False version.status = "FAILED" session.flush() try: session.commit() except Exception as e: session.rollback() celery_logger.error(f"Failed to insert error msg: {e}") finally: session.close() celery_logger.error(f"{message}: {failure_info}") return failure_info @celery.task(name="workflow:fin", acks_late=True, bind=True) def fin(self, *args, **kwargs): return True @celery.task(name="workflow:workflow_complete", acks_late=True, bind=True) def workflow_complete(self, workflow_name): return f"{workflow_name} completed successfully"
[docs]def get_materialization_info( datastack_info: dict, analysis_version: int = None, materialization_time_stamp: datetime.datetime.utcnow = None, skip_table: bool = False, row_size: int = 1_000_000, table_name: str = None, skip_row_count: bool = False, ) -> List[dict]: """Initialize materialization by an aligned volume name. Iterates through all tables in a aligned volume database and gathers metadata for each table. The list of tables are passed to workers for materialization. Args: datastack_info (dict): Datastack info analysis_version (int, optional): Analysis version to use for frozen materialization. Defaults to None. skip_table (bool, optional): Triggers row count for skipping tables larger than row_size arg. Defaults to False. row_size (int, optional): Row size number to check. Defaults to 1_000_000. Returns: List[dict]: [description] """ celery_logger.info("Collecting materialization metadata") aligned_volume_name = datastack_info["aligned_volume"]["name"] pcg_table_name = datastack_info["segmentation_source"].split("/")[-1] segmentation_source = datastack_info.get("segmentation_source") if not materialization_time_stamp: materialization_time_stamp = datetime.datetime.utcnow() db = dynamic_annotation_cache.get_db(aligned_volume_name) annotation_tables = db.database.get_valid_table_names() if table_name is not None: annotation_tables = [next(a for a in annotation_tables if a == table_name)] metadata = [] celery_logger.debug(f"Annotation tables: {annotation_tables}") for annotation_table in annotation_tables: max_id = db.database.get_max_id_value(annotation_table) try: max_id = int(max_id) except TypeError: max_id = None if not skip_row_count: row_count = db.database.get_table_row_count( annotation_table, filter_valid=True, filter_timestamp=str(materialization_time_stamp), ) min_id = db.database.get_min_id_value(annotation_table) try: min_id = int(min_id) except TypeError: min_id = None table_metadata = { "max_id": max_id, "min_id": min_id, "row_count": row_count, } if row_count == 0: continue if row_count >= row_size and skip_table: continue else: table_metadata = {"max_id": max_id} md = db.database.get_table_metadata(annotation_table) vx = md.get("voxel_resolution_x", None) vy = md.get("voxel_resolution_y", None) vz = md.get("voxel_resolution_z", None) vx = vx or 1.0 vy = vy or 1.0 vz = vz or 1.0 voxel_resolution = [vx, vy, vz] reference_table = md.get("reference_table") schema = db.database.get_table_schema(annotation_table) if max_id: table_metadata.update( { "annotation_table_name": annotation_table, "datastack": datastack_info["datastack"], "aligned_volume": str(aligned_volume_name), "schema": schema, "add_indices": True, "coord_resolution": voxel_resolution, "reference_table": reference_table, "materialization_time_stamp": str(materialization_time_stamp), "table_count": len(annotation_tables), } ) has_segmentation_table = db.schema.is_segmentation_table_required(schema) if has_segmentation_table: segmentation_table_name = build_segmentation_table_name( annotation_table, pcg_table_name ) segmentation_metadata = db.segmentation.get_segmentation_table_metadata( annotation_table, pcg_table_name ) if segmentation_metadata: create_segmentation_table = False else: celery_logger.warning( f"SEGMENTATION TABLE DOES NOT EXIST: {segmentation_table_name}" ) segmentation_metadata = {"last_updated": None} create_segmentation_table = True last_updated_time_stamp = segmentation_metadata.get("last_updated") last_updated_time_stamp = ( str(last_updated_time_stamp) if last_updated_time_stamp else None ) table_metadata.update( { "create_segmentation_table": create_segmentation_table, "segmentation_table_name": segmentation_table_name, "temp_mat_table_name": f"temp__{annotation_table}", "pcg_table_name": pcg_table_name, "segmentation_source": segmentation_source, "last_updated_time_stamp": last_updated_time_stamp, "chunk_size": get_config_param( "MATERIALIZATION_ROW_CHUNK_SIZE" ), "queue_length_limit": get_config_param("QUEUE_LENGTH_LIMIT"), "throttle_queues": get_config_param("THROTTLE_QUEUES"), "lookup_all_root_ids": datastack_info.get( "lookup_all_root_ids", False ), "merge_table": get_config_param("MERGE_TABLES"), } ) if analysis_version: table_metadata.update( { "analysis_version": analysis_version, "analysis_database": f"{datastack_info['datastack']}__mat{analysis_version}", } ) metadata.append(table_metadata.copy()) celery_logger.debug(metadata) db.database.cached_session.close() celery_logger.info(f"Metadata collected for {len(metadata)} tables") return metadata
@celery.task(name="workflow:collect_data", acks_late=True) def collect_data(*args, **kwargs): return args, kwargs
[docs]def query_id_range(column, start_id: int, end_id: int): if end_id: return and_(column >= start_id, column < end_id) else: return column >= start_id
[docs]def chunk_ids(mat_metadata, model, chunk_size: int): aligned_volume = mat_metadata.get("aligned_volume") session = sqlalchemy_cache.get(aligned_volume) q = session.query( model, func.row_number().over(order_by=model).label("row_count") ).from_self(model) if chunk_size > 1: q = q.filter(text("row_count %% %d=1" % chunk_size)) chunks = [id for id, in q] while chunks: chunk_start = chunks.pop(0) chunk_end = chunks[0] if chunks else None yield [chunk_start, chunk_end]
@celery.task( name="workflow:update_metadata", bind=True, acks_late=True, autoretry_for=(Exception,), max_retries=3, ) def update_metadata(self, mat_metadata: dict): """Update 'last_updated' column in the segmentation metadata table for a given segmentation table. Args: mat_metadata (dict): materialization metadata Returns: str: description of table that was updated """ aligned_volume = mat_metadata["aligned_volume"] segmentation_table_name = mat_metadata["segmentation_table_name"] session = sqlalchemy_cache.get(aligned_volume) materialization_time_stamp = mat_metadata["materialization_time_stamp"] try: last_updated_time_stamp = datetime.datetime.strptime( materialization_time_stamp, "%Y-%m-%d %H:%M:%S.%f" ) except ValueError: last_updated_time_stamp = datetime.datetime.strptime( materialization_time_stamp, "%Y-%m-%dT%H:%M:%S.%f" ) try: seg_metadata = ( session.query(SegmentationMetadata) .filter(SegmentationMetadata.table_name == segmentation_table_name) .one() ) seg_metadata.last_updated = last_updated_time_stamp session.commit() except Exception as e: celery_logger.error(f"SQL ERROR: {e}") session.rollback() finally: session.close() return { f"Table: {segmentation_table_name}": f"Time stamp {materialization_time_stamp}" } @celery.task( name="workflow:add_index", bind=True, acks_late=True, task_reject_on_worker_lost=True, autoretry_for=(Exception,), max_retries=3, ) def add_index(self, database: dict, command: str): """Add an index or a contrainst to a table. Args: mat_metadata (dict): datastack info for the aligned_volume derived from the infoservice command (str): sql command to create an index or constraint Raises: self.retry: retries task when an error creating an index occurs Returns: str: String of SQL command """ engine = sqlalchemy_cache.get_engine(database) # increase maintenance memory to improve index creation speeds, # reset to default after index is created ADD_INDEX_SQL = f""" SET maintenance_work_mem to '1GB'; {command} SET maintenance_work_mem to '64MB'; """ try: with engine.begin() as conn: celery_logger.info(f"Adding index: {command}") result = conn.execute(ADD_INDEX_SQL) except ProgrammingError as index_error: celery_logger.error(index_error) return "Index already exists" except Exception as e: celery_logger.error(f"Index creation failed: {e}") raise self.retry(exc=e, countdown=3) return f"Index {command} added to table"
[docs]def monitor_task_states(task_ids: List, polling_rate: int = 0.2): while True: results = [AsyncResult(task_id, app=celery) for task_id in task_ids] celery_logger.debug(f"Celery results {results}") result_status = [] for result in results: if result.state == "FAILURE": raise Exception(result.traceback) result_status.append(result.state) celery_logger.debug(f"Celery task status: {result_status}") if all(x == "SUCCESS" for x in result_status): return True time.sleep(polling_rate)
[docs]def monitor_workflow_state(workflow: AsyncResult, polling_rate: int = 0.2): while True: celery_logger.debug("WAITING FOR TASKS TO COMPLETE...") if workflow.ready(): celery_logger.debug(f"WORKFLOW IDS: {workflow.id}, READY") if workflow.successful(): celery_logger.debug("CHAIN COMPLETE") return True if workflow.failed(): return False time.sleep(polling_rate)
[docs]def check_if_task_is_running(task_name: str, worker_name_prefix: str) -> bool: """Check if a task is running under a worker with a specified prefix name. If the task is found to be in an active state then return True. Parameters ---------- task_name : str name of task to check if it is running worker_name_prefix : str prefix of celery worker, used to check specific queue. Returns ------- bool True if task_name is running else False """ inspector = celery.control.inspect() active_tasks_dict = inspector.active() workflow_active_tasks = next( v for k, v in active_tasks_dict.items() if worker_name_prefix in k ) for active_task in workflow_active_tasks: if task_name in active_task.values(): celery_logger.info(f"Task {task_name} is running...") return True return False