Skip to content
94 changes: 57 additions & 37 deletions pathml/core/slide_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from io import BytesIO

import dask
import numpy as np
import openslide
from javabridge.jutil import JavaException
Expand Down Expand Up @@ -62,8 +63,14 @@ class OpenSlideBackend(SlideBackend):
def __init__(self, filename):
logger.info(f"OpenSlideBackend loading file at: {filename}")
self.filename = filename
self.slide = openslide.open_slide(filename=filename)
self.level_count = self.slide.level_count

@property
def slide(self):
return openslide.open_slide(filename=self.filename)

@property
def level_count(self):
return self.slide.level_count

def __repr__(self):
return f"OpenSlideBackend('{self.filename}')"
Expand Down Expand Up @@ -211,9 +218,10 @@ def generate_tiles(self, shape=3000, stride=None, pad=False, level=0):
for ix_i in range(n_tiles_i):
for ix_j in range(n_tiles_j):
coords = (int(ix_i * stride_i), int(ix_j * stride_j))
# get image for tile
tile_im = self.extract_region(location=coords, size=shape, level=level)
yield pathml.core.tile.Tile(image=tile_im, coords=coords)
image = dask.delayed(self.extract_region)(
location=coords, size=shape, level=level
)
yield pathml.core.tile.Tile(image, coords=coords)


def _init_logger():
Expand Down Expand Up @@ -421,7 +429,10 @@ def extract_region(
f"Multi-level images not supported with series_as_channels=True. Input 'level={level}' invalid. Use 'level=0'."
)

javabridge.start_vm(class_path=bioformats.JARS, max_heap_size="100G")
javabridge.start_vm(
class_path=bioformats.JARS, max_heap_size="100G", run_headless=True
)

with bioformats.ImageReader(str(self.filename), perform_init=True) as reader:
# expand size
logger.info(f"extracting region with input size = {size}")
Expand Down Expand Up @@ -593,27 +604,28 @@ def generate_tiles(self, shape=3000, stride=None, pad=False, level=0, **kwargs):
for ix_j in range(n_tiles_j):
coords = (int(ix_i * stride_i), int(ix_j * stride_j))
if coords[0] + shape[0] < i and coords[1] + shape[1] < j:
# get image for tile
tile_im = self.extract_region(
image = dask.delayed(self.extract_region)(
location=coords, size=shape, level=level, **kwargs
)
yield pathml.core.tile.Tile(image=tile_im, coords=coords)
# Image on edge and needs to be padded with 0s
else:
unpaddedshape = (
unpadded_shape = (
i - coords[0] if coords[0] + shape[0] > i else shape[0],
j - coords[1] if coords[1] + shape[1] > j else shape[1],
)
tile_im = self.extract_region(
location=coords, size=unpaddedshape, level=level, **kwargs
edge_image = dask.delayed(self.extract_region)(
location=coords, size=unpadded_shape, level=level, **kwargs
)
zeroarrayshape = list(tile_im.shape)
zeroarrayshape[0], zeroarrayshape[1] = (
list(shape)[0],
list(shape)[1],
)
padded_im = np.zeros(zeroarrayshape)
padded_im[: tile_im.shape[0], : tile_im.shape[1], ...] = tile_im
yield pathml.core.tile.Tile(image=padded_im, coords=coords)

def pad(image):
"""Pads edge tiles with zeros."""
padded = np.zeros((*shape, *image.shape[:-2]))
padded[: image.shape[0], : image.shape[1]] = image
return padded

# Need to delay to use shape of edge_image
image = dask.delayed(pad)(edge_image)
yield pathml.core.tile.Tile(image=image, coords=coords)


class DICOMBackend(SlideBackend):
Expand Down Expand Up @@ -653,19 +665,25 @@ def __init__(self, filename):
f"DICOM metadata: frame_shape={self.frame_shape}, nrows = {self.n_rows}, ncols = {self.n_cols}"
)

# actual file
self.fp = DicomFile(self.filename, mode="rb")
self.fp.is_little_endian = self.transfer_syntax_uid.is_little_endian
self.fp.is_implicit_VR = self.transfer_syntax_uid.is_implicit_VR
fp = self.fp

# need to do this to advance the file to the correct point, at the beginning of the pixels
self.metadata = dcmread(self.fp, stop_before_pixels=True)
self.pixel_data_offset = self.fp.tell()
self.fp.seek(self.pixel_data_offset, 0)
self.metadata = dcmread(fp, stop_before_pixels=True)
pixel_data_offset = fp.tell()
fp.seek(pixel_data_offset, 0)
# note that reading this tag is necessary to advance the file to correct position
_ = TupleTag(self.fp.read_tag())
_ = TupleTag(fp.read_tag())
# get basic offset table, to enable reading individual frames without loading entire image
self.bot = self.get_bot(self.fp)
self.first_frame = self.fp.tell()
self.bot = self.get_bot(fp)
self.first_frame = fp.tell()

@property
def fp(self):
"""actual file"""
fp = DicomFile(self.filename, mode="rb")
fp.is_little_endian = self.transfer_syntax_uid.is_little_endian
fp.is_implicit_VR = self.transfer_syntax_uid.is_implicit_VR
return fp

def __repr__(self):
out = f"DICOMBackend('{self.filename}')\n"
Expand Down Expand Up @@ -807,7 +825,9 @@ def _read_frame(self, frame_ix):
np.ndarray: pixel data of that frame
"""
frame_offset = self.bot[int(frame_ix)]
self.fp.seek(self.first_frame + frame_offset, 0)
# self.fp refers to a different filelike object each time it is accessed
fp = self.fp
fp.seek(self.first_frame + frame_offset, 0)
try:
stop_at = self.bot[frame_ix + 1] - frame_offset
except IndexError:
Expand All @@ -816,11 +836,11 @@ def _read_frame(self, frame_ix):
# A frame may comprised of multiple chunks
chunks = []
while True:
tag = TupleTag(self.fp.read_tag())
tag = TupleTag(fp.read_tag())
if n == stop_at or int(tag) == SequenceDelimiterTag:
break
length = self.fp.read_UL()
chunks.append(self.fp.read(length))
length = fp.read_UL()
chunks.append(fp.read(length))
n += 8 + length

frame_bytes = b"".join(chunks)
Expand Down Expand Up @@ -899,7 +919,7 @@ def generate_tiles(self, shape, stride, pad, level=0, **kwargs):
if i >= (self.n_frames - self.n_cols):
continue

frame_im = self.extract_region(location=i)
im = dask.delayed(self.extract_region)(location=i)
coords = self._index_to_coords(i)
frame_tile = pathml.core.tile.Tile(image=frame_im, coords=coords)
yield frame_tile
tile = pathml.core.tile.Tile(image=im, coords=coords)
yield tile
68 changes: 45 additions & 23 deletions pathml/core/slide_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pathml.core
import pathml.preprocessing.pipeline
from pathml.core.slide_types import SlideType
from pathml.preprocessing.transforms import DropTileException


def infer_backend(path):
Expand Down Expand Up @@ -309,31 +310,47 @@ def run(
)

# map pipeline application onto each tile
processed_tile_futures = []
futures = [
client.submit(pipeline.apply, tile)
for tile in self.generate_tiles(
level=level,
shape=tile_size,
stride=tile_stride,
pad=tile_pad,
**kwargs,
)
]

for tile in self.generate_tiles(
level=level,
shape=tile_size,
stride=tile_stride,
pad=tile_pad,
**kwargs,
):
if not tile.slide_type:
tile.slide_type = self.slide_type
# explicitly scatter data, i.e. send the tile data out to the cluster before applying the pipeline
# according to dask, this can reduce scheduler burden and keep data on workers
big_future = client.scatter(tile)
f = client.submit(pipeline.apply, big_future)
processed_tile_futures.append(f)

# as tiles are processed, add them to h5
for future, tile in dask.distributed.as_completed(
processed_tile_futures, with_results=True
# After a worker processes a tile, add the tile to h5
for future, result in dask.distributed.as_completed(
futures, with_results=True, raise_errors=False
):
self.tiles.add(tile)
if future.status == "finished":
self.tiles.add(result)
if future.status == "error":
typ, exc, tb = result
if typ is DropTileException:
pass
else:
raise exc.with_traceback(tb)
# TODO: Free memory used for tile
# Each in-memory future holding a Tile shows a size of 48 bytes on the Dask dashboard
# which clearly does not include image data.
# Could it be that loaded image data is somehow not being garbage collected with Tiles?

# # all of these still leave unmanaged memory on each worker
# future.release()
# future.cancel()
# del result
# del future
# del futures

if shutdown_after:
client.shutdown()
else:
pass
# Stopgap to free unmanaged memory on client before processing another slide
client.restart()

else:
for tile in self.generate_tiles(
Expand All @@ -343,8 +360,6 @@ def run(
pad=tile_pad,
**kwargs,
):
if not tile.slide_type:
tile.slide_type = self.slide_type
pipeline.apply(tile)
self.tiles.add(tile)

Expand Down Expand Up @@ -410,14 +425,19 @@ def generate_tiles(self, shape=3000, stride=None, pad=False, **kwargs):
pathml.core.tile.Tile: Extracted Tile object
"""
for tile in self.slide.generate_tiles(shape, stride, pad, **kwargs):
# TODO: move to worker!! (forces loading data on main thread)

# add masks for tile, if possible
# i.e. if the SlideData has a Masks object, and the tile has coordinates
if self.masks is not None and tile.coords is not None:
# masks not supported if pad=True
# to implement, need to update Mask.slice to support slices that go beyond the full mask
if not pad:
i, j = tile.coords
di, dj = tile.image.shape[0:2]
# Accessing image loads data on main thread
# dask.delayed waits until compute is called on worker
shape = dask.delayed(tile).image.shape[0:2]
di, dj = shape[0], shape[1]
# add the Masks object for the masks corresponding to the tile
# this assumes that the tile didn't already have any masks
# this should work since the backend reads from image only
Expand All @@ -430,6 +450,8 @@ def generate_tiles(self, shape=3000, stride=None, pad=False, **kwargs):
tile_slices = [slice(i, i + di), slice(j, j + dj)]
tile.masks = self.masks.slice(tile_slices)

# TODO: end move to worker

# add slide-level labels to each tile, if possible
if self.labels is not None:
tile.labels = self.labels
Expand Down
45 changes: 31 additions & 14 deletions pathml/core/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from collections import OrderedDict

import anndata
import dask
import h5py
import matplotlib.pyplot as plt
import numpy as np
from dask.delayed import Delayed

import pathml.core.masks

Expand All @@ -21,7 +23,7 @@ class Tile:
on labelling the top-leftmost pixel as (0, 0)

Args:
image (np.ndarray): Image array of tile
image (np.ndarray or dask.delayed.Delayed): Tile image or dask.delayed.Delayed object to load image
coords (tuple): Coordinates of tile relative to the whole-slide image.
The (i,j) coordinate system is based on labelling the top-leftmost pixel of the WSI as (0, 0).
name (str, optional): Name of tile
Expand Down Expand Up @@ -60,9 +62,9 @@ def __init__(
time_series=None,
):
# check inputs
assert isinstance(
assert isinstance(image, Delayed) or isinstance(
image, np.ndarray
), f"image of type {type(image)} must be a np.ndarray"
), f"image of type {type(image)} must be a np.ndarray or a dask.delayed.Delayed object"
assert masks is None or isinstance(
masks, dict
), f"masks is of type {type(masks)} but must be of type dict"
Expand Down Expand Up @@ -115,23 +117,38 @@ def __init__(
counts, anndata.AnnData
), f"counts is of type {type(counts)} but must be of type anndata.AnnData or None"

if masks:
for val in masks.values():
if val.shape[:2] != image.shape[:2]:
raise ValueError(
f"mask is of shape {val.shape} but must match tile shape {image.shape}"
)
self.masks = masks
else:
self.masks = OrderedDict()

self.image = image
self._image = image
self.masks = masks if masks else OrderedDict()
self.name = name
self.coords = coords
self.slide_type = slide_type
self.labels = labels
self.counts = counts

@property
def image(self):
if isinstance(self._image, Delayed):
image = dask.compute(self._image, scheduler="single-threaded")
if isinstance(image, tuple):
image = image[0]
assert isinstance(
image, np.ndarray
), f"image of type {type(image)} must be a np.ndarray"
for val in self.masks.values():
if val.shape[:2] != image.shape[:2]:
raise ValueError(
f"mask is of shape {val.shape} but must match tile shape {image.shape}"
)
self._image = image
return self._image

@image.setter
def image(self, image):
assert isinstance(
image, np.ndarray
), f"image of type {type(image)} must be a np.ndarray"
self._image = image

def __repr__(self):
out = []
out.append(f"Tile(coords={self.coords}")
Expand Down
Loading