You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

62 lines
2.3 KiB

from pipeline import (
load_dataset,
filter_data,
filter_test_data,
prepare_user_data,
train_models,
evaluate_models,
display_warning_about_2020_data,
display_warnings_for_scenarios
)
# === Configurable Parameters ===
DATA_PATH = './Datasets/ALLUSERS32_15MIN_WITHOUTTHREHOLD.xlsx'
OUTPUT_EXCEL_PATH = './working/evaluation_results.xlsx'
SEQUENCE_LENGTHS = [20] # You can add more: [20, 25, 30]
TRAINING_SCENARIO = [(2018, list(range(1, 13))), (2019, list(range(1, 10)))]
VALIDATION_SCENARIO = [(2019, [10, 11, 12])]
TEST_SCENARIO = [(2020, [1, 2])] # Jan–Feb 2020 only
# === Optional display only ===
predefined_training_scenarios = {
"Scenario 1": {"years_months": [(2018, list(range(1, 13))), (2019, list(range(1, 10)))]},
"Scenario 2": {"years_months": [(2017, list(range(1, 13))), (2018, list(range(1, 13))), (2019, list(range(1, 10)))]}
}
predefined_validation_scenarios = {
"Scenario A": {"years_months": [(2019, [10, 11, 12])]}
}
def main():
print("=== Training Scenario Setup ===")
display_warning_about_2020_data()
display_warnings_for_scenarios("training", predefined_training_scenarios, predefined_validation_scenarios)
print("\n=== Validation Scenario Setup ===")
display_warning_about_2020_data()
display_warnings_for_scenarios("validation", predefined_training_scenarios, predefined_validation_scenarios)
# === Load and preprocess ===
df = load_dataset(DATA_PATH)
ALLUSERS32_15MIN_WITHOUTTHREHOLD = False
if('ALLUSERS32_15MIN_WITHOUTTHREHOLD.xlsx' in DATA_PATH):
ALLUSERS32_15MIN_WITHOUTTHREHOLD = True
training_data = filter_data(df, TRAINING_SCENARIO, ALLUSERS32_15MIN_WITHOUTTHREHOLD)
validation_data = filter_data(df, VALIDATION_SCENARIO, ALLUSERS32_15MIN_WITHOUTTHREHOLD)
user_data_train = prepare_user_data(training_data)
user_data_val = prepare_user_data(validation_data)
# === Train models ===
best_models = train_models(user_data_train, user_data_val, sequence_lengths=SEQUENCE_LENGTHS)
# === Load and evaluate test ===
test_df = filter_test_data(df, TEST_SCENARIO)
evaluate_models(best_models, test_df, SEQUENCE_LENGTHS, OUTPUT_EXCEL_PATH, ALLUSERS32_15MIN_WITHOUTTHREHOLD)
print(f"\n✅ All evaluations completed. Results saved to: {OUTPUT_EXCEL_PATH}")
if __name__ == "__main__":
main()