You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
94 lines
3.7 KiB
94 lines
3.7 KiB
import numpy as np
|
|
import pandas as pd
|
|
|
|
from pipeline import (
|
|
load_dataset,
|
|
filter_data,
|
|
filter_test_data,
|
|
prepare_user_data,
|
|
train_models,
|
|
evaluate_models,
|
|
display_warning_about_2020_data,
|
|
display_warnings_for_scenarios
|
|
)
|
|
|
|
year_str = 'Year'
|
|
month_str = 'Month'
|
|
user_str = 'user'
|
|
|
|
# === Configurable Parameters ===
|
|
DATA_PATH = './Datasets/ALLUSERS32_15MIN_WITHOUTTHREHOLD.xlsx'
|
|
OUTPUT_EXCEL_PATH = './working/evaluation_results.xlsx'
|
|
SEQUENCE_LENGTHS = [20] # You can add more: [20, 25, 30]
|
|
|
|
TRAINING_SCENARIO = [(2018, list(range(1, 13))), (2019, list(range(1, 10)))]
|
|
VALIDATION_SCENARIO = [(2019, [10, 11, 12])]
|
|
TEST_SCENARIO = [(2020, [1, 2])] # Jan–Feb 2020 only
|
|
|
|
# === Optional display only ===
|
|
predefined_training_scenarios = {
|
|
"Scenario 1": {"years_months": [(2018, list(range(1, 13))), (2019, list(range(1, 10)))]},
|
|
"Scenario 2": {"years_months": [(2017, list(range(1, 13))), (2018, list(range(1, 13))), (2019, list(range(1, 10)))]}
|
|
}
|
|
predefined_validation_scenarios = {
|
|
"Scenario A": {"years_months": [(2019, [10, 11, 12])]}
|
|
}
|
|
|
|
def remove_covid_data(df):
|
|
df = df[~((df[year_str]==2020) & (df[month_str]>2))]
|
|
return df
|
|
|
|
def split_data_by_month_percentage(df, percentages):
|
|
train_p, valid_p, test_p = percentages
|
|
ids = df[[year_str, month_str]].drop_duplicates().sort_values([year_str, month_str])
|
|
tr, va, te = np.split(ids, [int((train_p/100) * len(ids)), int(((train_p + valid_p)/100) * len(ids))])
|
|
return df.merge(tr, on=[year_str, month_str], how='inner'), df.merge(va, on=[year_str, month_str], how='inner'), df.merge(te, on=[year_str, month_str], how='inner')
|
|
|
|
def split_data_by_userdata_percentage(df, percentages):
|
|
train_p, valid_p, test_p = percentages
|
|
tr, va, te = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
|
|
for user_id in df[user_str].unique():
|
|
user_data = df[df[user_str]==user_id].sort_values([year_str, month_str])
|
|
u_tr, u_va, u_te = np.split(user_data, [int((train_p/100)*len(user_data)), int(((train_p+valid_p)/100)*len(user_data))])
|
|
tr = pd.concat([tr, u_tr], ignore_index=True)
|
|
va = pd.concat([va, u_va], ignore_index=True)
|
|
te = pd.concat([te, u_te], ignore_index=True)
|
|
return tr, va, te
|
|
|
|
|
|
def main():
|
|
# print("=== Training Scenario Setup ===")
|
|
# display_warning_about_2020_data()
|
|
# display_warnings_for_scenarios("training", predefined_training_scenarios, predefined_validation_scenarios)
|
|
|
|
# print("\n=== Validation Scenario Setup ===")
|
|
# display_warning_about_2020_data()
|
|
# display_warnings_for_scenarios("validation", predefined_training_scenarios, predefined_validation_scenarios)
|
|
|
|
# === Load and preprocess ===
|
|
df = load_dataset(DATA_PATH)
|
|
removed = remove_covid_data(df)
|
|
tr,val,te = split_data_by_userdata_percentage(df, (80,10,10))
|
|
tr_2, val_2, te_2 = split_data_by_month_percentage(df, (80, 10, 10))
|
|
|
|
ALLUSERS32_15MIN_WITHOUTTHREHOLD = False
|
|
if('ALLUSERS32_15MIN_WITHOUTTHREHOLD.xlsx' in DATA_PATH):
|
|
ALLUSERS32_15MIN_WITHOUTTHREHOLD = True
|
|
|
|
training_data = filter_data(df, TRAINING_SCENARIO, ALLUSERS32_15MIN_WITHOUTTHREHOLD)
|
|
validation_data = filter_data(df, VALIDATION_SCENARIO, ALLUSERS32_15MIN_WITHOUTTHREHOLD)
|
|
|
|
user_data_train = prepare_user_data(training_data)
|
|
user_data_val = prepare_user_data(validation_data)
|
|
|
|
# === Train models ===
|
|
best_models = train_models(user_data_train, user_data_val, sequence_lengths=SEQUENCE_LENGTHS)
|
|
|
|
# === Load and evaluate test ===
|
|
test_df = filter_test_data(df, TEST_SCENARIO)
|
|
evaluate_models(best_models, test_df, SEQUENCE_LENGTHS, OUTPUT_EXCEL_PATH, ALLUSERS32_15MIN_WITHOUTTHREHOLD)
|
|
|
|
print(f"\n✅ All evaluations completed. Results saved to: {OUTPUT_EXCEL_PATH}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|