I used to hate Python — a hatred mainly rooted in fear and envy. But it is a great language for getting things to just work, and I now write hundreds of lines of Python per day as a machine learning research student. However, it is not enough to write code that just works; anyone who has had to reproduce someone else’s work will tell you it is also important for code to communicate its own logical intent to the poor humans who have to read it.
This series, Defensive Python, will provide little tricks, hacks, and gotchas that will help you write Python code that explains itself by verifying itself.
0: the best defense is a strong offense
This isn’t just a saying — offensive programming is an actual branch of defensive programming. Its core principle is that your code is an expression of your assumptions. If your assumptions are violated, then your code needs to do something to get your attention. Especially for researchers, it is better for your code to crash than for it to run down the wrong path and give the wrong results.
The most offensive tool in your arsenal is assert.
An assert statement has two parts: a condition, and the thing that gets printed if that condition is false.
>>> assert 0 <= i < len(data), i
>>> assert x is not None
>>> assert False, "this part of the code should never be reached"I’m really bad at remembering to add the second part. You don’t really need to do it, since the condition is the statement of your assumptions. But if (when) your assertion fails, you will need that output to figure out what went wrong.
The other benefit of asserts is that they can provide information to your linter.
def successor(x):
assert isinstance(x, int), type(x)
return x + 1 # the type checker won't complainComments are useful too, but just remember that comments don’t crash when they are violated. I’m lazy, so that’s my reason for not commenting my code.
1: strict zip
One simple example of offensive Python is a little-known feature of the built-in zip function that was added in Python 3.10.
You probably use zip to do things like looping over multiple lists at once:
for x, y, z in zip(X, Y, Z):
do_something(x, y, z)X, Y, and Z should probably have the same length.
If they don’t, then zip cuts off the moment any of the iterators run out, from left to right.
This can have surprising side effects:
def X():
print("x1"); yield "x1"
print("x2"); yield "x2"
print("X ran out")
def Y():
print("y1"); yield "y1"
print("Y ran out")
def Z():
print("z1"); yield "z1"
print("Z ran out")
>>> assert list(zip(X(), Y(), Z())) == [("x1", "y1", "z1")]
x1
y1
z1
x2
Y ran out
>>> assert list(zip(Z(), Y(), X())) == [("z1", "y1", "x1")]
z1
y1
x1
Z ran outNotice how the first iterable to run out causes everything else to get truncated — even if the other iterables were going to run out on the same step. If you are counting on your iterables to do some kind of cleanup when they stop, then you are in for a nasty surprise.
Since Python 3.10, zip accepts an optional strict=True keyword argument to tell it to validate that all of the iterables run out at the same time.
If they aren’t actually the same length, zip(..., strict=True) will raise an error.
>>> list(zip(X(), Y(), Z(), strict=True))
x1
y1
z1
x2
Y ran out
ValueError: zip() argument 2 is shorter than argument 1If they are the same length, then every iterator will be exhausted regardless of their order.
>>> assert list(zip(Y(), Z(), strict=True)) == [("y1", "z1")]
y1
z1
Y ran out
Z ran outAs you see, zip(..., strict=True) implements and offensively validates the behaviour that most people expect from the zip function.
You should always use it when you expect all of the iterators to have the same length.
If your iterators have different lengths, but you want to continue iterating until all of them run out, use
itertools.zip_longest.
2: tuple indexing is a code smell
One common pattern when working with multidimensional arrays (NumPy, PyTorch, etc.) is accessing the dimensions of an array:
def foo(a: NDArray, b: NDArray):
"""
Args:
a (NDArray): B x M x M
b (NDArray): B x N x N
"""
B = a.shape[0]
M = a.shape[1]
N = b.shape[1]
...The problem with docstrings is that they have no power of enforcement on their own.
>>> a = np.random.rand(2, 5, 3, 9)
>>> a.shape
(2, 5, 3, 9)
>>> b = np.random.rand(2, 3)
>>> b.shape
(2, 3)
>>> foo(a, b) # inputs get incorrectly acceptedWe could obviously add a bunch of assertions to make this function more offensive. But before we do that, here’s an alternative.
def foo(a: NDArray, b: NDArray):
# same docstring
B, M, _ = a.shape
_, N, _ = b.shapeNow you will get a ValueError if someone gives you anything with the wrong number of dimensions.
>>> foo(a, b)
ValueError: too many values to unpack (expected 3)
>>> foo(b, a)
ValueError: not enough values to unpack (expected 3, got 2)To make things fully offensive, we add assertions to take care of the stuff that destructuring can’t do:
def foo(a: NDArray, b: NDArray):
# same docstring
B, M, _ = a.shape
_, N, _ = b.shape
assert a.shape == (B, M, M), (a.shape, B, M)
assert b.shape == (B, N, N), (b.shape, B, N)Now, anyone reading the code knows exactly what shapes they should pass in.
As you just saw, shape tuples are a great place to apply destructuring. In general, if you ever find yourself accessing an element inside a tuple using indexing, then your sense of code smell should be tingling to tell you that there’s probably a more offensive way to do it.
For example, if you call a function that has multiple returns, but only care about some of them:
>>> loss = model(x)[0] # smelly
>>> loss, _ = model(x) # offensiveIn the second case, the reader learns that model(x) returns two things.
You can even use it in a loop:
>>> losses = [model(x)[0] for x in inputs] # smelly
>>> losses = [loss for loss, _ in (model(x) for x in inputs)] # offensive… although it is a bit offensive to the eyes as well.
3: only
We will end off for now by introducing our first prescribed helper function. I carry it around all the time and I’m surprised it’s not part of the Python standard library.
Suppose you find yourself destructuring a collection with a single element.
>>> y = find_all(x)[0] # smelly
>>> y = next(iter(find_all(x))) # even smellier
>>> y, = find_all(x) # offensiveHowever, if you go for the offensive solution, then you can’t add a type hint to y, and a single comma will be the difference between your offensive code and an extremely hard-to-spot bug.
Helper function1 to the rescue:
def only(it: Iterable[T]) -> T:
x, = it
return x
>>> y = only(find_all(x)) # niceNot only does it make your destructuring less bug-prone, it also makes it pretty inside a comprehension:
>>> ys = [only(find_all(x)) for x in xs]