Unpacking tables with a hill-climbing method in Python

One problem that came up while developing smallerize was trying to recreate the individual participants from a slightly complicated 3-way table like:

SiteAB
GenderFemaleMaleNon-binaryFemaleMaleNon-binaryTotal
Control16184612046
Treatment13164126354
Total2924818183100

A table like this comes from a set of individual participants with different combinations of Site, Gender and Treatment. I wanted to generate a matching set of participants so that I could test out calculations on the same data.

It's possible that some studies report their binary outcome data like this and this could be used for re-analysis, but I suspect the uses are pretty niche and it's more useful as a demonstration of how useful random iteration on a solution can be for some problems.

Hill-climbing, or educated guessing

The solution I came up with uses a simple shuffling procedure that randomly swaps elements and checks how different the numbers in each treatment group are from the true numbers, keeping the swap if the difference gets smaller. It probably qualifies as a hill climbing algorithm, although that might be a generous term.

I've found this kind of guessing and gradual improvement to be useful for a few different kinds of problems where it wasn't clear if there was a good existing solution. For each problem, it basically requires:

  • A way to check distance between the current solution and the desired state.
  • A way to randomly change part of the solution, in a way that might get closer to the solution

Recreating the original participants

My solution in Python is shown below - we don't need anything outside of the Python standard library here.

import collections
import itertools
import random
import csv

I've stored the 'true' allocations that the table above is based on in a csv, which you can download here, and for convenience we'll use it to grab the totals from the table. Note that we're not actually cheating using the true allocations, it's just that 3-way tables are hard to work with and this is the easiest way to read in the counts.

Python makes it easy to swap between different representations of the table and counts, so expect heavy use of comprehensions and zip().

# The CSV is in table format but we'll also convert it to a dict
#   for easier processing
true_table = [
    line for line in csv.DictReader(open('table_assignments.csv'))
]
true = {}
for factor in ('gender', 'site', 'group'):
    true[factor] = [row[factor] for row in true_table]

# Checking we have the same totals as the table
totals = {}
for factor in ('gender', 'site', 'group'):
    totals[factor] = collections.Counter(
        row[factor] for row in true_table
    )
totals
{'gender': Counter({'Female': 47, 'Male': 42, 'Non-binary': 11}),
 'site': Counter({'B': 39, 'A': 61}),
 'group': Counter({'Control': 46, 'Treatment': 54})}

Now we'll create an initial set of participants randomly, based on these totals:

participants = {}
for factor in ('gender', 'site', 'group'):
    participants[factor] = list(
        itertools.chain(
            *([level] * num 
            for level, num in totals[factor].items())
        )
    )
    random.shuffle(participants[factor])

# Some of the randomly generated participants:
list(zip(*participants.values()))[:5]
[('Non-binary', 'A', 'Treatment'),
 ('Male', 'A', 'Treatment'),
 ('Male', 'A', 'Treatment'),
 ('Female', 'A', 'Control'),
 ('Female', 'A', 'Treatment')]

We can count the current number of participants in each cell and compare against the true values:

def count_all(ps):
    counts = collections.Counter(zip(*ps.values()))
    return counts

current_counts = count_all(participants)
current_counts
Counter({('Non-binary', 'A', 'Treatment'): 3,
         ('Male', 'A', 'Treatment'): 15,
         ('Female', 'A', 'Control'): 14,
         ('Female', 'A', 'Treatment'): 14,
         ('Male', 'B', 'Treatment'): 11,
         ('Male', 'A', 'Control'): 11,
         ('Non-binary', 'A', 'Control'): 4,
         ('Female', 'B', 'Treatment'): 9,
         ('Female', 'B', 'Control'): 10,
         ('Non-binary', 'B', 'Treatment'): 2,
         ('Male', 'B', 'Control'): 5,
         ('Non-binary', 'B', 'Control'): 2})
true_counts = count_all(true)
# Note that when you subtract Counters, only the positive
#   values are kept. This saves us dealing with absolute
#   values later
diffs = current_counts - true_counts
diffs
Counter({('Female', 'A', 'Treatment'): 1,
         ('Male', 'B', 'Treatment'): 5,
         ('Male', 'A', 'Control'): 3,
         ('Female', 'B', 'Control'): 4,
         ('Non-binary', 'B', 'Control'): 2})

For the overall measure of fit against the true table, we'll just sum these values.

total_diff = sum(diffs.values())
total_diff
15

Finding the solution

Now we can try to swap factor levels between two randomly selected participants, and see if we get closer to the truth. If we do, we keep the swap. If we try out these swaps repeatedly, we can usually reach the true table in a thousand attempts or less (it is random though, so it varies and won't always find a solution on a given run):

max_attempts = 1e5
best_diff = total_diff

attempts = 0
solved = False
while not solved:
    a = random.randint(0, len(true_table) - 1)
    b = random.randint(0, len(true_table) - 1)
    
    # Pick a random factor to try permuting
    to_swap = random.choice(['gender', 'site'])
    # Try swapping their levels
    current_a = participants[to_swap][a]
    current_b = participants[to_swap][b]
    # Try swap 
    participants[to_swap][b] = current_a
    participants[to_swap][a] = current_b
        
    current_counts = count_all(participants)
    current_total = sum(v for v in (current_counts - true_counts).values())

    if current_total < best_diff:
        best_diff = current_total
        print("Improved match. New total: ", best_diff)
    else:
        # Swap back
        participants[to_swap][a] = current_a
        participants[to_swap][b] = current_b

    if best_diff == 0:
        solved = True
        print("Solved in", attempts, "attempts")
    attempts += 1
    if attempts > max_attempts:
        print("Failed to solve.")
        break
Improved match. New total:  13
Improved match. New total:  11
Improved match. New total:  10
Improved match. New total:  9
Improved match. New total:  8
Improved match. New total:  6
Improved match. New total:  4
Improved match. New total:  2
Improved match. New total:  0
Solved in 958 attempts

We can double check that our final counts match the originals:

final_counts = count_all(participants)
# Unlike above, the Counter.subtract() method keeps negative
#   differences as well: 
final_counts.subtract(true_counts)

This means the list of individual participants should now match the true values that were used to generate the table:

sorted(zip(*participants.values())) == sorted(zip(*true.values()))
True

And we're done! While writing this post, I remembered a cool application of a similar hill-climbing method for breaking simple codes, so stay tuned for another post coming soon.