diff --git a/.gitignore b/.gitignore index f8b73e7..aa68a72 100644 --- a/.gitignore +++ b/.gitignore @@ -138,3 +138,4 @@ dmypy.json # Cython debug symbols cython_debug/ +.idea diff --git a/main.py b/main.py index b078918..78c90e3 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,6 @@ +import numpy as np +import pandas as pd + from pipeline import ( load_dataset, filter_data, @@ -9,6 +12,10 @@ from pipeline import ( display_warnings_for_scenarios ) +year_str = 'Year' +month_str = 'Month' +user_str = 'user' + # === Configurable Parameters === DATA_PATH = './Datasets/ALLUSERS32_15MIN_WITHOUTTHREHOLD.xlsx' OUTPUT_EXCEL_PATH = './working/evaluation_results.xlsx' @@ -27,6 +34,28 @@ predefined_validation_scenarios = { "Scenario A": {"years_months": [(2019, [10, 11, 12])]} } +def remove_covid_data(df): + df = df[~((df[year_str]==2020) & (df[month_str]>2))] + return df + +def split_data_by_month_percentage(df, percentages): + train_p, valid_p, test_p = percentages + ids = df[[year_str, month_str]].drop_duplicates().sort_values([year_str, month_str]) + tr, va, te = np.split(ids, [int((train_p/100) * len(ids)), int(((train_p + valid_p)/100) * len(ids))]) + return df.merge(tr, on=[year_str, month_str], how='inner'), df.merge(va, on=[year_str, month_str], how='inner'), df.merge(te, on=[year_str, month_str], how='inner') + +def split_data_by_userdata_percentage(df, percentages): + train_p, valid_p, test_p = percentages + tr, va, te = pd.DataFrame(), pd.DataFrame(), pd.DataFrame() + for user_id in df[user_str].unique(): + user_data = df[df[user_str]==user_id].sort_values([year_str, month_str]) + u_tr, u_va, u_te = np.split(user_data, [int((train_p/100)*len(user_data)), int(((train_p+valid_p)/100)*len(user_data))]) + tr = pd.concat([tr, u_tr], ignore_index=True) + va = pd.concat([va, u_va], ignore_index=True) + te = pd.concat([te, u_te], ignore_index=True) + return tr, va, te + + def main(): # print("=== Training Scenario Setup ===") # display_warning_about_2020_data() @@ -38,6 +67,9 @@ def main(): # === Load and preprocess === df = load_dataset(DATA_PATH) + removed = remove_covid_data(df) + tr,val,te = split_data_by_userdata_percentage(df, (80,10,10)) + tr_2, val_2, te_2 = split_data_by_month_percentage(df, (80, 10, 10)) ALLUSERS32_15MIN_WITHOUTTHREHOLD = False if('ALLUSERS32_15MIN_WITHOUTTHREHOLD.xlsx' in DATA_PATH): diff --git a/requirements.txt b/requirements.txt index b734340..833dbb3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -38,7 +38,7 @@ six==1.17.0 tensorboard==2.19.0 tensorboard-data-server==0.7.2 tensorflow==2.19.0 -tensorflow-io-gcs-filesystem==0.37.1 +tensorflow-io-gcs-filesystem==0.31.0 termcolor==3.1.0 threadpoolctl==3.6.0 typing_extensions==4.14.1