loading
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
sns.set()
sns.set_style('whitegrid')

# (Change this to your path)
df = pd.read_csv('https://storage.googleapis.com/fake-cohort-data/results.gz')

# This maps the customers to their cohorts based on their first transaction month
customer_to_cohort = df.groupby('customer_id').min()['txn_month'].to_dict()
df['customer_cohort'] = df['customer_id'].map(customer_to_cohort)

# This builds the cohort summary table (shifted_cohort_revenue) and the count of customers
# in each cohort (cohort_sizes).
cohort_revenue = pd.pivot_table(df, index='customer_cohort', columns='txn_month', values='txn_total', aggfunc='sum').fillna(0)

# We want to replace with nan all the entries where the txn month is before the cohort start month.
t = np.triu(np.ones(cohort_revenue.values.shape))
t[t == 0] = np.nan
t = pd.DataFrame(t, index=cohort_revenue.index, columns=cohort_revenue.columns)
cohort_revenue = cohort_revenue * t

shifted_cohort_revenue = cohort_revenue.apply(lambda x: pd.Series(x.dropna().values), axis=1).cumsum(axis=1)
cohort_sizes = df.groupby('customer_cohort')['customer_id'].agg(pd.Series.nunique)

# Plotting
shifted_cohort_revenue.div(cohort_sizes, axis=0).transpose().plot(lw=3)
plt.title("Monthly Cohort Revenue LTVs", fontsize=24)
plt.ylabel("USD", fontsize=24)
plt.xlabel("Months Since Cohort Formation", fontsize=24)
plt.gcf().set_size_inches(16,9)
plt.tick_params(labelsize=20)

labels = [f"{cohort[:10]}, size={size}" for cohort, size in zip(cohort_sizes.index, cohort_sizes.values)]
plt.legend(labels, fontsize=16, ncol=2)
plt.savefig('graph.png')