Source code for vampire.coloring
import matplotlib.pyplot as plt
import numpy as np
[docs]def label_imgs(img_set, properties_df):
"""
Label objects in the images within the set according to clusters.
Parameters
----------
img_set : list[ndarray]
Set of images to be labeled.
properties_df : DataFrame
Properties of objects from ``img_set``.
Contains labels and cluster_ids.
Returns
-------
labeled_imgs : list[ndarray]
Image set with objects labeled according to clusters.
"""
img_ids = np.unique(properties_df['image_id'])
labeled_imgs = []
for i, img_id in enumerate(img_ids):
img_df = properties_df[properties_df['image_id'] == img_id]
labeled_img = label_img(img_set[i], img_df)
labeled_imgs.append(labeled_img)
return labeled_imgs
[docs]def label_img(img, img_df):
"""
Label objects in the image according to clusters.
Parameters
----------
img : ndarray
Image to be labeled.
img_df : DataFrame
Properties of objects from ``img``.
Contains labels and cluster_ids.
Returns
-------
labeled_img : ndarray
Image with objects labeled according to clusters.
"""
AVOID_OVERRIDE_NUM = 2**16
masks = []
cluster_ids = np.unique(img_df['cluster_id'])
for cluster_id in cluster_ids:
cluster_df = img_df[img_df['cluster_id'] == cluster_id]
mask = np.isin(img, cluster_df['label'])
# 0 reserved for background, switch to 1 indexing
mask = mask * AVOID_OVERRIDE_NUM * (cluster_id + 1)
masks.append(mask)
labeled_img = sum(masks) / AVOID_OVERRIDE_NUM
return labeled_img
[docs]def color_img(img, background=0, cmap=None, background_color=None):
"""
Plot cluster-labeled images.
Parameters
----------
img : ndarray
Cluster-labeled image.
background : int, optional
Background value. Default 0.
cmap : str, optional
Matplotlib colormap name.
https://matplotlib.org/stable/tutorials/colors/colormaps.html
background_color : str, optional
Color name for background.
Returns
-------
fig : matplotlib.pyplot.figure
ax : matplotlib.axes.Axes
colors : ndarray
Colors used to color each cluster
"""
if cmap is None:
cmap = plt.get_cmap('twilight').copy()
else:
cmap = plt.get_cmap(cmap).copy()
if background_color is None:
cmap.set_bad(color='white')
else:
cmap.set_bad(color=background_color)
# avoid modifying img in outer scope with inplace operations
img = np.copy(img)
# assign each cluster a label that's normalized for cmap
cluster_ids = np.unique(img)
cluster_ids = np.delete(cluster_ids,
np.where(cluster_ids == background))
n_clusters = len(cluster_ids)
replaced_labels = np.linspace(0.1, 0.9, n_clusters)
for i, replaced_label in enumerate(replaced_labels):
img[img == cluster_ids[i]] = replaced_label
# make background "bad" so it displays background_color
img[img == background] = np.nan
# plot cluster-labeled img
fig, ax = plt.subplots(figsize=(5, 5))
ax.imshow(img, cmap=cmap, vmax=1, vmin=0)
ax.tick_params(
axis='both',
which='both',
bottom=False,
top=False,
left=False,
labelbottom=False,
labelleft=False
)
plt.tight_layout(pad=0)
# colors used for labeling
colors = cmap(replaced_labels)
return fig, ax, colors