Skip to content

Commit fffdad2

Browse files
authored
Fixed #124 (#125)
* Fixed #124 Signed-off-by: Vivek Joshy <8206808+vivekjoshy@users.noreply.github.com> * Update dependencies Signed-off-by: Vivek Joshy <8206808+vivekjoshy@users.noreply.github.com> * Format with black Signed-off-by: Vivek Joshy <8206808+vivekjoshy@users.noreply.github.com> * Add changelog fragment Signed-off-by: Vivek Joshy <8206808+vivekjoshy@users.noreply.github.com> --------- Signed-off-by: Vivek Joshy <8206808+vivekjoshy@users.noreply.github.com>
1 parent f76df19 commit fffdad2

29 files changed

+664
-700
lines changed

benchmark/draw.ipynb

Lines changed: 33 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -341,9 +341,9 @@
341341
"data_directory = Path(working_directory / Path(\"data\"))\n",
342342
"data_directory.mkdir(exist_ok=True)\n",
343343
"downloader(\n",
344-
" url=\"doi:10.5281/zenodo.10344773/chess.csv\", \n",
344+
" url=\"doi:10.5281/zenodo.10344773/chess.csv\",\n",
345345
" output_file=data_directory / \"chess.csv\",\n",
346-
" pooch=None\n",
346+
" pooch=None,\n",
347347
")"
348348
]
349349
},
@@ -374,6 +374,7 @@
374374
" BLACK_WINS = 2\n",
375375
" STALEMATE = 3\n",
376376
"\n",
377+
"\n",
377378
"@dataclass(slots=True)\n",
378379
"class Player:\n",
379380
" name: str\n",
@@ -452,8 +453,8 @@
452453
" options=[m.__name__ for m in models],\n",
453454
" value=PlackettLuce.__name__,\n",
454455
" # rows=10,\n",
455-
" description='Model:',\n",
456-
" disabled=False\n",
456+
" description=\"Model:\",\n",
457+
" disabled=False,\n",
457458
")\n",
458459
"display(widget)"
459460
]
@@ -693,31 +694,19 @@
693694
"\n",
694695
"for match_index, row in train.iterrows():\n",
695696
" white_player = Player(name=row[\"white_username\"])\n",
696-
" black_player = Player(name=row[\"black_username\"]) \n",
697-
" players = {\n",
698-
" row[\"white_username\"]: white_player,\n",
699-
" row[\"black_username\"]: black_player\n",
700-
" }\n",
701-
" \n",
697+
" black_player = Player(name=row[\"black_username\"])\n",
698+
" players = {row[\"white_username\"]: white_player, row[\"black_username\"]: black_player}\n",
699+
"\n",
702700
" white_result = row[\"white_result\"]\n",
703701
" black_result = row[\"black_result\"]\n",
704-
" \n",
702+
"\n",
705703
" if white_result == \"win\":\n",
706-
" match = Match(\n",
707-
" result=Result.WHITE_WINS,\n",
708-
" players=players\n",
709-
" )\n",
704+
" match = Match(result=Result.WHITE_WINS, players=players)\n",
710705
" elif black_result == \"win\":\n",
711-
" match = Match(\n",
712-
" result=Result.BLACK_WINS,\n",
713-
" players=players\n",
714-
" )\n",
706+
" match = Match(result=Result.BLACK_WINS, players=players)\n",
715707
" else:\n",
716-
" match = Match(\n",
717-
" result=Result.STALEMATE,\n",
718-
" players=players\n",
719-
" )\n",
720-
" \n",
708+
" match = Match(result=Result.STALEMATE, players=players)\n",
709+
"\n",
721710
" train_matches.append(match)\n",
722711
" t.update(1)\n",
723712
"\n",
@@ -830,20 +819,20 @@
830819
" player_2_rating = openskill_players[player_2]\n",
831820
" team_1 = [player_1_rating]\n",
832821
" team_2 = [player_2_rating]\n",
833-
" \n",
822+
"\n",
834823
" if match.result == Result.WHITE_WINS:\n",
835824
" ranks = [1, 2]\n",
836825
" elif match.result == Result.BLACK_WINS:\n",
837826
" ranks = [2, 1]\n",
838827
" else:\n",
839828
" ranks = [1, 1]\n",
840-
" \n",
829+
"\n",
841830
" rated_teams = model.rate(teams=[team_1, team_2], ranks=ranks)\n",
842831
"\n",
843832
" for team in rated_teams:\n",
844833
" for player in team:\n",
845834
" openskill_players[player.name] = player\n",
846-
" \n",
835+
"\n",
847836
" t.update(1)\n",
848837
"\n",
849838
"_ = gc.collect()"
@@ -970,32 +959,20 @@
970959
"# Test Data\n",
971960
"test_matches: List[Match] = []\n",
972961
"\n",
973-
"for match_index, row in test.iterrows():\n",
962+
"for match_index, row in test.iterrows():\n",
974963
" white_player = Player(name=row[\"white_username\"])\n",
975-
" black_player = Player(name=row[\"black_username\"]) \n",
976-
" players = {\n",
977-
" row[\"white_username\"]: white_player,\n",
978-
" row[\"black_username\"]: black_player\n",
979-
" }\n",
980-
" \n",
964+
" black_player = Player(name=row[\"black_username\"])\n",
965+
" players = {row[\"white_username\"]: white_player, row[\"black_username\"]: black_player}\n",
966+
"\n",
981967
" white_result = row[\"white_result\"]\n",
982968
" black_result = row[\"black_result\"]\n",
983-
" \n",
969+
"\n",
984970
" if white_result == \"win\":\n",
985-
" match = Match(\n",
986-
" result=Result.WHITE_WINS,\n",
987-
" players=players\n",
988-
" )\n",
971+
" match = Match(result=Result.WHITE_WINS, players=players)\n",
989972
" elif black_result == \"win\":\n",
990-
" match = Match(\n",
991-
" result=Result.BLACK_WINS,\n",
992-
" players=players\n",
993-
" )\n",
973+
" match = Match(result=Result.BLACK_WINS, players=players)\n",
994974
" else:\n",
995-
" match = Match(\n",
996-
" result=Result.STALEMATE,\n",
997-
" players=players\n",
998-
" )\n",
975+
" match = Match(result=Result.STALEMATE, players=players)\n",
999976
"\n",
1000977
" test_matches.append(match)\n",
1001978
" t.update(1)"
@@ -1063,7 +1040,6 @@
10631040
}
10641041
],
10651042
"source": [
1066-
"\n",
10671043
"# Predict OpenSkill Matches\n",
10681044
"print(\"Predict Matches in Test Set using OpenSkill:\")\n",
10691045
"t = tqdm(total=len(test_matches))\n",
@@ -1078,24 +1054,24 @@
10781054
" draw = True\n",
10791055
" else:\n",
10801056
" draw = False\n",
1081-
" \n",
1057+
"\n",
10821058
" player_1, player_2 = match.players.keys()\n",
1083-
" \n",
1059+
"\n",
10841060
" if player_1 in openskill_players:\n",
10851061
" player_1_rating = openskill_players[player_1]\n",
10861062
" else:\n",
10871063
" player_1_rating = model.rating(name=player_1)\n",
1088-
" \n",
1064+
"\n",
10891065
" if player_2 in openskill_players:\n",
10901066
" player_2_rating = openskill_players[player_2]\n",
10911067
" else:\n",
10921068
" player_2_rating = model.rating(name=player_2)\n",
1093-
" \n",
1069+
"\n",
10941070
" teams = [[player_1_rating], [player_2_rating]]\n",
1095-
" \n",
1071+
"\n",
10961072
" white_win_probability, black_win_probability = model.predict_win(teams)\n",
10971073
" draw_probability = model.predict_draw(teams)\n",
1098-
" \n",
1074+
"\n",
10991075
" if draw_probability > (white_win_probability + black_win_probability):\n",
11001076
" if draw:\n",
11011077
" openskill_correct_predictions += 1\n",
@@ -1106,7 +1082,7 @@
11061082
" openskill_correct_predictions += 1\n",
11071083
" else:\n",
11081084
" openskill_incorrect_predictions += 1\n",
1109-
" \n",
1085+
"\n",
11101086
" t.update(1)"
11111087
]
11121088
},
@@ -1184,11 +1160,8 @@
11841160
"\n",
11851161
"openskill_accuracy = round(\n",
11861162
" (\n",
1187-
" openskill_correct_predictions\n",
1188-
" / (\n",
1189-
" openskill_incorrect_predictions\n",
1190-
" + openskill_correct_predictions\n",
1191-
" )\n",
1163+
" openskill_correct_predictions\n",
1164+
" / (openskill_incorrect_predictions + openskill_correct_predictions)\n",
11921165
" )\n",
11931166
" * 100,\n",
11941167
" 2,\n",

benchmark/rank.ipynb

Lines changed: 28 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -334,14 +334,14 @@
334334
"data_directory = Path(working_directory / Path(\"data\"))\n",
335335
"data_directory.mkdir(exist_ok=True)\n",
336336
"downloader(\n",
337-
" url=\"doi:10.5281/zenodo.10342317/train.parquet\", \n",
337+
" url=\"doi:10.5281/zenodo.10342317/train.parquet\",\n",
338338
" output_file=data_directory / \"train.parquet\",\n",
339-
" pooch=None\n",
339+
" pooch=None,\n",
340340
")\n",
341341
"downloader(\n",
342-
" url=\"doi:10.5281/zenodo.10342317/test.parquet\", \n",
343-
" output_file=data_directory / \"test.parquet\", \n",
344-
" pooch=None\n",
342+
" url=\"doi:10.5281/zenodo.10342317/test.parquet\",\n",
343+
" output_file=data_directory / \"test.parquet\",\n",
344+
" pooch=None,\n",
345345
")"
346346
]
347347
},
@@ -473,8 +473,8 @@
473473
" options=[m.__name__ for m in models],\n",
474474
" value=PlackettLuce.__name__,\n",
475475
" # rows=10,\n",
476-
" description='Model:',\n",
477-
" disabled=False\n",
476+
" description=\"Model:\",\n",
477+
" disabled=False,\n",
478478
")\n",
479479
"display(widget)"
480480
]
@@ -634,14 +634,14 @@
634634
],
635635
"source": [
636636
"def reduce_memory_usage_pl(df, name):\n",
637-
" \"\"\" \n",
637+
" \"\"\"\n",
638638
" Reduce memory usage by polars dataframe {df} with name {name} by changing its data types.\n",
639-
" Original pandas version of this function: \n",
639+
" Original pandas version of this function:\n",
640640
" https://www.kaggle.com/code/arjanso/reducing-dataframe-memory-size-by-65\n",
641641
" \"\"\"\n",
642642
" print(f\"Memory usage of dataframe {name} is {round(df.estimated_size('mb'), 2)} MB\")\n",
643-
" Numeric_Int_types = [pl.Int8,pl.Int16,pl.Int32,pl.Int64]\n",
644-
" Numeric_Float_types = [pl.Float32,pl.Float64] \n",
643+
" Numeric_Int_types = [pl.Int8, pl.Int16, pl.Int32, pl.Int64]\n",
644+
" Numeric_Float_types = [pl.Float32, pl.Float64]\n",
645645
" for col in df.columns:\n",
646646
" col_type = df[col].dtype\n",
647647
" c_min = df[col].min()\n",
@@ -664,7 +664,9 @@
664664
" df = df.with_columns(df[col].cast(pl.Categorical))\n",
665665
" else:\n",
666666
" pass\n",
667-
" print(f\"Memory usage of dataframe {name} became {round(df.estimated_size('mb'), 2)} MB\")\n",
667+
" print(\n",
668+
" f\"Memory usage of dataframe {name} became {round(df.estimated_size('mb'), 2)} MB\"\n",
669+
" )\n",
668670
" return df\n",
669671
"\n",
670672
"\n",
@@ -814,31 +816,28 @@
814816
" player = Player(\n",
815817
" name=raw_player[\"player_name\"],\n",
816818
" kill_ratio=raw_player[\"kill_ratio\"],\n",
817-
" assist_ratio=raw_player[\"assist_ratio\"]\n",
819+
" assist_ratio=raw_player[\"assist_ratio\"],\n",
818820
" )\n",
819821
"\n",
820-
" match_id = raw_player['match_id']\n",
821-
" team_id = raw_player['team_id']\n",
822+
" match_id = raw_player[\"match_id\"]\n",
823+
" team_id = raw_player[\"team_id\"]\n",
822824
" if match_id not in train_matches:\n",
823825
" team = Team(\n",
824826
" id=raw_player[\"team_id\"],\n",
825827
" match_id=raw_player[\"match_id\"],\n",
826828
" rank=raw_player[\"team_placement\"],\n",
827-
" players={player.name: player}\n",
829+
" players={player.name: player},\n",
828830
" )\n",
829831
"\n",
830-
" match = Match(\n",
831-
" id=raw_player[\"match_id\"],\n",
832-
" teams={team.id: team}\n",
833-
" )\n",
832+
" match = Match(id=raw_player[\"match_id\"], teams={team.id: team})\n",
834833
" else:\n",
835834
" if team_id not in train_matches[match_id].teams:\n",
836835
" match = train_matches[match_id]\n",
837836
" team = Team(\n",
838837
" id=raw_player[\"team_id\"],\n",
839838
" match_id=raw_player[\"match_id\"],\n",
840839
" rank=raw_player[\"team_placement\"],\n",
841-
" players={player.name: player}\n",
840+
" players={player.name: player},\n",
842841
" )\n",
843842
" match.teams[team_id] = team\n",
844843
" else:\n",
@@ -1131,31 +1130,28 @@
11311130
" player = Player(\n",
11321131
" name=raw_player[\"player_name\"],\n",
11331132
" kill_ratio=raw_player[\"kill_ratio\"],\n",
1134-
" assist_ratio=raw_player[\"assist_ratio\"]\n",
1133+
" assist_ratio=raw_player[\"assist_ratio\"],\n",
11351134
" )\n",
11361135
"\n",
1137-
" match_id = raw_player['match_id']\n",
1138-
" team_id = raw_player['team_id']\n",
1136+
" match_id = raw_player[\"match_id\"]\n",
1137+
" team_id = raw_player[\"team_id\"]\n",
11391138
" if match_id not in test_matches:\n",
11401139
" team = Team(\n",
11411140
" id=raw_player[\"team_id\"],\n",
11421141
" match_id=raw_player[\"match_id\"],\n",
11431142
" rank=raw_player[\"team_placement\"],\n",
1144-
" players={player.name: player}\n",
1143+
" players={player.name: player},\n",
11451144
" )\n",
11461145
"\n",
1147-
" match = Match(\n",
1148-
" id=raw_player[\"match_id\"],\n",
1149-
" teams={team.id: team}\n",
1150-
" )\n",
1146+
" match = Match(id=raw_player[\"match_id\"], teams={team.id: team})\n",
11511147
" else:\n",
11521148
" if team_id not in test_matches[match_id].teams:\n",
11531149
" match = test_matches[match_id]\n",
11541150
" team = Team(\n",
11551151
" id=raw_player[\"team_id\"],\n",
11561152
" match_id=raw_player[\"match_id\"],\n",
11571153
" rank=raw_player[\"team_placement\"],\n",
1158-
" players={player.name: player}\n",
1154+
" players={player.name: player},\n",
11591155
" )\n",
11601156
" match.teams[team_id] = team\n",
11611157
" else:\n",
@@ -1228,7 +1224,6 @@
12281224
}
12291225
],
12301226
"source": [
1231-
"\n",
12321227
"# Predict OpenSkill Matches\n",
12331228
"print(\"Predict Matches in Test Set using OpenSkill:\")\n",
12341229
"t = tqdm(total=len(test_matches))\n",
@@ -1358,11 +1353,8 @@
13581353
"\n",
13591354
"openskill_accuracy = round(\n",
13601355
" (\n",
1361-
" openskill_correct_predictions\n",
1362-
" / (\n",
1363-
" openskill_incorrect_predictions\n",
1364-
" + openskill_correct_predictions\n",
1365-
" )\n",
1356+
" openskill_correct_predictions\n",
1357+
" / (openskill_incorrect_predictions + openskill_correct_predictions)\n",
13661358
" )\n",
13671359
" * 100,\n",
13681360
" 2,\n",

0 commit comments

Comments
 (0)