diff --git a/CHANGELOG.md b/CHANGELOG.md index b7755a76..a455f231 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,8 @@ +## [1.32.0-beta.5](https://github.com/ScrapeGraphAI/Scrapegraph-ai/compare/v1.32.0-beta.4...v1.32.0-beta.5) (2024-12-02) ## [1.32.0](https://github.com/ScrapeGraphAI/Scrapegraph-ai/compare/v1.31.1...v1.32.0) (2024-12-02) -### Features -* add API integration ([46373af](https://github.com/ScrapeGraphAI/Scrapegraph-ai/commit/46373afe6d8c05ad26039e68190f13d82b20a349)) ## [1.32.0-beta.4](https://github.com/ScrapeGraphAI/Scrapegraph-ai/compare/v1.32.0-beta.3...v1.32.0-beta.4) (2024-12-02) diff --git a/examples/extras/authenticated_playwright.py b/examples/extras/authenticated_playwright.py new file mode 100644 index 00000000..a4926bc7 --- /dev/null +++ b/examples/extras/authenticated_playwright.py @@ -0,0 +1,93 @@ +""" +Example leveraging a state file containing session cookies which +might be leveraged to authenticate to a website and scrape protected +content. +""" + +import os +import random +from dotenv import load_dotenv + +# import playwright so we can use it to create the state file +from playwright.async_api import async_playwright + +from scrapegraphai.graphs import OmniScraperGraph +from scrapegraphai.utils import prettify_exec_info + +load_dotenv() + +# ************************************************ +# Leveraging Playwright external to the invocation of the graph to +# login and create the state file +# ************************************************ + + +# note this is just an example and probably won't actually work on +# LinkedIn, the implementation of the login is highly dependent on the website +async def do_login(): + async with async_playwright() as playwright: + browser = await playwright.chromium.launch( + timeout=30000, + headless=False, + slow_mo=random.uniform(500, 1500), + ) + page = await browser.new_page() + + # very basic implementation of a login, in reality it may be trickier + await page.goto("https://www.linkedin.com/login") + await page.get_by_label("Email or phone").fill("some_bloke@some_domain.com") + await page.get_by_label("Password").fill("test1234") + await page.get_by_role("button", name="Sign in").click() + await page.wait_for_timeout(3000) + + # assuming a successful login, we save the cookies to a file + await page.context.storage_state(path="./state.json") + + +async def main(): + await do_login() + + # ************************************************ + # Define the configuration for the graph + # ************************************************ + + openai_api_key = os.getenv("OPENAI_APIKEY") + + graph_config = { + "llm": { + "api_key": openai_api_key, + "model": "openai/gpt-4o", + }, + "max_images": 10, + "headless": False, + # provide the path to the state file + "storage_state": "./state.json", + } + + # ************************************************ + # Create the OmniScraperGraph instance and run it + # ************************************************ + + omni_scraper_graph = OmniScraperGraph( + prompt="List me all the projects with their description.", + source="https://www.linkedin.com/feed/", + config=graph_config, + ) + + # the storage_state is used to load the cookies from the state file + # so we are authenticated and able to scrape protected content + result = omni_scraper_graph.run() + print(result) + + # ************************************************ + # Get graph execution info + # ************************************************ + + graph_exec_info = omni_scraper_graph.get_execution_info() + print(prettify_exec_info(graph_exec_info)) + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) diff --git a/examples/extras/undected_playwrigth.py b/examples/extras/undected_playwright.py similarity index 100% rename from examples/extras/undected_playwrigth.py rename to examples/extras/undected_playwright.py diff --git a/pyproject.toml b/pyproject.toml index 0034235d..1cd4a7b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,8 @@ name = "scrapegraphai" -version = "1.32.0" +version = "1.32.0b5" + diff --git a/scrapegraphai/docloaders/chromium.py b/scrapegraphai/docloaders/chromium.py index 3cc49e7f..255fc3c2 100644 --- a/scrapegraphai/docloaders/chromium.py +++ b/scrapegraphai/docloaders/chromium.py @@ -8,6 +8,7 @@ logger = get_logger("web-loader") + class ChromiumLoader(BaseLoader): """Scrapes HTML pages from URLs using a (headless) instance of the Chromium web driver with proxy protection. @@ -33,6 +34,7 @@ def __init__( proxy: Optional[Proxy] = None, load_state: str = "domcontentloaded", requires_js_support: bool = False, + storage_state: Optional[str] = None, **kwargs: Any, ): """Initialize the loader with a list of URL paths. @@ -62,6 +64,7 @@ def __init__( self.urls = urls self.load_state = load_state self.requires_js_support = requires_js_support + self.storage_state = storage_state async def ascrape_undetected_chromedriver(self, url: str) -> str: """ @@ -91,7 +94,9 @@ async def ascrape_undetected_chromedriver(self, url: str) -> str: attempt += 1 logger.error(f"Attempt {attempt} failed: {e}") if attempt == self.RETRY_LIMIT: - results = f"Error: Network error after {self.RETRY_LIMIT} attempts - {e}" + results = ( + f"Error: Network error after {self.RETRY_LIMIT} attempts - {e}" + ) finally: driver.quit() @@ -113,7 +118,9 @@ async def ascrape_playwright(self, url: str) -> str: browser = await p.chromium.launch( headless=self.headless, proxy=self.proxy, **self.browser_config ) - context = await browser.new_context() + context = await browser.new_context( + storage_state=self.storage_state + ) await Malenia.apply_stealth(context) page = await context.new_page() await page.goto(url, wait_until="domcontentloaded") @@ -125,10 +132,12 @@ async def ascrape_playwright(self, url: str) -> str: attempt += 1 logger.error(f"Attempt {attempt} failed: {e}") if attempt == self.RETRY_LIMIT: - raise RuntimeError(f"Failed to fetch {url} after {self.RETRY_LIMIT} attempts: {e}") + raise RuntimeError( + f"Failed to fetch {url} after {self.RETRY_LIMIT} attempts: {e}" + ) finally: - if 'browser' in locals(): - await browser.close() + if "browser" in locals(): + async def ascrape_with_js_support(self, url: str) -> str: """ @@ -138,7 +147,7 @@ async def ascrape_with_js_support(self, url: str) -> str: url (str): The URL to scrape. Returns: - str: The fully rendered HTML content after JavaScript execution, + str: The fully rendered HTML content after JavaScript execution, or an error message if an exception occurs. """ from playwright.async_api import async_playwright @@ -153,7 +162,9 @@ async def ascrape_with_js_support(self, url: str) -> str: browser = await p.chromium.launch( headless=self.headless, proxy=self.proxy, **self.browser_config ) - context = await browser.new_context() + context = await browser.new_context( + storage_state=self.storage_state + ) page = await context.new_page() await page.goto(url, wait_until="networkidle") results = await page.content() @@ -163,7 +174,9 @@ async def ascrape_with_js_support(self, url: str) -> str: attempt += 1 logger.error(f"Attempt {attempt} failed: {e}") if attempt == self.RETRY_LIMIT: - results = f"Error: Network error after {self.RETRY_LIMIT} attempts - {e}" + results = ( + f"Error: Network error after {self.RETRY_LIMIT} attempts - {e}" + ) finally: await browser.close() @@ -180,7 +193,9 @@ def lazy_load(self) -> Iterator[Document]: Document: The scraped content encapsulated within a Document object. """ scraping_fn = ( - self.ascrape_with_js_support if self.requires_js_support else getattr(self, f"ascrape_{self.backend}") + self.ascrape_with_js_support + if self.requires_js_support + else getattr(self, f"ascrape_{self.backend}") ) for url in self.urls: @@ -202,7 +217,9 @@ async def alazy_load(self) -> AsyncIterator[Document]: source URL as metadata. """ scraping_fn = ( - self.ascrape_with_js_support if self.requires_js_support else getattr(self, f"ascrape_{self.backend}") + self.ascrape_with_js_support + if self.requires_js_support + else getattr(self, f"ascrape_{self.backend}") ) tasks = [scraping_fn(url) for url in self.urls] diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index df5352b3..1148cc29 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -1,6 +1,7 @@ """ AbstractGraph Module """ + from abc import ABC, abstractmethod from typing import Optional import uuid @@ -9,12 +10,10 @@ from langchain.chat_models import init_chat_model from langchain_core.rate_limiters import InMemoryRateLimiter from ..helpers import models_tokens -from ..models import ( - OneApi, - DeepSeek -) +from ..models import OneApi, DeepSeek from ..utils.logging import set_verbosity_warning, set_verbosity_info + class AbstractGraph(ABC): """ Scaffolding class for creating a graph representation and executing it. @@ -39,14 +38,18 @@ class AbstractGraph(ABC): ... # Implementation of graph creation here ... return graph ... - >>> my_graph = MyGraph("Example Graph", + >>> my_graph = MyGraph("Example Graph", {"llm": {"model": "gpt-3.5-turbo"}}, "example_source") >>> result = my_graph.run() """ - def __init__(self, prompt: str, config: dict, - source: Optional[str] = None, schema: Optional[BaseModel] = None): - + def __init__( + self, + prompt: str, + config: dict, + source: Optional[str] = None, + schema: Optional[BaseModel] = None, + ): if config.get("llm").get("temperature") is None: config["llm"]["temperature"] = 0 @@ -55,14 +58,13 @@ def __init__(self, prompt: str, config: dict, self.config = config self.schema = schema self.llm_model = self._create_llm(config["llm"]) - self.verbose = False if config is None else config.get( - "verbose", False) - self.headless = True if self.config is None else config.get( - "headless", True) + self.verbose = False if config is None else config.get("verbose", False) + self.headless = True if self.config is None else config.get("headless", True) self.loader_kwargs = self.config.get("loader_kwargs", {}) self.cache_path = self.config.get("cache_path", False) self.browser_base = self.config.get("browser_base") self.scrape_do = self.config.get("scrape_do") + self.storage_state = self.config.get("storage_state") self.graph = self._create_graph() self.final_state = None @@ -81,7 +83,7 @@ def __init__(self, prompt: str, config: dict, "loader_kwargs": self.loader_kwargs, "llm_model": self.llm_model, "cache_path": self.cache_path, - } + } self.set_common_params(common_params, overwrite=True) @@ -129,7 +131,8 @@ def _create_llm(self, llm_config: dict) -> object: with warnings.catch_warnings(): warnings.simplefilter("ignore") llm_params["rate_limiter"] = InMemoryRateLimiter( - requests_per_second=requests_per_second) + requests_per_second=requests_per_second + ) if max_retries is not None: llm_params["max_retries"] = max_retries @@ -140,22 +143,45 @@ def _create_llm(self, llm_config: dict) -> object: raise KeyError("model_tokens not specified") from exc return llm_params["model_instance"] - known_providers = {"openai", "azure_openai", "google_genai", "google_vertexai", - "ollama", "oneapi", "nvidia", "groq", "anthropic", "bedrock", "mistralai", - "hugging_face", "deepseek", "ernie", "fireworks", "togetherai"} - - if '/' in llm_params["model"]: - split_model_provider = llm_params["model"].split("/", 1) - llm_params["model_provider"] = split_model_provider[0] - llm_params["model"] = split_model_provider[1] + known_providers = { + "openai", + "azure_openai", + "google_genai", + "google_vertexai", + "ollama", + "oneapi", + "nvidia", + "groq", + "anthropic", + "bedrock", + "mistralai", + "hugging_face", + "deepseek", + "ernie", + "fireworks", + "togetherai", + } + + if "/" in llm_params["model"]: + split_model_provider = llm_params["model"].split("/", 1) + llm_params["model_provider"] = split_model_provider[0] + llm_params["model"] = split_model_provider[1] else: - possible_providers = [provider for provider, models_d in models_tokens.items() if llm_params["model"] in models_d] + possible_providers = [ + provider + for provider, models_d in models_tokens.items() + if llm_params["model"] in models_d + ] if len(possible_providers) <= 0: raise ValueError(f"""Provider {llm_params['model_provider']} is not supported. If possible, try to use a model instance instead.""") llm_params["model_provider"] = possible_providers[0] - print((f"Found providers {possible_providers} for model {llm_params['model']}, using {llm_params['model_provider']}.\n" - "If it was not intended please specify the model provider in the graph configuration")) + print( + ( + f"Found providers {possible_providers} for model {llm_params['model']}, using {llm_params['model_provider']}.\n" + "If it was not intended please specify the model provider in the graph configuration" + ) + ) if llm_params["model_provider"] not in known_providers: raise ValueError(f"""Provider {llm_params['model_provider']} is not supported. @@ -163,7 +189,9 @@ def _create_llm(self, llm_config: dict) -> object: if "model_tokens" not in llm_params: try: - self.model_token = models_tokens[llm_params["model_provider"]][llm_params["model"]] + self.model_token = models_tokens[llm_params["model_provider"]][ + llm_params["model"] + ] except KeyError: print(f"""Model {llm_params['model_provider']}/{llm_params['model']} not found, using default token size (8192)""") @@ -172,10 +200,17 @@ def _create_llm(self, llm_config: dict) -> object: self.model_token = llm_params["model_tokens"] try: - if llm_params["model_provider"] not in \ - {"oneapi","nvidia","ernie","deepseek","togetherai"}: + if llm_params["model_provider"] not in { + "oneapi", + "nvidia", + "ernie", + "deepseek", + "togetherai", + }: if llm_params["model_provider"] == "bedrock": - llm_params["model_kwargs"] = { "temperature" : llm_params.pop("temperature") } + llm_params["model_kwargs"] = { + "temperature": llm_params.pop("temperature") + } with warnings.catch_warnings(): warnings.simplefilter("ignore") return init_chat_model(**llm_params) @@ -187,6 +222,7 @@ def _create_llm(self, llm_config: dict) -> object: if model_provider == "ernie": from langchain_community.chat_models import ErnieBotChat + return ErnieBotChat(**llm_params) elif model_provider == "oneapi": @@ -211,7 +247,6 @@ def _create_llm(self, llm_config: dict) -> object: except Exception as e: raise Exception(f"Error instancing model: {e}") - def get_state(self, key=None) -> dict: """ "" Get the final state of the graph. diff --git a/scrapegraphai/graphs/code_generator_graph.py b/scrapegraphai/graphs/code_generator_graph.py index b8329f18..fe94e9d5 100644 --- a/scrapegraphai/graphs/code_generator_graph.py +++ b/scrapegraphai/graphs/code_generator_graph.py @@ -1,6 +1,7 @@ """ SmartScraperGraph Module """ + from typing import Optional import logging from pydantic import BaseModel @@ -16,11 +17,12 @@ GenerateCodeNode, ) + class CodeGeneratorGraph(AbstractGraph): """ - CodeGeneratorGraph is a script generator pipeline that generates + CodeGeneratorGraph is a script generator pipeline that generates the function extract_data(html: str) -> dict() for - extracting the wanted information from a HTML page. + extracting the wanted information from a HTML page. The code generated is in Python and uses the library BeautifulSoup. It requires a user prompt, a source URL, and an output schema. @@ -52,8 +54,9 @@ class CodeGeneratorGraph(AbstractGraph): ) """ - def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None): - + def __init__( + self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None + ): super().__init__(prompt, config, source, schema) self.input_key = "url" if source.startswith("http") else "local_dir" @@ -78,16 +81,14 @@ def _create_graph(self) -> BaseGraph: "cut": self.config.get("cut", True), "loader_kwargs": self.config.get("loader_kwargs", {}), "browser_base": self.config.get("browser_base"), - "scrape_do": self.config.get("scrape_do") - } + "scrape_do": self.config.get("scrape_do"), + "storage_state": self.config.get("storage_state"), + }, ) parse_node = ParseNode( input="doc", output=["parsed_doc"], - node_config={ - "llm_model": self.llm_model, - "chunk_size": self.model_token - } + node_config={"llm_model": self.llm_model, "chunk_size": self.model_token}, ) generate_validation_answer_node = GenerateAnswerNode( @@ -97,7 +98,7 @@ def _create_graph(self) -> BaseGraph: "llm_model": self.llm_model, "additional_info": self.config.get("additional_info"), "schema": self.schema, - } + }, ) prompt_refier_node = PromptRefinerNode( @@ -106,8 +107,8 @@ def _create_graph(self) -> BaseGraph: node_config={ "llm_model": self.llm_model, "chunk_size": self.model_token, - "schema": self.schema - } + "schema": self.schema, + }, ) html_analyzer_node = HtmlAnalyzerNode( @@ -117,8 +118,8 @@ def _create_graph(self) -> BaseGraph: "llm_model": self.llm_model, "additional_info": self.config.get("additional_info"), "schema": self.schema, - "reduction": self.config.get("reduction", 0) - } + "reduction": self.config.get("reduction", 0), + }, ) generate_code_node = GenerateCodeNode( @@ -128,14 +129,17 @@ def _create_graph(self) -> BaseGraph: "llm_model": self.llm_model, "additional_info": self.config.get("additional_info"), "schema": self.schema, - "max_iterations": self.config.get("max_iterations", { - "overall": 10, - "syntax": 3, - "execution": 3, - "validation": 3, - "semantic": 3 - }), - } + "max_iterations": self.config.get( + "max_iterations", + { + "overall": 10, + "syntax": 3, + "execution": 3, + "validation": 3, + "semantic": 3, + }, + ), + }, ) return BaseGraph( @@ -152,10 +156,10 @@ def _create_graph(self) -> BaseGraph: (parse_node, generate_validation_answer_node), (generate_validation_answer_node, prompt_refier_node), (prompt_refier_node, html_analyzer_node), - (html_analyzer_node, generate_code_node) + (html_analyzer_node, generate_code_node), ], entry_point=fetch_node, - graph_name=self.__class__.__name__ + graph_name=self.__class__.__name__, ) def run(self) -> str: diff --git a/scrapegraphai/graphs/depth_search_graph.py b/scrapegraphai/graphs/depth_search_graph.py index 56cb2f16..0df9c061 100644 --- a/scrapegraphai/graphs/depth_search_graph.py +++ b/scrapegraphai/graphs/depth_search_graph.py @@ -1,6 +1,7 @@ """ depth search graph Module """ + from typing import Optional import logging from pydantic import BaseModel @@ -11,14 +12,15 @@ ParseNodeDepthK, DescriptionNode, RAGNode, - GenerateAnswerNodeKLevel + GenerateAnswerNodeKLevel, ) + class DepthSearchGraph(AbstractGraph): """ - CodeGeneratorGraph is a script generator pipeline that generates + CodeGeneratorGraph is a script generator pipeline that generates the function extract_data(html: str) -> dict() for - extracting the wanted information from a HTML page. The + extracting the wanted information from a HTML page. The code generated is in Python and uses the library BeautifulSoup. It requires a user prompt, a source URL, and an output schema. @@ -50,8 +52,9 @@ class DepthSearchGraph(AbstractGraph): ) """ - def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None): - + def __init__( + self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None + ): super().__init__(prompt, config, source, schema) self.input_key = "url" if source.startswith("http") else "local_dir" @@ -72,17 +75,16 @@ def _create_graph(self) -> BaseGraph: "force": self.config.get("force", False), "cut": self.config.get("cut", True), "browser_base": self.config.get("browser_base"), + "storage_state": self.config.get("storage_state"), "depth": self.config.get("depth", 1), - "only_inside_links": self.config.get("only_inside_links", False) - } + "only_inside_links": self.config.get("only_inside_links", False), + }, ) parse_node_k = ParseNodeDepthK( input="docs", output=["docs"], - node_config={ - "verbose": self.config.get("verbose", False) - } + node_config={"verbose": self.config.get("verbose", False)}, ) description_node = DescriptionNode( @@ -91,18 +93,18 @@ def _create_graph(self) -> BaseGraph: node_config={ "llm_model": self.llm_model, "verbose": self.config.get("verbose", False), - "cache_path": self.config.get("cache_path", False) - } + "cache_path": self.config.get("cache_path", False), + }, ) - rag_node = RAGNode ( + rag_node = RAGNode( input="docs", output=["vectorial_db"], node_config={ "llm_model": self.llm_model, "embedder_model": self.config.get("embedder_model", False), "verbose": self.config.get("verbose", False), - } + }, ) generate_answer_k = GenerateAnswerNodeKLevel( @@ -112,8 +114,7 @@ def _create_graph(self) -> BaseGraph: "llm_model": self.llm_model, "embedder_model": self.config.get("embedder_model", False), "verbose": self.config.get("verbose", False), - } - + }, ) return BaseGraph( @@ -122,16 +123,16 @@ def _create_graph(self) -> BaseGraph: parse_node_k, description_node, rag_node, - generate_answer_k + generate_answer_k, ], edges=[ (fetch_node_k, parse_node_k), (parse_node_k, description_node), (description_node, rag_node), - (rag_node, generate_answer_k) + (rag_node, generate_answer_k), ], entry_point=fetch_node_k, - graph_name=self.__class__.__name__ + graph_name=self.__class__.__name__, ) def run(self) -> str: diff --git a/scrapegraphai/graphs/document_scraper_graph.py b/scrapegraphai/graphs/document_scraper_graph.py index 48664f7f..db3244c5 100644 --- a/scrapegraphai/graphs/document_scraper_graph.py +++ b/scrapegraphai/graphs/document_scraper_graph.py @@ -1,6 +1,7 @@ """ This module implements the Document Scraper Graph for the ScrapeGraphAI application. """ + from typing import Optional import logging from pydantic import BaseModel @@ -8,10 +9,11 @@ from .abstract_graph import AbstractGraph from ..nodes import FetchNode, ParseNode, GenerateAnswerNode + class DocumentScraperGraph(AbstractGraph): """ - DocumentScraperGraph is a scraping pipeline that automates the process of - extracting information from web pages using a natural language model to interpret + DocumentScraperGraph is a scraping pipeline that automates the process of + extracting information from web pages using a natural language model to interpret and answer prompts. Attributes: @@ -20,7 +22,7 @@ class DocumentScraperGraph(AbstractGraph): config (dict): Configuration parameters for the graph. schema (BaseModel): The schema for the graph output. llm_model: An instance of a language model client, configured for generating answers. - embedder_model: An instance of an embedding model client, + embedder_model: An instance of an embedding model client, configured for generating embeddings. verbose (bool): A flag indicating whether to show print statements during execution. headless (bool): A flag indicating whether to run the graph in headless mode. @@ -40,7 +42,9 @@ class DocumentScraperGraph(AbstractGraph): >>> result = smart_scraper.run() """ - def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None): + def __init__( + self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None + ): super().__init__(prompt, config, source, schema) self.input_key = "md" if source.endswith("md") else "md_dir" @@ -57,7 +61,8 @@ def _create_graph(self) -> BaseGraph: output=["doc"], node_config={ "loader_kwargs": self.config.get("loader_kwargs", {}), - } + "storage_state": self.config.get("storage_state", None), + }, ) parse_node = ParseNode( input="doc", @@ -65,8 +70,8 @@ def _create_graph(self) -> BaseGraph: node_config={ "parse_html": False, "chunk_size": self.model_token, - "llm_model": self.llm_model - } + "llm_model": self.llm_model, + }, ) generate_answer_node = GenerateAnswerNode( input="user_prompt & (relevant_chunks | parsed_doc | doc)", @@ -75,8 +80,8 @@ def _create_graph(self) -> BaseGraph: "llm_model": self.llm_model, "additional_info": self.config.get("additional_info"), "schema": self.schema, - "is_md_scraper": True - } + "is_md_scraper": True, + }, ) return BaseGraph( @@ -85,12 +90,9 @@ def _create_graph(self) -> BaseGraph: parse_node, generate_answer_node, ], - edges=[ - (fetch_node, parse_node), - (parse_node, generate_answer_node) - ], + edges=[(fetch_node, parse_node), (parse_node, generate_answer_node)], entry_point=fetch_node, - graph_name=self.__class__.__name__ + graph_name=self.__class__.__name__, ) def run(self) -> str: diff --git a/scrapegraphai/graphs/omni_scraper_graph.py b/scrapegraphai/graphs/omni_scraper_graph.py index be909ba2..035ad6a7 100644 --- a/scrapegraphai/graphs/omni_scraper_graph.py +++ b/scrapegraphai/graphs/omni_scraper_graph.py @@ -1,21 +1,18 @@ """ This module implements the Omni Scraper Graph for the ScrapeGraphAI application. """ + from typing import Optional from pydantic import BaseModel from .base_graph import BaseGraph from .abstract_graph import AbstractGraph -from ..nodes import ( - FetchNode, - ParseNode, - ImageToTextNode, - GenerateAnswerOmniNode -) +from ..nodes import FetchNode, ParseNode, ImageToTextNode, GenerateAnswerOmniNode from ..models import OpenAIImageToText + class OmniScraperGraph(AbstractGraph): """ - OmniScraper is a scraping pipeline that automates the process of + OmniScraper is a scraping pipeline that automates the process of extracting information from web pages using a natural language model to interpret and answer prompts. @@ -25,7 +22,7 @@ class OmniScraperGraph(AbstractGraph): config (dict): Configuration parameters for the graph. schema (BaseModel): The schema for the graph output. llm_model: An instance of a language model client, configured for generating answers. - embedder_model: An instance of an embedding model client, + embedder_model: An instance of an embedding model client, configured for generating embeddings. verbose (bool): A flag indicating whether to show print statements during execution. headless (bool): A flag indicating whether to run the graph in headless mode. @@ -47,8 +44,9 @@ class OmniScraperGraph(AbstractGraph): ) """ - def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None): - + def __init__( + self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None + ): self.max_images = 5 if config is None else config.get("max_images", 5) super().__init__(prompt, config, source, schema) @@ -68,7 +66,8 @@ def _create_graph(self) -> BaseGraph: output=["doc"], node_config={ "loader_kwargs": self.config.get("loader_kwargs", {}), - } + "storage_state": self.config.get("storage_state"), + }, ) parse_node = ParseNode( @@ -77,8 +76,8 @@ def _create_graph(self) -> BaseGraph: node_config={ "chunk_size": self.model_token, "parse_urls": True, - "llm_model": self.llm_model - } + "llm_model": self.llm_model, + }, ) image_to_text_node = ImageToTextNode( @@ -86,8 +85,8 @@ def _create_graph(self) -> BaseGraph: output=["img_desc"], node_config={ "llm_model": OpenAIImageToText(self.config["llm"]), - "max_images": self.max_images - } + "max_images": self.max_images, + }, ) generate_answer_omni_node = GenerateAnswerOmniNode( @@ -96,8 +95,8 @@ def _create_graph(self) -> BaseGraph: node_config={ "llm_model": self.llm_model, "additional_info": self.config.get("additional_info"), - "schema": self.schema - } + "schema": self.schema, + }, ) return BaseGraph( @@ -110,10 +109,10 @@ def _create_graph(self) -> BaseGraph: edges=[ (fetch_node, parse_node), (parse_node, image_to_text_node), - (image_to_text_node, generate_answer_omni_node) + (image_to_text_node, generate_answer_omni_node), ], entry_point=fetch_node, - graph_name=self.__class__.__name__ + graph_name=self.__class__.__name__, ) def run(self) -> str: diff --git a/scrapegraphai/graphs/script_creator_graph.py b/scrapegraphai/graphs/script_creator_graph.py index f6a884a1..1e785c92 100644 --- a/scrapegraphai/graphs/script_creator_graph.py +++ b/scrapegraphai/graphs/script_creator_graph.py @@ -1,15 +1,13 @@ """ ScriptCreatorGraph Module """ + from typing import Optional from pydantic import BaseModel from .base_graph import BaseGraph from .abstract_graph import AbstractGraph -from ..nodes import ( - FetchNode, - ParseNode, - GenerateScraperNode -) +from ..nodes import FetchNode, ParseNode, GenerateScraperNode + class ScriptCreatorGraph(AbstractGraph): """ @@ -21,7 +19,7 @@ class ScriptCreatorGraph(AbstractGraph): config (dict): Configuration parameters for the graph. schema (BaseModel): The schema for the graph output. llm_model: An instance of a language model client, configured for generating answers. - embedder_model: An instance of an embedding model client, + embedder_model: An instance of an embedding model client, configured for generating embeddings. verbose (bool): A flag indicating whether to show print statements during execution. headless (bool): A flag indicating whether to run the graph in headless mode. @@ -43,9 +41,10 @@ class ScriptCreatorGraph(AbstractGraph): >>> result = script_creator.run() """ - def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None): - - self.library = config['library'] + def __init__( + self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None + ): + self.library = config["library"] super().__init__(prompt, config, source, schema) @@ -65,17 +64,19 @@ def _create_graph(self) -> BaseGraph: node_config={ "llm_model": self.llm_model, "loader_kwargs": self.config.get("loader_kwargs", {}), - "script_creator": True - } + "script_creator": True, + "storage_state": self.config.get("storage_state"), + }, ) parse_node = ParseNode( input="doc", output=["parsed_doc"], - node_config={"chunk_size": self.model_token, - "parse_html": False, - "llm_model": self.llm_model - } + node_config={ + "chunk_size": self.model_token, + "parse_html": False, + "llm_model": self.llm_model, + }, ) generate_scraper_node = GenerateScraperNode( @@ -87,7 +88,7 @@ def _create_graph(self) -> BaseGraph: "schema": self.schema, }, library=self.library, - website=self.source + website=self.source, ) return BaseGraph( @@ -101,7 +102,7 @@ def _create_graph(self) -> BaseGraph: (parse_node, generate_scraper_node), ], entry_point=fetch_node, - graph_name=self.__class__.__name__ + graph_name=self.__class__.__name__, ) def run(self) -> str: diff --git a/scrapegraphai/graphs/search_graph.py b/scrapegraphai/graphs/search_graph.py index 987eab8b..313cb768 100644 --- a/scrapegraphai/graphs/search_graph.py +++ b/scrapegraphai/graphs/search_graph.py @@ -1,21 +1,19 @@ -""" +""" SearchGraph Module """ + from copy import deepcopy from typing import Optional, List from pydantic import BaseModel from .base_graph import BaseGraph from .abstract_graph import AbstractGraph from .smart_scraper_graph import SmartScraperGraph -from ..nodes import ( - SearchInternetNode, - GraphIteratorNode, - MergeAnswersNode -) +from ..nodes import SearchInternetNode, GraphIteratorNode, MergeAnswersNode from ..utils.copy import safe_deepcopy + class SearchGraph(AbstractGraph): - """ + """ SearchGraph is a scraping pipeline that searches the internet for answers to a given prompt. It only requires a user prompt to search the internet and generate an answer. @@ -66,9 +64,10 @@ def _create_graph(self) -> BaseGraph: "llm_model": self.llm_model, "max_results": self.max_results, "loader_kwargs": self.loader_kwargs, + "storage_state": self.copy_config.get("storage_state"), "search_engine": self.copy_config.get("search_engine"), - "serper_api_key": self.copy_config.get("serper_api_key") - } + "serper_api_key": self.copy_config.get("serper_api_key"), + }, ) graph_iterator_node = GraphIteratorNode( @@ -76,32 +75,25 @@ def _create_graph(self) -> BaseGraph: output=["results"], node_config={ "graph_instance": SmartScraperGraph, - "scraper_config": self.copy_config + "scraper_config": self.copy_config, }, - schema=self.copy_schema + schema=self.copy_schema, ) merge_answers_node = MergeAnswersNode( input="user_prompt & results", output=["answer"], - node_config={ - "llm_model": self.llm_model, - "schema": self.copy_schema - } + node_config={"llm_model": self.llm_model, "schema": self.copy_schema}, ) return BaseGraph( - nodes=[ - search_internet_node, - graph_iterator_node, - merge_answers_node - ], + nodes=[search_internet_node, graph_iterator_node, merge_answers_node], edges=[ (search_internet_node, graph_iterator_node), - (graph_iterator_node, merge_answers_node) + (graph_iterator_node, merge_answers_node), ], entry_point=search_internet_node, - graph_name=self.__class__.__name__ + graph_name=self.__class__.__name__, ) def run(self) -> str: @@ -116,8 +108,8 @@ def run(self) -> str: self.final_state, self.execution_info = self.graph.execute(inputs) # Store the URLs after execution - if 'urls' in self.final_state: - self.considered_urls = self.final_state['urls'] + if "urls" in self.final_state: + self.considered_urls = self.final_state["urls"] return self.final_state.get("answer", "No answer found.") diff --git a/scrapegraphai/graphs/search_link_graph.py b/scrapegraphai/graphs/search_link_graph.py index 1a83feb1..e8baf1d8 100644 --- a/scrapegraphai/graphs/search_link_graph.py +++ b/scrapegraphai/graphs/search_link_graph.py @@ -1,18 +1,18 @@ -""" -SearchLinkGraph Module """ +SearchLinkGraph Module +""" + from typing import Optional import logging from pydantic import BaseModel from .base_graph import BaseGraph from .abstract_graph import AbstractGraph -from ..nodes import (FetchNode, - SearchLinkNode, - SearchLinksWithContext) +from ..nodes import FetchNode, SearchLinkNode, SearchLinksWithContext + class SearchLinkGraph(AbstractGraph): - """ - SearchLinkGraph is a scraping pipeline that automates the process of + """ + SearchLinkGraph is a scraping pipeline that automates the process of extracting information from web pages using a natural language model to interpret and answer prompts. @@ -22,7 +22,7 @@ class SearchLinkGraph(AbstractGraph): config (dict): Configuration parameters for the graph. schema (BaseModel): The schema for the graph output. llm_model: An instance of a language model client, configured for generating answers. - embedder_model: An instance of an embedding model client, + embedder_model: An instance of an embedding model client, configured for generating embeddings. verbose (bool): A flag indicating whether to show print statements during execution. headless (bool): A flag indicating whether to run the graph in headless mode. @@ -49,14 +49,15 @@ def _create_graph(self) -> BaseGraph: """ fetch_node = FetchNode( - input="url| local_dir", - output=["doc"], - node_config={ - "force": self.config.get("force", False), - "cut": self.config.get("cut", True), - "loader_kwargs": self.config.get("loader_kwargs", {}), - } - ) + input="url| local_dir", + output=["doc"], + node_config={ + "force": self.config.get("force", False), + "cut": self.config.get("cut", True), + "loader_kwargs": self.config.get("loader_kwargs", {}), + "storage_state": self.config.get("storage_state"), + }, + ) if self.config.get("llm_style") == (True, None): search_link_node = SearchLinksWithContext( @@ -65,7 +66,7 @@ def _create_graph(self) -> BaseGraph: node_config={ "llm_model": self.llm_model, "chunk_size": self.model_token, - } + }, ) else: search_link_node = SearchLinkNode( @@ -73,19 +74,14 @@ def _create_graph(self) -> BaseGraph: output=["parsed_doc"], node_config={ "chunk_size": self.model_token, - } + }, ) return BaseGraph( - nodes=[ - fetch_node, - search_link_node - ], - edges=[ - (fetch_node, search_link_node) - ], + nodes=[fetch_node, search_link_node], + edges=[(fetch_node, search_link_node)], entry_point=fetch_node, - graph_name=self.__class__.__name__ + graph_name=self.__class__.__name__, ) def run(self) -> str: diff --git a/scrapegraphai/graphs/smart_scraper_graph.py b/scrapegraphai/graphs/smart_scraper_graph.py index 340f69bb..f7ce5dde 100644 --- a/scrapegraphai/graphs/smart_scraper_graph.py +++ b/scrapegraphai/graphs/smart_scraper_graph.py @@ -1,6 +1,7 @@ """ SmartScraperGraph Module """ + from typing import Optional from pydantic import BaseModel from .base_graph import BaseGraph @@ -10,14 +11,14 @@ ParseNode, ReasoningNode, GenerateAnswerNode, - ConditionalNode + ConditionalNode, ) from ..prompts import REGEN_ADDITIONAL_INFO from scrapegraph_py import SyncClient class SmartScraperGraph(AbstractGraph): """ - SmartScraper is a scraping pipeline that automates the process of + SmartScraper is a scraping pipeline that automates the process of extracting information from web pages using a natural language model to interpret and answer prompts. @@ -27,7 +28,7 @@ class SmartScraperGraph(AbstractGraph): config (dict): Configuration parameters for the graph. schema (BaseModel): The schema for the graph output. llm_model: An instance of a language model client, configured for generating answers. - embedder_model: An instance of an embedding model client, + embedder_model: An instance of an embedding model client, configured for generating embeddings. verbose (bool): A flag indicating whether to show print statements during execution. headless (bool): A flag indicating whether to run the graph in headless mode. @@ -48,7 +49,9 @@ class SmartScraperGraph(AbstractGraph): ) """ - def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None): + def __init__( + self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None + ): super().__init__(prompt, config, source, schema) self.input_key = "url" if source.startswith("http") else "local_dir" @@ -79,16 +82,14 @@ def _create_graph(self) -> BaseGraph: "cut": self.config.get("cut", True), "loader_kwargs": self.config.get("loader_kwargs", {}), "browser_base": self.config.get("browser_base"), - "scrape_do": self.config.get("scrape_do") - } + "scrape_do": self.config.get("scrape_do"), + "storage_state": self.config.get("storage_state"), + }, ) parse_node = ParseNode( input="doc", output=["parsed_doc"], - node_config={ - "llm_model": self.llm_model, - "chunk_size": self.model_token - } + node_config={"llm_model": self.llm_model, "chunk_size": self.model_token}, ) generate_answer_node = GenerateAnswerNode( @@ -98,7 +99,7 @@ def _create_graph(self) -> BaseGraph: "llm_model": self.llm_model, "additional_info": self.config.get("additional_info"), "schema": self.schema, - } + }, ) cond_node = None @@ -111,7 +112,7 @@ def _create_graph(self) -> BaseGraph: node_config={ "key_name": "answer", "condition": 'not answer or answer=="NA"', - } + }, ) regen_node = GenerateAnswerNode( input="user_prompt & answer", @@ -120,7 +121,7 @@ def _create_graph(self) -> BaseGraph: "llm_model": self.llm_model, "additional_info": REGEN_ADDITIONAL_INFO, "schema": self.schema, - } + }, ) if self.config.get("html_mode") is False: @@ -129,61 +130,107 @@ def _create_graph(self) -> BaseGraph: output=["parsed_doc"], node_config={ "llm_model": self.llm_model, - "chunk_size": self.model_token - } + "chunk_size": self.model_token, + }, ) reasoning_node = None if self.config.get("reasoning"): - reasoning_node = ReasoningNode( + reasoning_node = ReasoningNode( input="user_prompt & (relevant_chunks | parsed_doc | doc)", output=["answer"], node_config={ "llm_model": self.llm_model, "additional_info": self.config.get("additional_info"), "schema": self.schema, - } + }, ) - + # Define the graph variation configurations # (html_mode, reasoning, reattempt) graph_variation_config = { (False, True, False): { - "nodes": [fetch_node, parse_node, reasoning_node, generate_answer_node], - "edges": [(fetch_node, parse_node), (parse_node, reasoning_node), (reasoning_node, generate_answer_node)] + "nodes": [fetch_node, parse_node, reasoning_node, generate_answer_node], + "edges": [ + (fetch_node, parse_node), + (parse_node, reasoning_node), + (reasoning_node, generate_answer_node), + ], }, (True, True, False): { "nodes": [fetch_node, reasoning_node, generate_answer_node], - "edges": [(fetch_node, reasoning_node), (reasoning_node, generate_answer_node)] + "edges": [ + (fetch_node, reasoning_node), + (reasoning_node, generate_answer_node), + ], }, (True, False, False): { "nodes": [fetch_node, generate_answer_node], - "edges": [(fetch_node, generate_answer_node)] + "edges": [(fetch_node, generate_answer_node)], }, (False, False, False): { "nodes": [fetch_node, parse_node, generate_answer_node], - "edges": [(fetch_node, parse_node), (parse_node, generate_answer_node)] + "edges": [(fetch_node, parse_node), (parse_node, generate_answer_node)], }, (False, True, True): { - "nodes": [fetch_node, parse_node, reasoning_node, generate_answer_node, cond_node, regen_node], - "edges": [(fetch_node, parse_node), (parse_node, reasoning_node), (reasoning_node, generate_answer_node), - (generate_answer_node, cond_node), (cond_node, regen_node), (cond_node, None)] + "nodes": [ + fetch_node, + parse_node, + reasoning_node, + generate_answer_node, + cond_node, + regen_node, + ], + "edges": [ + (fetch_node, parse_node), + (parse_node, reasoning_node), + (reasoning_node, generate_answer_node), + (generate_answer_node, cond_node), + (cond_node, regen_node), + (cond_node, None), + ], }, (True, True, True): { - "nodes": [fetch_node, reasoning_node, generate_answer_node, cond_node, regen_node], - "edges": [(fetch_node, reasoning_node), (reasoning_node, generate_answer_node), - (generate_answer_node, cond_node), (cond_node, regen_node), (cond_node, None)] + "nodes": [ + fetch_node, + reasoning_node, + generate_answer_node, + cond_node, + regen_node, + ], + "edges": [ + (fetch_node, reasoning_node), + (reasoning_node, generate_answer_node), + (generate_answer_node, cond_node), + (cond_node, regen_node), + (cond_node, None), + ], }, (True, False, True): { "nodes": [fetch_node, generate_answer_node, cond_node, regen_node], - "edges": [(fetch_node, generate_answer_node), (generate_answer_node, cond_node), - (cond_node, regen_node), (cond_node, None)] + "edges": [ + (fetch_node, generate_answer_node), + (generate_answer_node, cond_node), + (cond_node, regen_node), + (cond_node, None), + ], }, (False, False, True): { - "nodes": [fetch_node, parse_node, generate_answer_node, cond_node, regen_node], - "edges": [(fetch_node, parse_node), (parse_node, generate_answer_node), - (generate_answer_node, cond_node), (cond_node, regen_node), (cond_node, None)] - } + "nodes": [ + fetch_node, + parse_node, + generate_answer_node, + cond_node, + regen_node, + ], + "edges": [ + (fetch_node, parse_node), + (parse_node, generate_answer_node), + (generate_answer_node, cond_node), + (cond_node, regen_node), + (cond_node, None), + ], + }, } # Get the current conditions @@ -199,7 +246,7 @@ def _create_graph(self) -> BaseGraph: nodes=config["nodes"], edges=config["edges"], entry_point=fetch_node, - graph_name=self.__class__.__name__ + graph_name=self.__class__.__name__, ) # Default return if no conditions match @@ -207,9 +254,9 @@ def _create_graph(self) -> BaseGraph: nodes=[fetch_node, parse_node, generate_answer_node], edges=[(fetch_node, parse_node), (parse_node, generate_answer_node)], entry_point=fetch_node, - graph_name=self.__class__.__name__ + graph_name=self.__class__.__name__, ) - + def run(self) -> str: """ Executes the scraping process and returns the answer to the prompt. diff --git a/scrapegraphai/graphs/smart_scraper_lite_graph.py b/scrapegraphai/graphs/smart_scraper_lite_graph.py index 77437145..b751a8c3 100644 --- a/scrapegraphai/graphs/smart_scraper_lite_graph.py +++ b/scrapegraphai/graphs/smart_scraper_lite_graph.py @@ -1,6 +1,7 @@ """ SmartScraperGraph Module """ + from typing import Optional from pydantic import BaseModel from .base_graph import BaseGraph @@ -10,9 +11,10 @@ ParseNode, ) + class SmartScraperLiteGraph(AbstractGraph): """ - SmartScraperLiteGraph is a scraping pipeline that automates the process of + SmartScraperLiteGraph is a scraping pipeline that automates the process of extracting information from web pages. Attributes: @@ -38,8 +40,13 @@ class SmartScraperLiteGraph(AbstractGraph): ) """ - def __init__(self, source: str, config: dict, prompt: str = "", - schema: Optional[BaseModel] = None): + def __init__( + self, + source: str, + config: dict, + prompt: str = "", + schema: Optional[BaseModel] = None, + ): super().__init__(prompt, config, source, schema) self.input_key = "url" if source.startswith("http") else "local_dir" @@ -60,17 +67,15 @@ def _create_graph(self) -> BaseGraph: "cut": self.config.get("cut", True), "loader_kwargs": self.config.get("loader_kwargs", {}), "browser_base": self.config.get("browser_base"), - "scrape_do": self.config.get("scrape_do") - } + "scrape_do": self.config.get("scrape_do"), + "storage_state": self.config.get("storage_state"), + }, ) parse_node = ParseNode( input="doc", output=["parsed_doc"], - node_config={ - "llm_model": self.llm_model, - "chunk_size": self.model_token - } + node_config={"llm_model": self.llm_model, "chunk_size": self.model_token}, ) return BaseGraph( @@ -82,7 +87,7 @@ def _create_graph(self) -> BaseGraph: (fetch_node, parse_node), ], entry_point=fetch_node, - graph_name=self.__class__.__name__ + graph_name=self.__class__.__name__, ) def run(self) -> str: diff --git a/scrapegraphai/graphs/speech_graph.py b/scrapegraphai/graphs/speech_graph.py index d491d4bc..d9d107c0 100644 --- a/scrapegraphai/graphs/speech_graph.py +++ b/scrapegraphai/graphs/speech_graph.py @@ -1,6 +1,7 @@ -""" +""" SpeechGraph Module """ + from typing import Optional from pydantic import BaseModel from .base_graph import BaseGraph @@ -14,9 +15,10 @@ from ..utils.save_audio_from_bytes import save_audio_from_bytes from ..models import OpenAITextToSpeech + class SpeechGraph(AbstractGraph): """ - SpeechyGraph is a scraping pipeline that scrapes the web, provide an answer + SpeechyGraph is a scraping pipeline that scrapes the web, provide an answer to a given prompt, and generate an audio file. Attributes: @@ -44,7 +46,9 @@ class SpeechGraph(AbstractGraph): ... {"llm": {"model": "openai/gpt-3.5-turbo"}} """ - def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None): + def __init__( + self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None + ): super().__init__(prompt, config, source, schema) self.input_key = "url" if source.startswith("http") else "local_dir" @@ -57,18 +61,12 @@ def _create_graph(self) -> BaseGraph: BaseGraph: A graph instance representing the web scraping and audio generation workflow. """ - fetch_node = FetchNode( - input="url | local_dir", - output=["doc"] - ) + fetch_node = FetchNode(input="url | local_dir", output=["doc"]) parse_node = ParseNode( input="doc", output=["parsed_doc"], - node_config={ - "chunk_size": self.model_token, - "llm_model": self.llm_model - } + node_config={"chunk_size": self.model_token, "llm_model": self.llm_model}, ) generate_answer_node = GenerateAnswerNode( @@ -77,32 +75,25 @@ def _create_graph(self) -> BaseGraph: node_config={ "llm_model": self.llm_model, "additional_info": self.config.get("additional_info"), - "schema": self.schema - } + "schema": self.schema, + }, ) text_to_speech_node = TextToSpeechNode( input="answer", output=["audio"], - node_config={ - "tts_model": OpenAITextToSpeech(self.config["tts_model"]) - } + node_config={"tts_model": OpenAITextToSpeech(self.config["tts_model"])}, ) return BaseGraph( - nodes=[ - fetch_node, - parse_node, - generate_answer_node, - text_to_speech_node - ], + nodes=[fetch_node, parse_node, generate_answer_node, text_to_speech_node], edges=[ (fetch_node, parse_node), (parse_node, generate_answer_node), - (generate_answer_node, text_to_speech_node) + (generate_answer_node, text_to_speech_node), ], entry_point=fetch_node, - graph_name=self.__class__.__name__ + graph_name=self.__class__.__name__, ) def run(self) -> str: @@ -119,8 +110,7 @@ def run(self) -> str: audio = self.final_state.get("audio", None) if not audio: raise ValueError("No audio generated from the text.") - save_audio_from_bytes(audio, self.config.get( - "output_path", "output.mp3")) + save_audio_from_bytes(audio, self.config.get("output_path", "output.mp3")) print(f"Audio saved to {self.config.get('output_path', 'output.mp3')}") return self.final_state.get("answer", "No answer found.") diff --git a/scrapegraphai/nodes/fetch_node.py b/scrapegraphai/nodes/fetch_node.py index f964eb8b..88225a20 100644 --- a/scrapegraphai/nodes/fetch_node.py +++ b/scrapegraphai/nodes/fetch_node.py @@ -1,6 +1,7 @@ """ FetchNode Module """ + import json from typing import List, Optional from langchain_openai import ChatOpenAI, AzureChatOpenAI @@ -14,6 +15,7 @@ from ..utils.logging import get_logger from .base_node import BaseNode + class FetchNode(BaseNode): """ A node responsible for fetching the HTML content of a specified URL and updating @@ -55,22 +57,18 @@ def __init__( self.loader_kwargs = ( {} if node_config is None else node_config.get("loader_kwargs", {}) ) - self.llm_model = ( - {} if node_config is None else node_config.get("llm_model", {}) - ) - self.force = ( - False if node_config is None else node_config.get("force", False) - ) + self.llm_model = {} if node_config is None else node_config.get("llm_model", {}) + self.force = False if node_config is None else node_config.get("force", False) self.script_creator = ( False if node_config is None else node_config.get("script_creator", False) ) self.openai_md_enabled = ( - False if node_config is None else node_config.get("openai_md_enabled", False) + False + if node_config is None + else node_config.get("openai_md_enabled", False) ) - self.cut = ( - False if node_config is None else node_config.get("cut", True) - ) + self.cut = False if node_config is None else node_config.get("cut", True) self.browser_base = ( None if node_config is None else node_config.get("browser_base", None) @@ -80,20 +78,27 @@ def __init__( None if node_config is None else node_config.get("scrape_do", None) ) + self.storage_state = ( + None if node_config is None else node_config.get("storage_state", None) + ) + def is_valid_url(self, source: str) -> bool: """ Validates if the source string is a valid URL using regex. - + Parameters: source (str): The URL string to validate - + Raises: ValueError: If the URL is invalid """ import re - url_pattern = r'^https?://[^\s/$.?#].[^\s]*$' + + url_pattern = r"^https?://[^\s/$.?#].[^\s]*$" if not bool(re.match(url_pattern, source)): - raise ValueError(f"Invalid URL format: {source}. URL must start with http(s):// and contain a valid domain.") + raise ValueError( + f"Invalid URL format: {source}. URL must start with http(s):// and contain a valid domain." + ) return True def execute(self, state): @@ -126,7 +131,7 @@ def execute(self, state): return handlers[input_type](state, input_type, source) elif self.input == "pdf_dir": return state - + # For web sources, validate URL before proceeding try: if self.is_valid_url(source): @@ -134,7 +139,7 @@ def execute(self, state): except ValueError as e: # Re-raise the exception from is_valid_url raise - + return self.handle_local_source(state, source) def handle_directory(self, state, input_type, source): @@ -150,9 +155,7 @@ def handle_directory(self, state, input_type, source): dict: The updated state with the compressed document. """ - compressed_document = [ - source - ] + compressed_document = [source] state.update({self.output[0]: compressed_document}) return state @@ -181,6 +184,7 @@ def handle_file(self, state, input_type, source): # return self.update_state(state, compressed_document) state.update({self.output[0]: compressed_document}) return state + def load_file_content(self, source, input_type): """ Loads the content of a file based on its input type. @@ -197,10 +201,18 @@ def load_file_content(self, source, input_type): loader = PyPDFLoader(source) return loader.load() elif input_type == "csv": - return [Document(page_content=str(pd.read_csv(source)), metadata={"source": "csv"})] + return [ + Document( + page_content=str(pd.read_csv(source)), metadata={"source": "csv"} + ) + ] elif input_type == "json": with open(source, encoding="utf-8") as f: - return [Document(page_content=str(json.load(f)), metadata={"source": "json"})] + return [ + Document( + page_content=str(json.load(f)), metadata={"source": "json"} + ) + ] elif input_type == "xml" or input_type == "md": with open(source, "r", encoding="utf-8") as f: data = f.read() @@ -228,9 +240,15 @@ def handle_local_source(self, state, source): parsed_content = source - if (isinstance(self.llm_model, ChatOpenAI) or \ - isinstance(self.llm_model, AzureChatOpenAI)) \ - and not self.script_creator or self.force and not self.script_creator: + if ( + ( + isinstance(self.llm_model, ChatOpenAI) + or isinstance(self.llm_model, AzureChatOpenAI) + ) + and not self.script_creator + or self.force + and not self.script_creator + ): parsed_content = convert_to_md(source) else: parsed_content = source @@ -242,9 +260,10 @@ def handle_local_source(self, state, source): # return self.update_state(state, compressed_document) state.update({self.output[0]: compressed_document}) return state + def handle_web_source(self, state, source): """ - Handles the web source by fetching HTML content from a URL, + Handles the web source by fetching HTML content from a URL, optionally converting it to Markdown, and updating the state. Parameters: @@ -268,8 +287,11 @@ def handle_web_source(self, state, source): if not self.cut: parsed_content = cleanup_html(response, source) - if isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI)) \ - and not self.script_creator or (self.force and not self.script_creator): + if ( + isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI)) + and not self.script_creator + or (self.force and not self.script_creator) + ): parsed_content = convert_to_md(source, parsed_content) compressed_document = [Document(page_content=parsed_content)] @@ -290,28 +312,42 @@ def handle_web_source(self, state, source): raise ImportError("""The browserbase module is not installed. Please install it using `pip install browserbase`.""") - data = browser_base_fetch(self.browser_base.get("api_key"), - self.browser_base.get("project_id"), [source]) + data = browser_base_fetch( + self.browser_base.get("api_key"), + self.browser_base.get("project_id"), + [source], + ) - document = [Document(page_content=content, - metadata={"source": source}) for content in data] + document = [ + Document(page_content=content, metadata={"source": source}) + for content in data + ] elif self.scrape_do: from ..docloaders.scrape_do import scrape_do_fetch - if (self.scrape_do.get("use_proxy") is None) or \ - self.scrape_do.get("geoCode") is None or \ - self.scrape_do.get("super_proxy") is None: - data = scrape_do_fetch(self.scrape_do.get("api_key"), - source) - else: - data = scrape_do_fetch(self.scrape_do.get("api_key"), - source, self.scrape_do.get("use_proxy"), - self.scrape_do.get("geoCode"), - self.scrape_do.get("super_proxy")) - document = [Document(page_content=data, - metadata={"source": source})] + if ( + (self.scrape_do.get("use_proxy") is None) + or self.scrape_do.get("geoCode") is None + or self.scrape_do.get("super_proxy") is None + ): + data = scrape_do_fetch(self.scrape_do.get("api_key"), source) + else: + data = scrape_do_fetch( + self.scrape_do.get("api_key"), + source, + self.scrape_do.get("use_proxy"), + self.scrape_do.get("geoCode"), + self.scrape_do.get("super_proxy"), + ) + + document = [Document(page_content=data, metadata={"source": source})] else: - loader = ChromiumLoader([source], headless=self.headless, **loader_kwargs) + loader = ChromiumLoader( + [source], + headless=self.headless, + storage_state=self.storage_state, + **loader_kwargs, + ) document = loader.load() if not document or not document[0].page_content.strip(): @@ -320,15 +356,25 @@ def handle_web_source(self, state, source): parsed_content = document[0].page_content - if (isinstance(self.llm_model, ChatOpenAI) \ - or isinstance(self.llm_model, AzureChatOpenAI)) \ - and not self.script_creator or self.force \ - and not self.script_creator and not self.openai_md_enabled: + if ( + ( + isinstance(self.llm_model, ChatOpenAI) + or isinstance(self.llm_model, AzureChatOpenAI) + ) + and not self.script_creator + or self.force + and not self.script_creator + and not self.openai_md_enabled + ): parsed_content = convert_to_md(document[0].page_content, parsed_content) compressed_document = [ Document(page_content=parsed_content, metadata={"source": "html file"}) ] state["original_html"] = document - state.update({self.output[0]: compressed_document,}) + state.update( + { + self.output[0]: compressed_document, + } + ) return state diff --git a/scrapegraphai/nodes/fetch_node_level_k.py b/scrapegraphai/nodes/fetch_node_level_k.py index ce8e4042..3307f129 100644 --- a/scrapegraphai/nodes/fetch_node_level_k.py +++ b/scrapegraphai/nodes/fetch_node_level_k.py @@ -1,6 +1,7 @@ """ fetch_node_level_k module """ + from typing import List, Optional from urllib.parse import urljoin from langchain_core.documents import Document @@ -8,9 +9,10 @@ from .base_node import BaseNode from ..docloaders import ChromiumLoader + class FetchNodeLevelK(BaseNode): """ - A node responsible for fetching the HTML content of a specified URL and all its sub-links + A node responsible for fetching the HTML content of a specified URL and all its sub-links recursively up to a certain level of hyperlink the graph. This content is then used to update the graph's state. It uses ChromiumLoader to fetch the content from a web page asynchronously (with proxy protection). @@ -58,8 +60,11 @@ def __init__( self.loader_kwargs = node_config.get("loader_kwargs", {}) if node_config else {} self.browser_base = node_config.get("browser_base", None) self.scrape_do = node_config.get("scrape_do", None) + self.storage_state = node_config.get("storage_state", None) self.depth = node_config.get("depth", 1) if node_config else 1 - self.only_inside_links = node_config.get("only_inside_links", False) if node_config else False + self.only_inside_links = ( + node_config.get("only_inside_links", False) if node_config else False + ) self.min_input_len = 1 def execute(self, state: dict) -> dict: @@ -83,12 +88,14 @@ def execute(self, state: dict) -> dict: source = input_data[0] documents = [{"source": source}] - loader_kwargs = self.node_config.get("loader_kwargs", {}) if self.node_config else {} + loader_kwargs = ( + self.node_config.get("loader_kwargs", {}) if self.node_config else {} + ) for _ in range(self.depth): documents = self.obtain_content(documents, loader_kwargs) - filtered_documents = [doc for doc in documents if 'document' in doc] + filtered_documents = [doc for doc in documents if "document" in doc] state.update({self.output[0]: filtered_documents}) return state @@ -112,17 +119,27 @@ def fetch_content(self, source: str, loader_kwargs) -> Optional[str]: raise ImportError("""The browserbase module is not installed. Please install it using `pip install browserbase`.""") - data = browser_base_fetch(self.browser_base.get("api_key"), - self.browser_base.get("project_id"), [source]) - document = [Document(page_content=content, - metadata={"source": source}) for content in data] + data = browser_base_fetch( + self.browser_base.get("api_key"), + self.browser_base.get("project_id"), + [source], + ) + document = [ + Document(page_content=content, metadata={"source": source}) + for content in data + ] elif self.scrape_do: from ..docloaders.scrape_do import scrape_do_fetch + data = scrape_do_fetch(self.scrape_do.get("api_key"), source) - document = [Document(page_content=data, - metadata={"source": source})] + document = [Document(page_content=data, metadata={"source": source})] else: - loader = ChromiumLoader([source], headless=self.headless, **loader_kwargs) + loader = ChromiumLoader( + [source], + headless=self.headless, + storage_state=self.storage_state, + **loader_kwargs, + ) document = loader.load() return document @@ -136,8 +153,8 @@ def extract_links(self, html_content: str) -> list: Returns: list: A list of extracted hyperlinks. """ - soup = BeautifulSoup(html_content, 'html.parser') - links = [link['href'] for link in soup.find_all('a', href=True)] + soup = BeautifulSoup(html_content, "html.parser") + links = [link["href"] for link in soup.find_all("a", href=True)] self.logger.info(f"Extracted {len(links)} links.") return links @@ -173,8 +190,8 @@ def obtain_content(self, documents: List, loader_kwargs) -> List: """ new_documents = [] for doc in documents: - source = doc['source'] - if 'document' not in doc: + source = doc["source"] + if "document" not in doc: document = self.fetch_content(source, loader_kwargs) if not document or not document[0].page_content.strip(): @@ -182,20 +199,27 @@ def obtain_content(self, documents: List, loader_kwargs) -> List: documents.remove(doc) continue - doc['document'] = document - links = self.extract_links(doc['document'][0].page_content) + doc["document"] = document + links = self.extract_links(doc["document"][0].page_content) full_links = self.get_full_links(source, links) for link in full_links: - if not any(d.get('source', '') == link for d in documents) \ - and not any(d.get('source', '') == link for d in new_documents): + if not any( + d.get("source", "") == link for d in documents + ) and not any(d.get("source", "") == link for d in new_documents): new_documents.append({"source": link}) documents.extend(new_documents) return documents - def process_links(self, base_url: str, links: list, - loader_kwargs, depth: int, current_depth: int = 1) -> dict: + def process_links( + self, + base_url: str, + links: list, + loader_kwargs, + depth: int, + current_depth: int = 1, + ) -> dict: """ Processes a list of links recursively up to a given depth. @@ -217,8 +241,11 @@ def process_links(self, base_url: str, links: list, if current_depth < depth: new_links = self.extract_links(link_content) - content_dict.update(self.process_links(full_link, new_links, - loader_kwargs, depth, current_depth + 1)) + content_dict.update( + self.process_links( + full_link, new_links, loader_kwargs, depth, current_depth + 1 + ) + ) else: self.logger.warning(f"Failed to fetch content for {full_link}") return content_dict