|
|
@ -1,3 +1,6 @@ |
|
|
|
import numpy as np |
|
|
|
import pandas as pd |
|
|
|
|
|
|
|
from pipeline import ( |
|
|
|
load_dataset, |
|
|
|
filter_data, |
|
|
@ -9,6 +12,10 @@ from pipeline import ( |
|
|
|
display_warnings_for_scenarios |
|
|
|
) |
|
|
|
|
|
|
|
year_str = 'Year' |
|
|
|
month_str = 'Month' |
|
|
|
user_str = 'user' |
|
|
|
|
|
|
|
# === Configurable Parameters === |
|
|
|
DATA_PATH = './Datasets/ALLUSERS32_15MIN_WITHOUTTHREHOLD.xlsx' |
|
|
|
OUTPUT_EXCEL_PATH = './working/evaluation_results.xlsx' |
|
|
@ -27,6 +34,28 @@ predefined_validation_scenarios = { |
|
|
|
"Scenario A": {"years_months": [(2019, [10, 11, 12])]} |
|
|
|
} |
|
|
|
|
|
|
|
def remove_covid_data(df): |
|
|
|
df = df[~((df[year_str]==2020) & (df[month_str]>2))] |
|
|
|
return df |
|
|
|
|
|
|
|
def split_data_by_month_percentage(df, percentages): |
|
|
|
train_p, valid_p, test_p = percentages |
|
|
|
ids = df[[year_str, month_str]].drop_duplicates().sort_values([year_str, month_str]) |
|
|
|
tr, va, te = np.split(ids, [int((train_p/100) * len(ids)), int(((train_p + valid_p)/100) * len(ids))]) |
|
|
|
return df.merge(tr, on=[year_str, month_str], how='inner'), df.merge(va, on=[year_str, month_str], how='inner'), df.merge(te, on=[year_str, month_str], how='inner') |
|
|
|
|
|
|
|
def split_data_by_userdata_percentage(df, percentages): |
|
|
|
train_p, valid_p, test_p = percentages |
|
|
|
tr, va, te = pd.DataFrame(), pd.DataFrame(), pd.DataFrame() |
|
|
|
for user_id in df[user_str].unique(): |
|
|
|
user_data = df[df[user_str]==user_id].sort_values([year_str, month_str]) |
|
|
|
u_tr, u_va, u_te = np.split(user_data, [int((train_p/100)*len(user_data)), int(((train_p+valid_p)/100)*len(user_data))]) |
|
|
|
tr = pd.concat([tr, u_tr], ignore_index=True) |
|
|
|
va = pd.concat([va, u_va], ignore_index=True) |
|
|
|
te = pd.concat([te, u_te], ignore_index=True) |
|
|
|
return tr, va, te |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
# print("=== Training Scenario Setup ===") |
|
|
|
# display_warning_about_2020_data() |
|
|
@ -38,6 +67,9 @@ def main(): |
|
|
|
|
|
|
|
# === Load and preprocess === |
|
|
|
df = load_dataset(DATA_PATH) |
|
|
|
removed = remove_covid_data(df) |
|
|
|
tr,val,te = split_data_by_userdata_percentage(df, (80,10,10)) |
|
|
|
tr_2, val_2, te_2 = split_data_by_month_percentage(df, (80, 10, 10)) |
|
|
|
|
|
|
|
ALLUSERS32_15MIN_WITHOUTTHREHOLD = False |
|
|
|
if('ALLUSERS32_15MIN_WITHOUTTHREHOLD.xlsx' in DATA_PATH): |
|
|
|