Created
December 9, 2021 15:05
-
-
Save mwlon/9f494874a873aacf7d2029c570f106ee to your computer and use it in GitHub Desktop.
Fast runtime implementation of Cashflow Waterfall
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # requirements: python 3.7+ | |
| from dataclasses import dataclass, field | |
| from typing import List, Dict | |
| # Over the course of the entire lifecycle, this algorithm uses | |
| # O(cashflows + shareholders) runtime. | |
| @dataclass | |
| class WaterfallPool: | |
| sh_ids: List[str] = field(default_factory = list) | |
| sh_remaining_shares: List[int] = field(default_factory = list) | |
| sh_cash: List[float] = field(default_factory = list) | |
| current_sh_idx: int = 0 | |
| sh_idx_by_id: Dict[str, int] = field(default_factory = dict) | |
| def enter_pool(self, sh_id, n_shares): | |
| idx = len(self.sh_ids) | |
| self.sh_ids.append(sh_id) | |
| self.sh_remaining_shares.append(n_shares) | |
| self.sh_cash.append(0.0) | |
| self.sh_idx_by_id[sh_id] = idx | |
| def cashflow(self, amount, n_shares): | |
| remaining_n_shares = n_shares | |
| # Yes it's a loop, but over the course of the algorithm it | |
| # iterates very few times. | |
| while remaining_n_shares > 0 and self.current_sh_idx < len(self.sh_ids): | |
| deduction = min(remaining_n_shares, self.sh_remaining_shares[self.current_sh_idx]) | |
| self.sh_remaining_shares[self.current_sh_idx] -= deduction | |
| remaining_n_shares -= deduction | |
| self.sh_cash[self.current_sh_idx] += amount * deduction / n_shares | |
| if self.sh_remaining_shares[self.current_sh_idx] == 0: | |
| # This is how we ensure runtime of O(cashflows + shareholders). | |
| # Keep progressing the shareholder index so we don't check | |
| # shareholders who have cashed out. | |
| # At each constant-runtime iteration of the while loop, we either | |
| # increment the current shareholder idx or the current cashflow | |
| # index, so we undergo at most #cashflows + #shareholders iterations. | |
| self.current_sh_idx += 1 | |
| def claim_cash(self, sh_id): | |
| idx = self.sh_idx_by_id[sh_id] | |
| return self.sh_cash[idx] | |
| pool = WaterfallPool() | |
| pool.enter_pool('sh1', 100) | |
| pool.enter_pool('sh2', 50) | |
| pool.cashflow(500.0, 40) | |
| pool.cashflow(1000.0, 100) | |
| pool.enter_pool('sh3', 190) | |
| pool.cashflow(400.0, 200) | |
| for sh_id in ['sh1', 'sh2', 'sh3']: | |
| cash = pool.claim_cash(sh_id) | |
| print(f'shareholder {sh_id} got ${cash}') | |
| # result: | |
| # shareholder sh1 got $1100.0 | |
| # shareholder sh2 got $420.0 | |
| # shareholder sh3 got $380.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment