from schedule_class import NewSchedule, generate_schedules, service_time_with_no_shows
import pickle
import random
import numpy as np
import xgboost as xgb
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold, GridSearchCV
from sklearn.metrics import mean_squared_error, make_scorer, accuracy_score
from sklearn.base import clone
from sklearn.model_selection import StratifiedKFold, train_test_split, GridSearchCV
from itertools import chain, combinations
2024-07-18
Besproken met Joost:
- Model bouwen voor pairwise ranking.
- Performance vergelijken met cardinal ML model
- Computation time vergelijken:
- Lindley recursion <> cardinal ML
- Cardinal ML model pairwise ranking vs direct pairwise ranking
- Cardinal ML model met large objective punisment in loss function ontwikkelen
Setup and load data
class ScheduleData:
def __init__(self, N: int, T: int, samples, labels):
self.N = N
self.T = T
self.samples = samples
self.labels = labels
def describe_data(self):
print(f'N = {self.N}', f'T = {self.T}', '\nSamples',self.samples.tail(10), '\nLabels', self.labels.tail(10), sep = "\n")
def create_pair_list(self, n): # Create a set of randomly selected pairs of schedules
= list(range(len(self.samples)))
S = random.choices(S, k=n)
Q = []
P
for q in Q:
# Create a list of possible choices excluding t
= [s for s in S if s != q]
possible_choices
# Choose a random element from the possible choices
= random.choice(possible_choices)
p
# Add the chosen element to the result list
P.append(p)
= self.samples.iloc[Q, :].values.tolist()
samples_s1 = self.samples.iloc[P, :].values.tolist()
samples_s2 self.pair_list = list(zip(samples_s1, samples_s2))
self.lables_s1 = self.labels.loc[Q, 'obj'].values.tolist()
self.lables_s2 = self.labels.loc[P, 'obj'].values.tolist()
self.lables_rank = [1 * (self.lables_s1[i] > self.lables_s2[i]) for i in range(len(self.lables_s1))]
print(self.pair_list[:15], "\n", self.lables_s1[:15], "\n", self.lables_s2[:15], "\n", self.lables_rank[:15])
def create_neighbors_list(self, n): # Create a set of pairs of schedules that are from the same neighborhood
# Build a subset of random schedules with length n
= list(range(len(self.samples)))
S = random.choices(S, k=n)
Q = self.samples.iloc[Q, :].values.tolist()
samples_sub = self.labels.iloc[Q, 7]
labels_sub self.neighbors_list = []
# For each schedule in in the subset choose 2 random intervals i, j and swap 1 patient
for s in samples_sub:
= random.choice(range(len(s))) # Ensure i is a valid index in s
i = [index for index, element in enumerate(s) if element > 0 and index != i]
j
if not j: # Ensure j is not empty
continue
= random.choice(j) # Choose a random valid index from j
j
= s.copy() # Create a copy of s to modify
s_pair = s[i] + 1
s_pair[i] = s[j] - 1
s_pair[j]
self.neighbors_list.append((s, s_pair))
print(samples_sub, "\n", self.neighbors_list, "\n", labels_sub)
def calculate_objective(self, schedule, s, d, q):
= service_time_with_no_shows(s, q) # Adjust service times distribution for no-shows
s = np.array([1], dtype=np.int64) # Set probability of first spillover time being zero to 1
sp = [] # Initialize wt_list for saving all waiting times for all patients in the schedule
wt_list = 0 # Initialize sum of expected waiting times
ewt for x in schedule: # For each interval -
if(x == 0): # In case there are no patients,
= [np.array(sp)] # the spillover from the previous interval is recorded,
wt_temp # but there are no waiting times.
wt_list.append([]) = [] # Initialize the spillover time distribution
sp sum(wt_temp[-1][:d+1])) # All the work from the previous interval's spillover that could not be processed will be added to the this interval's spillover.
sp.append(np.1:] = wt_temp[-1][d+1:]
sp[else:
= [np.array(sp)] # Initialize wt_temp for saving all waiting times for all patients in the interval. The first patient has to wait for the spillover work from the previous period.
wt_temp += np.dot(range(len(sp)), sp) # Add waiting time for first patient in interval
ewt for i in range(x-1): # For each patient
= np.convolve(wt_temp[i], s) # Calculate the waiting time distribution
wt
wt_temp.append(wt)+= np.dot(range(len(wt)), wt)
ewt
wt_list.append(wt_temp)= []
sp sum(np.convolve(wt_temp[-1],s)[:d+1])) # Calculate the spillover
sp.append(np.1:] = np.convolve(wt_temp[-1],s)[d+1:]
sp[print(f"Schedule: {schedule}, ewt = {ewt}")
with open('./experiments/data.pickle', 'rb') as file:
= pickle.load(file)
sch_data: ScheduleData
sch_data.description
sch_data.describe_data()16000)
sch_data.create_pair_list(10)
sch_data.create_neighbors_list(= sch_data.neighbors_list[1][0]
test_sch 0.0, 0.27, 0.28, 0.2, 0.15, 0.1], 3, 0.2) sch_data.calculate_objective(test_sch, [
N = 10
T = 7
Samples
x_0 x_1 x_2 x_3 x_4 x_5 x_6
7998 8 1 0 1 0 0 0
7999 8 1 1 0 0 0 0
8000 8 2 0 0 0 0 0
8001 9 0 0 0 0 0 1
8002 9 0 0 0 0 1 0
8003 9 0 0 0 1 0 0
8004 9 0 0 1 0 0 0
8005 9 0 1 0 0 0 0
8006 9 1 0 0 0 0 0
8007 10 0 0 0 0 0 0
Labels
ew_0 ew_1 ew_2 ew_3 ew_4 ew_5 ew_6 \
7998 56.672 13.192158 0.000000 9.241482 0.000000 0.000000 0.0000
7999 56.672 13.192158 12.217911 0.000000 0.000000 0.000000 0.0000
8000 56.672 28.408317 0.000000 0.000000 0.000000 0.000000 0.0000
8001 72.864 0.000000 0.000000 0.000000 0.000000 0.000000 1.9547
8002 72.864 0.000000 0.000000 0.000000 0.000000 3.858871 0.0000
8003 72.864 0.000000 0.000000 0.000000 6.380606 0.000000 0.0000
8004 72.864 0.000000 0.000000 9.241482 0.000000 0.000000 0.0000
8005 72.864 0.000000 12.217883 0.000000 0.000000 0.000000 0.0000
8006 72.864 15.216038 0.000000 0.000000 0.000000 0.000000 0.0000
8007 91.080 0.000000 0.000000 0.000000 0.000000 0.000000 0.0000
obj obj_rank
7998 79.105640 7964.5
7999 82.082070 7978.0
8000 85.080317 7987.5
8001 74.818700 7911.0
8002 76.722871 7945.5
8003 79.244606 7968.0
8004 82.105482 7982.5
8005 85.081883 7993.0
8006 88.080038 7998.5
8007 91.080000 8005.0
[([5, 1, 0, 0, 4, 0, 0], [0, 0, 0, 1, 3, 1, 5]), ([1, 1, 1, 5, 0, 1, 1], [0, 0, 0, 4, 2, 0, 4]), ([1, 5, 2, 1, 1, 0, 0], [2, 0, 1, 3, 4, 0, 0]), ([5, 0, 0, 0, 2, 2, 1], [1, 2, 5, 0, 1, 1, 0]), ([1, 3, 2, 2, 0, 0, 2], [1, 0, 3, 0, 3, 1, 2]), ([4, 2, 0, 2, 0, 0, 2], [0, 0, 2, 6, 1, 0, 1]), ([2, 2, 1, 0, 0, 5, 0], [1, 0, 0, 1, 0, 4, 4]), ([0, 2, 5, 0, 0, 2, 1], [0, 7, 0, 2, 0, 1, 0]), ([0, 4, 1, 3, 1, 1, 0], [0, 0, 3, 5, 2, 0, 0]), ([0, 4, 0, 1, 0, 2, 3], [0, 0, 0, 0, 7, 2, 1]), ([1, 0, 3, 0, 0, 6, 0], [1, 4, 0, 0, 1, 1, 3]), ([0, 5, 3, 0, 0, 0, 2], [1, 3, 1, 0, 0, 4, 1]), ([1, 4, 1, 1, 0, 3, 0], [0, 6, 1, 0, 0, 2, 1]), ([1, 0, 0, 4, 0, 5, 0], [0, 3, 4, 0, 0, 1, 2]), ([1, 1, 4, 0, 4, 0, 0], [0, 0, 1, 1, 0, 5, 3])]
[45.83469008087551, 31.925612362191632, 54.48618002021776, 32.93638214007037, 31.903947786090264, 37.57795844587057, 30.66425452134402, 42.27419483282949, 50.09276646866054, 30.186601175812214, 37.626163200000036, 54.50807016895987, 34.52406351149628, 45.021314621440034, 37.594862272023136]
[44.48115841679363, 50.77680848523785, 35.95409687035906, 42.260359046783755, 28.748981569010112, 58.61060969191143, 44.81995776000003, 67.28540959597044, 64.99914126947522, 79.08352442141943, 23.652114129588547, 28.693674219611005, 51.50006045332245, 40.97090533376224, 48.03287604838404]
[1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0]
[[2, 2, 1, 3, 1, 0, 1], [0, 4, 2, 0, 1, 3, 0], [1, 0, 0, 0, 2, 7, 0], [3, 0, 0, 0, 2, 5, 0], [0, 6, 0, 0, 1, 0, 3], [0, 0, 0, 5, 4, 0, 1], [0, 0, 0, 0, 1, 1, 8], [3, 1, 2, 3, 0, 0, 1], [3, 0, 1, 1, 1, 4, 0], [1, 3, 1, 4, 1, 0, 0]]
[([2, 2, 1, 3, 1, 0, 1], [2, 2, 1, 3, 0, 0, 2]), ([0, 4, 2, 0, 1, 3, 0], [0, 5, 1, 0, 1, 3, 0]), ([1, 0, 0, 0, 2, 7, 0], [1, 0, 0, 0, 3, 6, 0]), ([3, 0, 0, 0, 2, 5, 0], [2, 0, 0, 1, 2, 5, 0]), ([0, 6, 0, 0, 1, 0, 3], [1, 6, 0, 0, 0, 0, 3]), ([0, 0, 0, 5, 4, 0, 1], [0, 0, 0, 6, 4, 0, 0]), ([0, 0, 0, 0, 1, 1, 8], [0, 0, 0, 0, 2, 0, 8]), ([3, 1, 2, 3, 0, 0, 1], [3, 1, 2, 2, 0, 0, 2]), ([3, 0, 1, 1, 1, 4, 0], [3, 0, 1, 1, 1, 3, 1]), ([1, 3, 1, 4, 1, 0, 0], [2, 3, 1, 4, 0, 0, 0])]
5963 29.935620
2689 42.942470
3029 54.867392
6312 35.795831
2882 43.738361
248 70.151718
12 60.141760
6793 35.779291
6450 24.214863
4681 41.492811
Name: obj, dtype: float64
Schedule: [0, 4, 2, 0, 1, 3, 0], ewt = 42.94246990549784
Flow chart of objective calculation method
Prepare data
# Prepare the dataset
= []
X for pair in sch_data.pair_list:
0] + pair[1])
X.append(pair[
= np.array(X)
X = np.array(sch_data.lables_rank)
y
# Split the dataset into training and test sets
= train_test_split(X, y, test_size=0.2, random_state=42) X_train, X_test, y_train, y_test
Train and evaluate model
def fit_and_score(estimator, X_train, X_test, y_train, y_test):
"""Fit the estimator on the train set and score it on both sets"""
=[(X_test, y_test)])
estimator.fit(X_train, y_train, eval_set
= estimator.score(X_train, y_train)
train_score = estimator.score(X_test, y_test)
test_score
return estimator, train_score, test_score
= StratifiedKFold(n_splits=5, shuffle=True, random_state=94)
cv
# Initialize the XGBClassifier without early stopping here
= xgb.XGBClassifier(
clf ="hist",
tree_method=6,
max_depth=1,
min_child_weight=0.1,
gamma=0.8,
subsample=0.8,
colsample_bytree=0.1,
learning_rate=100,
n_estimators=10
early_stopping_rounds
)
= []
results
for train_idx, test_idx in cv.split(X, y):
= X[train_idx], X[test_idx]
X_train, X_test = y[train_idx], y[test_idx]
y_train, y_test
= fit_and_score(
est, train_score, test_score
clone(clf), X_train, X_test, y_train, y_test
)
results.append((est, train_score, test_score))
# Print results
for i, (est, train_score, test_score) in enumerate(results):
print(f"Fold {i+1} - Train Score: {train_score:.4f}, Test Score: {test_score:.4f}")
[0] validation_0-logloss:0.66704
[1] validation_0-logloss:0.64197
[2] validation_0-logloss:0.62004
[3] validation_0-logloss:0.60129
[4] validation_0-logloss:0.58462
[5] validation_0-logloss:0.56863
[6] validation_0-logloss:0.55463
[7] validation_0-logloss:0.54278
[8] validation_0-logloss:0.53183
[9] validation_0-logloss:0.52202
[10] validation_0-logloss:0.51265
[11] validation_0-logloss:0.50330
[12] validation_0-logloss:0.49556
[13] validation_0-logloss:0.48886
[14] validation_0-logloss:0.48178
[15] validation_0-logloss:0.47352
[16] validation_0-logloss:0.46686
[17] validation_0-logloss:0.45823
[18] validation_0-logloss:0.45144
[19] validation_0-logloss:0.44638
[20] validation_0-logloss:0.44170
[21] validation_0-logloss:0.43606
[22] validation_0-logloss:0.43121
[23] validation_0-logloss:0.42626
[24] validation_0-logloss:0.42026
[25] validation_0-logloss:0.41398
[26] validation_0-logloss:0.40977
[27] validation_0-logloss:0.40508
[28] validation_0-logloss:0.40069
[29] validation_0-logloss:0.39565
[30] validation_0-logloss:0.39144
[31] validation_0-logloss:0.38700
[32] validation_0-logloss:0.38395
[33] validation_0-logloss:0.38027
[34] validation_0-logloss:0.37639
[35] validation_0-logloss:0.37305
[36] validation_0-logloss:0.36950
[37] validation_0-logloss:0.36542
[38] validation_0-logloss:0.36197
[39] validation_0-logloss:0.35907
[40] validation_0-logloss:0.35659
[41] validation_0-logloss:0.35246
[42] validation_0-logloss:0.35002
[43] validation_0-logloss:0.34677
[44] validation_0-logloss:0.34350
[45] validation_0-logloss:0.34061
[46] validation_0-logloss:0.33761
[47] validation_0-logloss:0.33511
[48] validation_0-logloss:0.33288
[49] validation_0-logloss:0.33082
[50] validation_0-logloss:0.32858
[51] validation_0-logloss:0.32658
[52] validation_0-logloss:0.32332
[53] validation_0-logloss:0.32046
[54] validation_0-logloss:0.31715
[55] validation_0-logloss:0.31561
[56] validation_0-logloss:0.31377
[57] validation_0-logloss:0.31075
[58] validation_0-logloss:0.30913
[59] validation_0-logloss:0.30729
[60] validation_0-logloss:0.30555
[61] validation_0-logloss:0.30361
[62] validation_0-logloss:0.30166
[63] validation_0-logloss:0.29909
[64] validation_0-logloss:0.29692
[65] validation_0-logloss:0.29539
[66] validation_0-logloss:0.29365
[67] validation_0-logloss:0.29163
[68] validation_0-logloss:0.28929
[69] validation_0-logloss:0.28798
[70] validation_0-logloss:0.28646
[71] validation_0-logloss:0.28538
[72] validation_0-logloss:0.28410
[73] validation_0-logloss:0.28237
[74] validation_0-logloss:0.28089
[75] validation_0-logloss:0.27960
[76] validation_0-logloss:0.27821
[77] validation_0-logloss:0.27676
[78] validation_0-logloss:0.27519
[79] validation_0-logloss:0.27396
[80] validation_0-logloss:0.27272
[81] validation_0-logloss:0.27075
[82] validation_0-logloss:0.26935
[83] validation_0-logloss:0.26826
[84] validation_0-logloss:0.26589
[85] validation_0-logloss:0.26479
[86] validation_0-logloss:0.26373
[87] validation_0-logloss:0.26270
[88] validation_0-logloss:0.26089
[89] validation_0-logloss:0.25961
[90] validation_0-logloss:0.25764
[91] validation_0-logloss:0.25653
[92] validation_0-logloss:0.25540
[93] validation_0-logloss:0.25509
[94] validation_0-logloss:0.25369
[95] validation_0-logloss:0.25280
[96] validation_0-logloss:0.25173
[97] validation_0-logloss:0.25119
[98] validation_0-logloss:0.24985
[99] validation_0-logloss:0.24888
[0] validation_0-logloss:0.66754
[1] validation_0-logloss:0.64366
[2] validation_0-logloss:0.62159
[3] validation_0-logloss:0.60301
[4] validation_0-logloss:0.58811
[5] validation_0-logloss:0.57304
[6] validation_0-logloss:0.55969
[7] validation_0-logloss:0.54666
[8] validation_0-logloss:0.53685
[9] validation_0-logloss:0.52678
[10] validation_0-logloss:0.51696
[11] validation_0-logloss:0.50823
[12] validation_0-logloss:0.49949
[13] validation_0-logloss:0.49091
[14] validation_0-logloss:0.48345
[15] validation_0-logloss:0.47603
[16] validation_0-logloss:0.46972
[17] validation_0-logloss:0.46359
[18] validation_0-logloss:0.45784
[19] validation_0-logloss:0.45178
[20] validation_0-logloss:0.44530
[21] validation_0-logloss:0.43673
[22] validation_0-logloss:0.43163
[23] validation_0-logloss:0.42648
[24] validation_0-logloss:0.42185
[25] validation_0-logloss:0.41540
[26] validation_0-logloss:0.41039
[27] validation_0-logloss:0.40655
[28] validation_0-logloss:0.40221
[29] validation_0-logloss:0.39804
[30] validation_0-logloss:0.39286
[31] validation_0-logloss:0.38754
[32] validation_0-logloss:0.38334
[33] validation_0-logloss:0.37859
[34] validation_0-logloss:0.37467
[35] validation_0-logloss:0.37123
[36] validation_0-logloss:0.36770
[37] validation_0-logloss:0.36266
[38] validation_0-logloss:0.35911
[39] validation_0-logloss:0.35462
[40] validation_0-logloss:0.35186
[41] validation_0-logloss:0.34792
[42] validation_0-logloss:0.34556
[43] validation_0-logloss:0.34182
[44] validation_0-logloss:0.33925
[45] validation_0-logloss:0.33623
[46] validation_0-logloss:0.33320
[47] validation_0-logloss:0.32991
[48] validation_0-logloss:0.32819
[49] validation_0-logloss:0.32630
[50] validation_0-logloss:0.32388
[51] validation_0-logloss:0.32128
[52] validation_0-logloss:0.31842
[53] validation_0-logloss:0.31606
[54] validation_0-logloss:0.31387
[55] validation_0-logloss:0.31268
[56] validation_0-logloss:0.30981
[57] validation_0-logloss:0.30799
[58] validation_0-logloss:0.30569
[59] validation_0-logloss:0.30370
[60] validation_0-logloss:0.30175
[61] validation_0-logloss:0.29983
[62] validation_0-logloss:0.29820
[63] validation_0-logloss:0.29610
[64] validation_0-logloss:0.29431
[65] validation_0-logloss:0.29239
[66] validation_0-logloss:0.29053
[67] validation_0-logloss:0.28896
[68] validation_0-logloss:0.28771
[69] validation_0-logloss:0.28484
[70] validation_0-logloss:0.28368
[71] validation_0-logloss:0.28234
[72] validation_0-logloss:0.28114
[73] validation_0-logloss:0.28019
[74] validation_0-logloss:0.27826
[75] validation_0-logloss:0.27630
[76] validation_0-logloss:0.27476
[77] validation_0-logloss:0.27332
[78] validation_0-logloss:0.27228
[79] validation_0-logloss:0.27049
[80] validation_0-logloss:0.26843
[81] validation_0-logloss:0.26696
[82] validation_0-logloss:0.26552
[83] validation_0-logloss:0.26416
[84] validation_0-logloss:0.26295
[85] validation_0-logloss:0.26173
[86] validation_0-logloss:0.26066
[87] validation_0-logloss:0.25959
[88] validation_0-logloss:0.25843
[89] validation_0-logloss:0.25704
[90] validation_0-logloss:0.25605
[91] validation_0-logloss:0.25484
[92] validation_0-logloss:0.25361
[93] validation_0-logloss:0.25264
[94] validation_0-logloss:0.25173
[95] validation_0-logloss:0.25067
[96] validation_0-logloss:0.24935
[97] validation_0-logloss:0.24842
[98] validation_0-logloss:0.24678
[99] validation_0-logloss:0.24569
[0] validation_0-logloss:0.66716
[1] validation_0-logloss:0.64197
[2] validation_0-logloss:0.61964
[3] validation_0-logloss:0.60098
[4] validation_0-logloss:0.58559
[5] validation_0-logloss:0.57006
[6] validation_0-logloss:0.55697
[7] validation_0-logloss:0.54281
[8] validation_0-logloss:0.53207
[9] validation_0-logloss:0.52178
[10] validation_0-logloss:0.51093
[11] validation_0-logloss:0.50144
[12] validation_0-logloss:0.49250
[13] validation_0-logloss:0.48505
[14] validation_0-logloss:0.47757
[15] validation_0-logloss:0.46979
[16] validation_0-logloss:0.46208
[17] validation_0-logloss:0.45400
[18] validation_0-logloss:0.44816
[19] validation_0-logloss:0.44282
[20] validation_0-logloss:0.43705
[21] validation_0-logloss:0.43032
[22] validation_0-logloss:0.42467
[23] validation_0-logloss:0.41924
[24] validation_0-logloss:0.41367
[25] validation_0-logloss:0.40847
[26] validation_0-logloss:0.40533
[27] validation_0-logloss:0.40131
[28] validation_0-logloss:0.39791
[29] validation_0-logloss:0.39193
[30] validation_0-logloss:0.38833
[31] validation_0-logloss:0.38346
[32] validation_0-logloss:0.37982
[33] validation_0-logloss:0.37538
[34] validation_0-logloss:0.37042
[35] validation_0-logloss:0.36736
[36] validation_0-logloss:0.36423
[37] validation_0-logloss:0.35971
[38] validation_0-logloss:0.35530
[39] validation_0-logloss:0.35203
[40] validation_0-logloss:0.34851
[41] validation_0-logloss:0.34415
[42] validation_0-logloss:0.34082
[43] validation_0-logloss:0.33693
[44] validation_0-logloss:0.33408
[45] validation_0-logloss:0.33168
[46] validation_0-logloss:0.32832
[47] validation_0-logloss:0.32464
[48] validation_0-logloss:0.32291
[49] validation_0-logloss:0.32065
[50] validation_0-logloss:0.31809
[51] validation_0-logloss:0.31611
[52] validation_0-logloss:0.31384
[53] validation_0-logloss:0.31161
[54] validation_0-logloss:0.30955
[55] validation_0-logloss:0.30677
[56] validation_0-logloss:0.30409
[57] validation_0-logloss:0.30121
[58] validation_0-logloss:0.29908
[59] validation_0-logloss:0.29649
[60] validation_0-logloss:0.29427
[61] validation_0-logloss:0.29233
[62] validation_0-logloss:0.29096
[63] validation_0-logloss:0.28932
[64] validation_0-logloss:0.28800
[65] validation_0-logloss:0.28596
[66] validation_0-logloss:0.28452
[67] validation_0-logloss:0.28181
[68] validation_0-logloss:0.28003
[69] validation_0-logloss:0.27823
[70] validation_0-logloss:0.27653
[71] validation_0-logloss:0.27482
[72] validation_0-logloss:0.27305
[73] validation_0-logloss:0.27103
[74] validation_0-logloss:0.26926
[75] validation_0-logloss:0.26845
[76] validation_0-logloss:0.26723
[77] validation_0-logloss:0.26573
[78] validation_0-logloss:0.26405
[79] validation_0-logloss:0.26299
[80] validation_0-logloss:0.26201
[81] validation_0-logloss:0.26009
[82] validation_0-logloss:0.25872
[83] validation_0-logloss:0.25748
[84] validation_0-logloss:0.25512
[85] validation_0-logloss:0.25418
[86] validation_0-logloss:0.25303
[87] validation_0-logloss:0.25181
[88] validation_0-logloss:0.25062
[89] validation_0-logloss:0.24941
[90] validation_0-logloss:0.24796
[91] validation_0-logloss:0.24648
[92] validation_0-logloss:0.24520
[93] validation_0-logloss:0.24342
[94] validation_0-logloss:0.24220
[95] validation_0-logloss:0.24141
[96] validation_0-logloss:0.24023
[97] validation_0-logloss:0.23914
[98] validation_0-logloss:0.23845
[99] validation_0-logloss:0.23752
[0] validation_0-logloss:0.66855
[1] validation_0-logloss:0.64549
[2] validation_0-logloss:0.62624
[3] validation_0-logloss:0.60934
[4] validation_0-logloss:0.59387
[5] validation_0-logloss:0.57887
[6] validation_0-logloss:0.56657
[7] validation_0-logloss:0.55343
[8] validation_0-logloss:0.54458
[9] validation_0-logloss:0.53453
[10] validation_0-logloss:0.52476
[11] validation_0-logloss:0.51582
[12] validation_0-logloss:0.50666
[13] validation_0-logloss:0.49987
[14] validation_0-logloss:0.49205
[15] validation_0-logloss:0.48464
[16] validation_0-logloss:0.47812
[17] validation_0-logloss:0.46834
[18] validation_0-logloss:0.46099
[19] validation_0-logloss:0.45513
[20] validation_0-logloss:0.45025
[21] validation_0-logloss:0.44236
[22] validation_0-logloss:0.43700
[23] validation_0-logloss:0.43229
[24] validation_0-logloss:0.42635
[25] validation_0-logloss:0.42189
[26] validation_0-logloss:0.41804
[27] validation_0-logloss:0.41100
[28] validation_0-logloss:0.40673
[29] validation_0-logloss:0.40305
[30] validation_0-logloss:0.39805
[31] validation_0-logloss:0.39435
[32] validation_0-logloss:0.39012
[33] validation_0-logloss:0.38675
[34] validation_0-logloss:0.38231
[35] validation_0-logloss:0.37796
[36] validation_0-logloss:0.37341
[37] validation_0-logloss:0.36886
[38] validation_0-logloss:0.36559
[39] validation_0-logloss:0.36265
[40] validation_0-logloss:0.35960
[41] validation_0-logloss:0.35668
[42] validation_0-logloss:0.35365
[43] validation_0-logloss:0.34970
[44] validation_0-logloss:0.34689
[45] validation_0-logloss:0.34424
[46] validation_0-logloss:0.34120
[47] validation_0-logloss:0.33856
[48] validation_0-logloss:0.33524
[49] validation_0-logloss:0.33194
[50] validation_0-logloss:0.32913
[51] validation_0-logloss:0.32666
[52] validation_0-logloss:0.32481
[53] validation_0-logloss:0.32262
[54] validation_0-logloss:0.31989
[55] validation_0-logloss:0.31722
[56] validation_0-logloss:0.31539
[57] validation_0-logloss:0.31337
[58] validation_0-logloss:0.31084
[59] validation_0-logloss:0.30813
[60] validation_0-logloss:0.30521
[61] validation_0-logloss:0.30331
[62] validation_0-logloss:0.30162
[63] validation_0-logloss:0.29953
[64] validation_0-logloss:0.29724
[65] validation_0-logloss:0.29560
[66] validation_0-logloss:0.29387
[67] validation_0-logloss:0.29233
[68] validation_0-logloss:0.28957
[69] validation_0-logloss:0.28768
[70] validation_0-logloss:0.28600
[71] validation_0-logloss:0.28472
[72] validation_0-logloss:0.28296
[73] validation_0-logloss:0.28075
[74] validation_0-logloss:0.27859
[75] validation_0-logloss:0.27731
[76] validation_0-logloss:0.27532
[77] validation_0-logloss:0.27388
[78] validation_0-logloss:0.27285
[79] validation_0-logloss:0.27144
[80] validation_0-logloss:0.27034
[81] validation_0-logloss:0.26838
[82] validation_0-logloss:0.26693
[83] validation_0-logloss:0.26556
[84] validation_0-logloss:0.26442
[85] validation_0-logloss:0.26327
[86] validation_0-logloss:0.26220
[87] validation_0-logloss:0.26138
[88] validation_0-logloss:0.26030
[89] validation_0-logloss:0.25889
[90] validation_0-logloss:0.25731
[91] validation_0-logloss:0.25634
[92] validation_0-logloss:0.25515
[93] validation_0-logloss:0.25355
[94] validation_0-logloss:0.25144
[95] validation_0-logloss:0.25061
[96] validation_0-logloss:0.24973
[97] validation_0-logloss:0.24850
[98] validation_0-logloss:0.24769
[99] validation_0-logloss:0.24648
[0] validation_0-logloss:0.66795
[1] validation_0-logloss:0.64307
[2] validation_0-logloss:0.62180
[3] validation_0-logloss:0.60377
[4] validation_0-logloss:0.58766
[5] validation_0-logloss:0.57282
[6] validation_0-logloss:0.55964
[7] validation_0-logloss:0.54693
[8] validation_0-logloss:0.53543
[9] validation_0-logloss:0.52383
[10] validation_0-logloss:0.51443
[11] validation_0-logloss:0.50582
[12] validation_0-logloss:0.49701
[13] validation_0-logloss:0.48842
[14] validation_0-logloss:0.48054
[15] validation_0-logloss:0.47320
[16] validation_0-logloss:0.46709
[17] validation_0-logloss:0.46060
[18] validation_0-logloss:0.45414
[19] validation_0-logloss:0.44762
[20] validation_0-logloss:0.44270
[21] validation_0-logloss:0.43368
[22] validation_0-logloss:0.42713
[23] validation_0-logloss:0.42058
[24] validation_0-logloss:0.41481
[25] validation_0-logloss:0.41125
[26] validation_0-logloss:0.40636
[27] validation_0-logloss:0.40197
[28] validation_0-logloss:0.39744
[29] validation_0-logloss:0.39349
[30] validation_0-logloss:0.38787
[31] validation_0-logloss:0.38425
[32] validation_0-logloss:0.37907
[33] validation_0-logloss:0.37483
[34] validation_0-logloss:0.37053
[35] validation_0-logloss:0.36611
[36] validation_0-logloss:0.36280
[37] validation_0-logloss:0.35829
[38] validation_0-logloss:0.35528
[39] validation_0-logloss:0.35034
[40] validation_0-logloss:0.34683
[41] validation_0-logloss:0.34295
[42] validation_0-logloss:0.33934
[43] validation_0-logloss:0.33620
[44] validation_0-logloss:0.33224
[45] validation_0-logloss:0.32995
[46] validation_0-logloss:0.32745
[47] validation_0-logloss:0.32488
[48] validation_0-logloss:0.32174
[49] validation_0-logloss:0.31918
[50] validation_0-logloss:0.31699
[51] validation_0-logloss:0.31453
[52] validation_0-logloss:0.31189
[53] validation_0-logloss:0.30938
[54] validation_0-logloss:0.30604
[55] validation_0-logloss:0.30330
[56] validation_0-logloss:0.30019
[57] validation_0-logloss:0.29851
[58] validation_0-logloss:0.29711
[59] validation_0-logloss:0.29530
[60] validation_0-logloss:0.29295
[61] validation_0-logloss:0.29074
[62] validation_0-logloss:0.28918
[63] validation_0-logloss:0.28726
[64] validation_0-logloss:0.28528
[65] validation_0-logloss:0.28291
[66] validation_0-logloss:0.28100
[67] validation_0-logloss:0.27956
[68] validation_0-logloss:0.27750
[69] validation_0-logloss:0.27565
[70] validation_0-logloss:0.27258
[71] validation_0-logloss:0.27093
[72] validation_0-logloss:0.26875
[73] validation_0-logloss:0.26720
[74] validation_0-logloss:0.26488
[75] validation_0-logloss:0.26316
[76] validation_0-logloss:0.26194
[77] validation_0-logloss:0.26030
[78] validation_0-logloss:0.25878
[79] validation_0-logloss:0.25753
[80] validation_0-logloss:0.25643
[81] validation_0-logloss:0.25464
[82] validation_0-logloss:0.25307
[83] validation_0-logloss:0.25178
[84] validation_0-logloss:0.25069
[85] validation_0-logloss:0.24891
[86] validation_0-logloss:0.24698
[87] validation_0-logloss:0.24570
[88] validation_0-logloss:0.24463
[89] validation_0-logloss:0.24359
[90] validation_0-logloss:0.24207
[91] validation_0-logloss:0.24063
[92] validation_0-logloss:0.23935
[93] validation_0-logloss:0.23778
[94] validation_0-logloss:0.23666
[95] validation_0-logloss:0.23569
[96] validation_0-logloss:0.23482
[97] validation_0-logloss:0.23359
[98] validation_0-logloss:0.23215
[99] validation_0-logloss:0.23141
Fold 1 - Train Score: 0.9662, Test Score: 0.9116
Fold 2 - Train Score: 0.9666, Test Score: 0.9175
Fold 3 - Train Score: 0.9677, Test Score: 0.9187
Fold 4 - Train Score: 0.9670, Test Score: 0.9113
Fold 5 - Train Score: 0.9663, Test Score: 0.9291
Test
The model seems to be having trouble ranking schedules with close similarity consistently. This could indicate model overfitting. Solutions: create training sets with close pairs or add close pairs to the existing training data set.
# Fit the model on the entire dataset
# Initialize the XGBClassifier without early stopping here
= xgb.XGBClassifier(
clf ="hist",
tree_method=6,
max_depth=1,
min_child_weight=0.1,
gamma=0.8,
subsample=0.8,
colsample_bytree=0.1,
learning_rate=100
n_estimators
)
clf.fit(X, y)
= [
input_X 9, 1, 0, 0, 0, 0, 0], [8, 2, 0, 0, 0, 0, 0]),
([8, 2, 0, 0, 0, 0, 0], [9, 1, 0, 0, 0, 0, 0]),
([3, 1, 1, 1, 0, 1, 2], [2, 1, 1, 1, 0, 1, 3]),
([2, 1, 1, 1, 0, 1, 3], [3, 1, 1, 1, 0, 1, 2]),
([3, 1, 1, 1, 0, 1, 2], [9, 1, 0, 0, 0, 0, 0]),
([8, 2, 0, 0, 0, 0, 0], [3, 1, 1, 1, 0, 1, 2])
([
]= []
X_new for pair in input_X:
0] + pair[1])
X_new.append(pair[
# Predict the target for new data
= clf.predict(X_new)
y_pred
# If you want to get the probability estimates
= clf.predict_proba(X_new)
y_pred_proba
print(f"y_pred = {y_pred}, \ny_pred_proba = \n{y_pred_proba}")
y_pred = [0 0 1 0 0 1],
y_pred_proba =
[[0.76242477 0.23757525]
[0.69124234 0.30875763]
[0.3242296 0.6757704 ]
[0.5743403 0.42565972]
[0.99786097 0.00213903]
[0.00121516 0.99878484]]