Skip to content

Instantly share code, notes, and snippets.

@mwlon
Created December 9, 2021 15:05
Show Gist options
  • Select an option

  • Save mwlon/9f494874a873aacf7d2029c570f106ee to your computer and use it in GitHub Desktop.

Select an option

Save mwlon/9f494874a873aacf7d2029c570f106ee to your computer and use it in GitHub Desktop.
Fast runtime implementation of Cashflow Waterfall
# requirements: python 3.7+
from dataclasses import dataclass, field
from typing import List, Dict
# Over the course of the entire lifecycle, this algorithm uses
# O(cashflows + shareholders) runtime.
@dataclass
class WaterfallPool:
sh_ids: List[str] = field(default_factory = list)
sh_remaining_shares: List[int] = field(default_factory = list)
sh_cash: List[float] = field(default_factory = list)
current_sh_idx: int = 0
sh_idx_by_id: Dict[str, int] = field(default_factory = dict)
def enter_pool(self, sh_id, n_shares):
idx = len(self.sh_ids)
self.sh_ids.append(sh_id)
self.sh_remaining_shares.append(n_shares)
self.sh_cash.append(0.0)
self.sh_idx_by_id[sh_id] = idx
def cashflow(self, amount, n_shares):
remaining_n_shares = n_shares
# Yes it's a loop, but over the course of the algorithm it
# iterates very few times.
while remaining_n_shares > 0 and self.current_sh_idx < len(self.sh_ids):
deduction = min(remaining_n_shares, self.sh_remaining_shares[self.current_sh_idx])
self.sh_remaining_shares[self.current_sh_idx] -= deduction
remaining_n_shares -= deduction
self.sh_cash[self.current_sh_idx] += amount * deduction / n_shares
if self.sh_remaining_shares[self.current_sh_idx] == 0:
# This is how we ensure runtime of O(cashflows + shareholders).
# Keep progressing the shareholder index so we don't check
# shareholders who have cashed out.
# At each constant-runtime iteration of the while loop, we either
# increment the current shareholder idx or the current cashflow
# index, so we undergo at most #cashflows + #shareholders iterations.
self.current_sh_idx += 1
def claim_cash(self, sh_id):
idx = self.sh_idx_by_id[sh_id]
return self.sh_cash[idx]
pool = WaterfallPool()
pool.enter_pool('sh1', 100)
pool.enter_pool('sh2', 50)
pool.cashflow(500.0, 40)
pool.cashflow(1000.0, 100)
pool.enter_pool('sh3', 190)
pool.cashflow(400.0, 200)
for sh_id in ['sh1', 'sh2', 'sh3']:
cash = pool.claim_cash(sh_id)
print(f'shareholder {sh_id} got ${cash}')
# result:
# shareholder sh1 got $1100.0
# shareholder sh2 got $420.0
# shareholder sh3 got $380.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment