Patrick Gray (patrick.c.gray at duke) - https://github.com/patrickcgray
matplotlib
¶matplotlib is a very powerful plotting library for making amazing visualizations for publications, personal use, or even web and desktop applications. matplotlib
can create almost any two dimensional visualization you can think of, including histograms, scatter plots, bivariate plots, and image displays. For some inspiration, check out the matplotlib
example gallery which includes the source code required to generate each example.
One part of matplotlib
that may be initially confusing is that matplotlib
contains two main methods of making plots - the object-oriented method, and the state-machine method.
A very good overview of the difference between the two usages is provided by Jake Vanderplas. Specifically,
If you need a primer on matplotlib beyond what is here I suggest: Python Like you Mean It or the matplotlib users guide.
In general, I think you should use the object-oriented API. While more complicated, is a much more powerful way of creating plots and should be used when developing more complicated visualizations. I always recommend the OO API.
import rasterio
import numpy as np
# Open our raster dataset
dataset = rasterio.open('../data/LE70220491999322EDC01_stack.gtif')
image = dataset.read()
Let's check out some of this datasets characteristics
# How many bands does this image have?
num_bands = dataset.count
print('Number of bands in image: {n}\n'.format(n=num_bands))
# How many rows and columns?
rows, cols = dataset.shape
print('Image size is: {r} rows x {c} columns\n'.format(r=rows, c=cols))
# What driver was used to open the raster?
driver = dataset.driver
print('Raster driver: {d}\n'.format(d=driver))
# What is the raster's projection?
proj = dataset.crs
print('Image projection:')
print(proj)
Number of bands in image: 8 Image size is: 250 rows x 250 columns Raster driver: GTiff Image projection: EPSG:32615
Note that if you have the full dataset read in with you can do:
red_band = image[2, :, :] # this pulls out the 3rd band
print(red_band.shape)
(250, 250)
This pulls out the band at index 2 which is the 3rd band because python indexing starts at 0
Which is equal to simply doing:
red_band_read = dataset.read(3) # this pulls out the 3rd band
print(red_band_read.shape)
if np.array_equal(red_band_read, red_band): # are they equal?
print('They are the same.')
(250, 250) They are the same.
If you have a large dataset the second dataset.read(3) may be preferable because you save on memory.
First thing to do is to import matplotlib
into our namespace. I will be using a special feature of the IPython utility which allows me to "inline" matplotlib
figures by entering the %matplotlib inline
command.
import matplotlib.pyplot as plt
%matplotlib inline
With matplotlib
imported, we can summon up a figure and make our first plot:
# Array of 0 - 9
x = np.arange(10)
# 10 random numbers, between 0 and 10
y = np.random.randint(0, 10, size=10)
# plot them as lines
plt.plot(x, y)
[<matplotlib.lines.Line2D at 0x7f35a2366460>]
# plot them as just points -- specify "ls" ("linestyle") as a null string
plt.plot(x, y, 'ro', ls='')
[<matplotlib.lines.Line2D at 0x7f35a225d100>]
As I mentioned earlier, we'll be using Matplotib’s object oriented API, instead of its function-based API. Let’s briefly draw a distinction between these two APIs:
# prepare 50 x-coordinates and 50 y-coordinates
x = np.linspace(-np.pi, np.pi, 50)
y = np.sin(x)
# Plot using matplotlib's functional API:
# a single function call produces a plot; convenient but less flexible
plt.plot(x, y)
[<matplotlib.lines.Line2D at 0x7f35a22444c0>]
# Plot using matplotlib's object-oriented API:
# we generate a figure and axis object: `fig` and `ax`
fig, ax = plt.subplots()
# we then use these objects to draw-on and manipulate our plot
ax.plot(x, y)
[<matplotlib.lines.Line2D at 0x7f35a21a94f0>]
Although the code that invokes the functional API is simpler, it is far less powerful and flexible than the object-oriented API, which produces figure (fig) and axes (ax) objects that we can leverage to customize our plot. You will likely see tutorials utilize the functional API in their examples, so it is useful to understand the distinction here. Shortly, you will learn how to leverage Matplotlib’s object-oriented API in powerful ways.
One typical thing that we might want to do would be to plot one band against another. In order to do this, we will need to transform, or flatten
, our 2 dimensional arrays of each band's values into 1 dimensional arrays:
red = image[3, :, :]
nir = image[4, :, :]
print('Array shape before: {shp} (size is {sz})'.format(shp=red.shape, sz=red.size))
red_flat = np.ndarray.flatten(red)
nir_flat = np.ndarray.flatten(nir)
print('Array shape after: {shp} (size is {sz})'.format(shp=red_flat.shape, sz=red_flat.size))
Array shape before: (250, 250) (size is 62500) Array shape after: (62500,) (size is 62500)
We have retained the number of entries in each of these raster bands, but we have flattened them from 2 dimensions into 1.
Now we can plot them. Since we just want points, we can use scatter
for a scatterplot. Since there are no lines in a scatterplot, it has a slightly different syntax.
fig, ax = plt.subplots()
# Make the plot
ax.scatter(red_flat, nir_flat, color='r', marker='o')
# Add some axis labels
ax.set_xlabel('Red Reflectance')
ax.set_ylabel('NIR label')
# Add a title
ax.set_title('Tasseled Cap, eh?')
Text(0.5, 1.0, 'Tasseled Cap, eh?')
If we wanted the two axes to have the same limits, we can calculate the limits and apply them
fig, ax = plt.subplots()
# Make the plot
ax.scatter(red_flat, nir_flat, color='r', marker='o')
# Calculate min and max
plot_min = min(red.min(), nir.min())
plot_max = max(red.max(), nir.max())
ax.set_xlim((plot_min, plot_max))
ax.set_ylim((plot_min, plot_max))
# Add some axis labels
ax.set_xlabel('Red Reflectance')
ax.set_ylabel('NIR label')
# Add a title
ax.set_title('Tasseled Cap, eh?')
Text(0.5, 1.0, 'Tasseled Cap, eh?')
What if we want to view average intensities of each band?
# numbers 1-8
x = np.arange(1,9)
# lets get the average value of each band
y = np.mean(image, axis=(1,2))
fig, ax = plt.subplots()
# plot them as lines
ax.plot(x, y)
# Add some axis labels
ax.set_xlabel('Band #')
ax.set_ylabel('Reflectance Value')
# Add a title
ax.set_title('Band Intensities')
Text(0.5, 1.0, 'Band Intensities')
It is common to have this high reflectance in the NIR band (band 4 here)
With so much data available to look at, it can be hard to understand what is going on with the mess of points shown above. Luckily our datasets aren't just a mess of points - they have a spatial structure.
To show the spatial structure of our images, we could make an image plot of one of our bands using imshow
to display an image on the axes:
# use the matplotlib.pyplot function "imshow" for an image -- nir at first
fig, ax = plt.subplots()
ax.imshow(image[3, :, :])
<matplotlib.image.AxesImage at 0x7f35a21d7850>
Well, it looks like there is something going on - maybe a river in the center and some bright vegetation to the bottom left of the image. What's lacking is any knowledge of what the colors mean.
Luckily, matplotlib
can provide us a colorbar.
# use "imshow" for an image -- nir at first
fig, ax = plt.subplots()
img = ax.imshow(image[3, :, :])
fig.colorbar(img, ax=ax) # we have to pass the current plot as an argument thus have to set it as a variable
<matplotlib.colorbar.Colorbar at 0x7f35a077d940>
If we want a greyscale image, we can manually specify a colormap:
We'll also begin showing some more advanced plotting capabilities in matplotlib such as putting multiple plots in a single output. Here we'll compare the NIR to Red in greyscale.
fig, (ax1, ax2) = plt.subplots(1,2, figsize=(10,3)) # 2 axes on a 1x2 grid
# note that we could also use indexing for our axes:
# fig, ax = plt.subplots(1, 2)
# ax1 = ax[0]
# find max reflectance to put them on the same colorbar scale
max_ref = np.amax([np.amax(image[3:,:]), np.amax(image[2:,:])])
# nir in first subplot
nir = ax1.imshow(image[3, :, :], cmap=plt.cm.Greys)
ax1.set_title("NIR Band")
nir.set_clim(vmin=0, vmax=max_ref)
fig.colorbar(nir, ax=ax1)
# Now red band in the second subplot
red = ax2.imshow(image[2, :, :], cmap=plt.cm.Greys)
ax2.set_title("Red Band")
red.set_clim(vmin=0, vmax=max_ref)
fig.colorbar(red, ax=ax2)
<matplotlib.colorbar.Colorbar at 0x7f35a063ee20>
Greyscale images are nice, but the most information we can receive comes from looking at the interplay among different bands. To accomplish this, we can map different spectral bands to the Red, Green, and Blue channels on our monitors.
Before we can do this, the matplotlib
imshow
help tells us that we need to normalize our bands into a 0 - 1 range. To do so, we will perform a simple linear scale fitting 0 reflectance to 0 and 80% reflectance to 1, clipping anything larger or smaller.
Remember:
If we are going from a Int16 datatype (e.g., reflectance scaled by 10,000x) to a decimal between 0 and 1, we will need to use a Float!
from rasterio.plot import reshape_as_raster, reshape_as_image
# Extract reference to SWIR1, NIR, and Red bands
index = np.array([4, 3, 2])
colors = image[index, :, :].astype(np.float64)
max_val = 5000
min_val = 0
# Enforce maximum and minimum values
colors[colors[:, :, :] > max_val] = max_val
colors[colors[:, :, :] < min_val] = min_val
for b in range(colors.shape[0]):
colors[b, :, :] = colors[b, :, :] * 1 / (max_val - min_val)
# rasters are in the format [bands, rows, cols] whereas images are typically [rows, cols, bands]
# and so our array needs to be reshaped
print(colors.shape)
colors_reshaped = reshape_as_image(colors)
print(colors_reshaped.shape)
(3, 250, 250) (250, 250, 3)
We've got a correctly shaped and normalized image now.
Let's calculate NDVI on this dataset to compare it to the color image.
np.seterr(divide='ignore', invalid='ignore')
bandNIR = image[3, :, :]
bandRed = image[2, :, :]
ndvi = (bandNIR.astype(float)-bandRed.astype(float))/(bandNIR.astype(float)+bandRed.astype(float))
Now let's plot it all:
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
# Show the color image
axs[0].imshow(colors_reshaped)
axs[0].set_title('Color Image')
# Show NDVI
axs[1].imshow(ndvi, cmap='RdYlGn')
axs[1].set_title('NDVI')
Text(0.5, 1.0, 'NDVI')
rasterio
has its own show function that is built to handle rasters. We saw this in chapter 1 briefly.
# this functions build on matplotlib and make them custom for rasterio
from rasterio.plot import show
fig, ax = plt.subplots(figsize=(6,6))
# display just band 4 (NIR)
show((dataset.read(4)), ax=ax)
<AxesSubplot:>
from rasterio.plot import adjust_band
rgb = image[0:3] # read in red, green, blue
rgb_norm = adjust_band(rgb) # normalize bands to range between 1.0 to 0.0
rgb_reshaped = reshape_as_image(rgb_norm) # reshape to [rows, cols, bands]
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
# the rasterio show function takes in [bands, rows, cols] so we don't need to reshape
show(rgb_norm, ax=axs[0])
axs[0].set_title("RGB from rasterio show")
# plot with normal matplotlib functions
axs[1].imshow(rgb_reshaped)
axs[1].set_title("RGB in matplotlib imshow")
Text(0.5, 1.0, 'RGB in matplotlib imshow')
fig, (axr, axg, axb) = plt.subplots(1,3, figsize=(21,7))
show((dataset, 1), ax=axr, cmap='Reds', title='red channel')
show((dataset, 2), ax=axg, cmap='Greens', title='green channel')
show((dataset, 3), ax=axb, cmap='Blues', title='blue channel')
<AxesSubplot:title={'center':'blue channel'}>
from rasterio.plot import show_hist
fig, ax = plt.subplots(figsize=(10,5))
show_hist(dataset, ax=ax)
Let's look at an overlapping histogram, maybe that'll be more informative:
fig, ax = plt.subplots(figsize=(10,5))
show_hist(dataset, ax=ax, bins=50, lw=0.0, stacked=False, alpha=0.3,
histtype='stepfilled', title="World Histogram overlaid")
world = rasterio.open("../data/world.rgb.tif")
print(world.shape)
fig, ax = plt.subplots(figsize=(10,5))
show((world), ax=ax, cmap='viridis')
(256, 512)
<AxesSubplot:>
Check out more documentation for the rasterio
plotting functions here: https://rasterio.readthedocs.io/en/latest/api/rasterio.plot.html#rasterio.plot.show