diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index f25b1ec7a2..eeeb207968 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -301,17 +301,56 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: @deco_retry def _get_eastmoney(): - url = "http://4.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&fs=m:105,m:106,m:107&fields=f12" - resp = requests.get(url, timeout=None) - if resp.status_code != 200: - raise ValueError("request error") + base_url = "http://4.push2.eastmoney.com/api/qt/clist/get" + params = { + "pn": 1, # page number + "pz": 100, # page size, default to 100 + "fs": "m:105,m:106,m:107", + "fields": "f12", + } - try: - _symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()] - except Exception as e: - logger.warning(f"request error: {e}") - raise + _symbols = [] + page = 1 + while True: + params["pn"] = page + try: + resp = requests.get(base_url, params=params, timeout=None) + resp.raise_for_status() + data = resp.json() + + # Check if response contains valid data + if not data or "data" not in data or not data["data"] or "diff" not in data["data"]: + logger.warning(f"Invalid response structure on page {page}") + break + + # fetch the current page data + current_symbols = [_v["f12"] for _v in data["data"]["diff"].values()] + + if not current_symbols: # It's the last page if there is no data in current page + logger.info(f"Last page reached: {page - 1}") + break + + _symbols.extend(current_symbols) + + # show progress + logger.info( + f"Page {page}: fetch {len(current_symbols)} stocks:[{current_symbols[0]} ... {current_symbols[-1]}]" + ) + + page += 1 + + # sleep time to avoid overloading the server + time.sleep(0.5) + + except requests.exceptions.HTTPError as e: + raise requests.exceptions.HTTPError( + f"Request to {base_url} failed with status code {resp.status_code}" + ) from e + except Exception as e: + logger.warning("An error occurred while extracting data from the response.") + raise + # If the number of symbols is less than the minimum required, raise an error if len(_symbols) < 8000: raise ValueError("request error") @@ -321,7 +360,7 @@ def _get_eastmoney(): def _get_nasdaq(): _res_symbols = [] for _name in ["otherlisted", "nasdaqtraded"]: - url = f"ftp://ftp.nasdaqtrader.com/SymbolDirectory/{_name}.txt" + url = f"https://www.nasdaqtrader.com/dynamic/SymDir/{_name}.txt" df = pd.read_csv(url, sep="|") df = df.rename(columns={"ACT Symbol": "Symbol"}) _symbols = df["Symbol"].dropna()