diff --git a/.env b/.env index d91c274..8bc45f1 100644 --- a/.env +++ b/.env @@ -1,4 +1,16 @@ +# LLM Provider: "openai" (default), "anthropic", or "ollama" +LLM_PROVIDER=openai + # OpenAI Configuration OPENAI_API_KEY= OPENAI_API_VERSION= -OPENAI_API_TYPE= \ No newline at end of file +OPENAI_API_TYPE= +OPENAI_MODEL=gpt-5.2 + +# Anthropic Configuration +ANTHROPIC_API_KEY= +ANTHROPIC_MODEL=claude-sonnet-4-5 + +# Ollama Configuration (local models) +OLLAMA_MODEL=llama3 +OLLAMA_BASE_URL=http://localhost:11434 diff --git a/pyproject.toml b/pyproject.toml index cf71403..fba1d25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,6 @@ dependencies = [ "langchain-community~=0.3.21", "langchain-core~=0.3.52", "langchain-openai~=0.3.14", - "langchain-ollama~=0.3.2", "pydantic", "pyinputplus", "azure-identity", @@ -39,6 +38,11 @@ dependencies = [ "numpy>=1.26.4", ] +[project.optional-dependencies] +anthropic = ["langchain-anthropic~=0.3.12"] +ollama = ["langchain-ollama~=0.3.2"] +all = ["langchain-anthropic~=0.3.12", "langchain-ollama~=0.3.2"] + [project.urls] "Homepage" = "https://github.com/nasa-jpl/rosa" "Bug Tracker" = "https://github.com/nasa-jpl/rosa/issues" diff --git a/src/rosa/rosa.py b/src/rosa/rosa.py index 67ea0f8..7d0e1e7 100644 --- a/src/rosa/rosa.py +++ b/src/rosa/rosa.py @@ -12,24 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, AsyncIterable, Dict, Literal, Optional, Union +from __future__ import annotations -from langchain.agents import AgentExecutor -from langchain.agents.format_scratchpad.openai_tools import ( - format_to_openai_tool_messages, -) -from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser +import logging +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, Literal, Optional, Union + +from langchain.agents import AgentExecutor, create_tool_calling_agent from langchain.prompts import MessagesPlaceholder from langchain_community.callbacks import get_openai_callback +from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, HumanMessage from langchain_core.prompts import ChatPromptTemplate -from langchain_ollama import ChatOllama from langchain_openai import AzureChatOpenAI, ChatOpenAI +if TYPE_CHECKING: + from langchain_anthropic import ChatAnthropic + from langchain_ollama import ChatOllama + from .prompts import RobotSystemPrompts, system_prompts from .tools import ROSATools -ChatModel = Union[ChatOpenAI, AzureChatOpenAI, ChatOllama] +logger = logging.getLogger(__name__) + +# Tested providers for static analysis; BaseChatModel accepted at runtime. +if TYPE_CHECKING: + ChatModel = Union[ChatOpenAI, AzureChatOpenAI, ChatAnthropic, ChatOllama] +else: + ChatModel = BaseChatModel class ROSA: @@ -38,7 +48,10 @@ class ROSA: Args: ros_version (Literal[1, 2]): The version of ROS that the agent will interact with. - llm (Union[AzureChatOpenAI, ChatOpenAI, ChatOllama]): The language model to use for generating responses. + llm (ChatModel): The language model to use for generating responses. Tested providers: + ChatOpenAI, AzureChatOpenAI, ChatAnthropic, and ChatOllama. Other BaseChatModel + subclasses that support tool calling may work but are not officially tested. + Note: token usage tracking is only supported for ChatOpenAI and AzureChatOpenAI. tools (Optional[list]): A list of additional LangChain tool functions to use with the agent. tool_packages (Optional[list]): A list of Python packages containing LangChain tool functions to use. prompts (Optional[RobotSystemPrompts]): Custom prompts to use with the agent. @@ -95,11 +108,20 @@ def __init__( ros_version, packages=tool_packages, tools=tools, blacklist=self.__blacklist ) self.__prompts = self._get_prompts(prompts) - self.__llm_with_tools = self.__llm.bind_tools(self.__tools.get_tools()) self.__agent = self._get_agent() self.__executor = self._get_executor(verbose=verbose) + # cache this check - no need to do isinstance on every invoke + self.__supports_token_tracking = isinstance(llm, (ChatOpenAI, AzureChatOpenAI)) self.__show_token_usage = show_token_usage if not streaming else False + if self.__show_token_usage and not self.__supports_token_tracking: + logger.warning( + "Token usage tracking only works with OpenAI/Azure models, not %s. " + "Disabling.", + type(llm).__name__, + ) + self.__show_token_usage = False + @property def chat_history(self): """Get the chat history.""" @@ -131,7 +153,7 @@ def invoke(self, query: str) -> str: - Token usage is printed if the show_token_usage flag is set. """ try: - with get_openai_callback() as cb: + with self._token_callback() as cb: result = self.__executor.invoke( {"input": query, "chat_history": self.__chat_history} ) @@ -246,17 +268,10 @@ def _get_executor(self, verbose: bool) -> AgentExecutor: def _get_agent(self): """Create and return an agent for processing user inputs and generating responses.""" - agent = ( - { - "input": lambda x: x["input"], - "agent_scratchpad": lambda x: format_to_openai_tool_messages( - x["intermediate_steps"] - ), - "chat_history": lambda x: x["chat_history"], - } - | self.__prompts - | self.__llm_with_tools - | OpenAIToolsAgentOutputParser() + agent = create_tool_calling_agent( + llm=self.__llm, + tools=self.__tools.get_tools(), + prompt=self.__prompts, ) return agent @@ -296,12 +311,26 @@ def _get_prompts( ) return template + @contextmanager + def _token_callback(self): + """Context manager for token usage tracking. + + Uses the OpenAI callback when the LLM is an OpenAI-based model, + otherwise yields None so the rest of the flow is unaffected. + """ + if self.__supports_token_tracking: + with get_openai_callback() as cb: + yield cb + else: + yield None + def _print_usage(self, cb): """Print the token usage if show_token_usage is enabled.""" - if cb and self.__show_token_usage: - print(f"[bold]Prompt Tokens:[/bold] {cb.prompt_tokens}") - print(f"[bold]Completion Tokens:[/bold] {cb.completion_tokens}") - print(f"[bold]Total Cost (USD):[/bold] ${cb.total_cost}") + if cb is None or not self.__show_token_usage: + return + print(f"[bold]Prompt Tokens:[/bold] {cb.prompt_tokens}") + print(f"[bold]Completion Tokens:[/bold] {cb.completion_tokens}") + print(f"[bold]Total Cost (USD):[/bold] ${cb.total_cost}") def _record_chat_history(self, query: str, response: str): """Record the chat history if accumulation is enabled.""" diff --git a/src/turtle_agent/scripts/llm.py b/src/turtle_agent/scripts/llm.py index ffc8952..4092de6 100644 --- a/src/turtle_agent/scripts/llm.py +++ b/src/turtle_agent/scripts/llm.py @@ -19,14 +19,55 @@ def get_llm(streaming: bool = False): - """A helper function to get the LLM instance.""" + """A helper function to get the LLM instance. + + Supports OpenAI (default), Anthropic and Ollama models. + Set the LLM_PROVIDER env variable to switch between providers: + - "openai" (default): uses OPENAI_API_KEY + - "anthropic": uses ANTHROPIC_API_KEY + - "ollama": uses local Ollama instance + """ dotenv.load_dotenv(dotenv.find_dotenv()) - llm = ChatOpenAI( - api_key=get_env_variable("OPENAI_API_KEY"), - model="gpt-5.1", - streaming=streaming, - ) + provider = os.getenv("LLM_PROVIDER", "openai").lower().strip() + supported = ("openai", "anthropic", "ollama") + if provider not in supported: + raise ValueError( + f"Unknown LLM_PROVIDER: '{provider}'. Must be one of: {', '.join(supported)}" + ) + + if provider == "openai": + llm = ChatOpenAI( + api_key=get_env_variable("OPENAI_API_KEY"), + model=os.getenv("OPENAI_MODEL", "gpt-4o"), + streaming=streaming, + ) + elif provider == "anthropic": + try: + from langchain_anthropic import ChatAnthropic + except ImportError: + raise ImportError( + "langchain-anthropic is required for Anthropic support. " + "Install it with: pip install langchain-anthropic" + ) + llm = ChatAnthropic( + api_key=get_env_variable("ANTHROPIC_API_KEY"), + model=os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-5"), + streaming=streaming, + ) + elif provider == "ollama": + try: + from langchain_ollama import ChatOllama + except ImportError: + raise ImportError( + "langchain-ollama is required for Ollama support. " + "Install it with: pip install langchain-ollama" + ) + llm = ChatOllama( + model=os.getenv("OLLAMA_MODEL", "llama3"), + base_url=os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"), + streaming=streaming, + ) return llm