Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion .env
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
# LLM Provider: "openai" (default), "anthropic", or "ollama"
LLM_PROVIDER=openai

# OpenAI Configuration
OPENAI_API_KEY=
OPENAI_API_VERSION=
OPENAI_API_TYPE=
OPENAI_API_TYPE=
OPENAI_MODEL=gpt-5.2

# Anthropic Configuration
ANTHROPIC_API_KEY=
ANTHROPIC_MODEL=claude-sonnet-4-5

# Ollama Configuration (local models)
OLLAMA_MODEL=llama3
OLLAMA_BASE_URL=http://localhost:11434
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ dependencies = [
"numpy>=1.26.4",
]

[project.optional-dependencies]
anthropic = ["langchain-anthropic~=0.3.12"]
Comment on lines +41 to +42
Copy link

Copilot AI Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding langchain-ollama to the optional-dependencies section similar to langchain-anthropic, since it's no longer imported directly in the ROSA core code and is only used in the turtle_agent example with lazy import. This would reduce the installation footprint for users who don't need Ollama support. For example: ollama = ["langchain-ollama~=0.3.2"]

Copilot uses AI. Check for mistakes.

[project.urls]
"Homepage" = "https://github.com/nasa-jpl/rosa"
"Bug Tracker" = "https://github.com/nasa-jpl/rosa/issues"
Expand Down
75 changes: 48 additions & 27 deletions src/rosa/rosa.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, AsyncIterable, Dict, Literal, Optional, Union
from __future__ import annotations

from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad.openai_tools import (
format_to_openai_tool_messages,
)
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, Literal, Optional

from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain.prompts import MessagesPlaceholder
from langchain_community.callbacks import get_openai_callback
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI, ChatOpenAI

if TYPE_CHECKING:
from langchain_anthropic import ChatAnthropic

Comment on lines 29 to 32
Copy link

Copilot AI Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ChatAnthropic import under TYPE_CHECKING is not used in any type annotations, only in the docstring. Since the code now uses BaseChatModel for all type hints, this import can be removed to reduce dependencies and avoid potential confusion.

Suggested change
if TYPE_CHECKING:
from langchain_anthropic import ChatAnthropic

Copilot uses AI. Check for mistakes.
from .prompts import RobotSystemPrompts, system_prompts
from .tools import ROSATools

ChatModel = Union[ChatOpenAI, AzureChatOpenAI, ChatOllama]
logger = logging.getLogger(__name__)

# Runtime-safe type alias: accepts any BaseChatModel, covering OpenAI, Azure,
# Anthropic, Ollama and any future langchain provider that implements tool calling.
ChatModel = BaseChatModel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The strict Union was removed in favor of a generic BaseChatModel, but get_llm()in llm.py still effectively hard-codes a strict union.

I am in favor of allowing any tool calling LLM, but that means we need the ability to construct a generic chat model without enumerating all possibilities.

Recommend either reverting back to Union with ChatAnthropic included and remove "any tool calling model" language, or find a good solution for constructing generic tool calling Chat models. Either fix is fine with me 👍



class ROSA:
Expand All @@ -38,7 +45,9 @@ class ROSA:

Args:
ros_version (Literal[1, 2]): The version of ROS that the agent will interact with.
llm (Union[AzureChatOpenAI, ChatOpenAI, ChatOllama]): The language model to use for generating responses.
llm (BaseChatModel): Any langchain chat model that supports tool calling. Tested with
ChatOpenAI, AzureChatOpenAI, ChatOllama, and ChatAnthropic. Note: token usage
tracking only works with ChatOpenAI and AzureChatOpenAI.
tools (Optional[list]): A list of additional LangChain tool functions to use with the agent.
tool_packages (Optional[list]): A list of Python packages containing LangChain tool functions to use.
prompts (Optional[RobotSystemPrompts]): Custom prompts to use with the agent.
Expand Down Expand Up @@ -95,7 +104,6 @@ def __init__(
ros_version, packages=tool_packages, tools=tools, blacklist=self.__blacklist
)
self.__prompts = self._get_prompts(prompts)
self.__llm_with_tools = self.__llm.bind_tools(self.__tools.get_tools())
self.__agent = self._get_agent()
self.__executor = self._get_executor(verbose=verbose)
self.__show_token_usage = show_token_usage if not streaming else False
Expand Down Expand Up @@ -131,7 +139,7 @@ def invoke(self, query: str) -> str:
- Token usage is printed if the show_token_usage flag is set.
"""
try:
with get_openai_callback() as cb:
with self._token_callback() as cb:
result = self.__executor.invoke(
{"input": query, "chat_history": self.__chat_history}
)
Expand Down Expand Up @@ -245,18 +253,15 @@ def _get_executor(self, verbose: bool) -> AgentExecutor:
return executor

def _get_agent(self):
"""Create and return an agent for processing user inputs and generating responses."""
agent = (
{
"input": lambda x: x["input"],
"agent_scratchpad": lambda x: format_to_openai_tool_messages(
x["intermediate_steps"]
),
"chat_history": lambda x: x["chat_history"],
}
| self.__prompts
| self.__llm_with_tools
| OpenAIToolsAgentOutputParser()
"""Create and return an agent for processing user inputs and generating responses.

Uses create_tool_calling_agent which is provider-agnostic and works with
any LLM that supports tool calling (OpenAI, Anthropic, Ollama, etc).
"""
agent = create_tool_calling_agent(
llm=self.__llm,
tools=self.__tools.get_tools(),
prompt=self.__prompts,
)
return agent
Comment on lines 269 to 276
Copy link

Copilot AI Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The refactoring from OpenAI-specific agent pipeline to create_tool_calling_agent is a significant architectural change, but there are no tests to verify that it works correctly with different providers. Consider adding integration tests that validate the agent works with mock ChatOpenAI, ChatAnthropic, and ChatOllama instances to ensure the provider-agnostic implementation behaves correctly across different LLM backends.

Copilot uses AI. Check for mistakes.

Expand Down Expand Up @@ -296,12 +301,28 @@ def _get_prompts(
)
return template

@contextmanager
def _token_callback(self):
"""Context manager for token usage tracking.

Uses the OpenAI callback when the LLM is an OpenAI-based model,
otherwise yields None so the rest of the flow is unaffected.
"""
if isinstance(self.__llm, (ChatOpenAI, AzureChatOpenAI)):
with get_openai_callback() as cb:
yield cb
else:
if self.__show_token_usage:
logger.warning("Token usage tracking is only supported for OpenAI and Azure models.")
Copy link

Copilot AI Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The warning message is only logged once when the context manager is entered, but if a user repeatedly calls invoke() with show_token_usage=True on a non-OpenAI model, they will see this warning every time. Consider adding a flag to log this warning only once, or document this limitation in the constructor's docstring more prominently to set user expectations upfront.

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _token_callback method uses isinstance() to check if the LLM is OpenAI-based, but this check happens at runtime on every invocation. Since the LLM type doesn't change after initialization, consider checking this once in init and storing a boolean flag (e.g., self.__supports_token_tracking) to avoid repeated isinstance calls. This would be more efficient and clearer.

Suggested change
if isinstance(self.__llm, (ChatOpenAI, AzureChatOpenAI)):
with get_openai_callback() as cb:
yield cb
else:
if self.__show_token_usage:
logger.warning("Token usage tracking is only supported for OpenAI and Azure models.")
# Lazily determine whether the current LLM supports OpenAI-style token tracking.
if not hasattr(self, "_ROSA__supports_token_tracking"):
self.__supports_token_tracking = isinstance(
self.__llm, (ChatOpenAI, AzureChatOpenAI)
)
if self.__supports_token_tracking:
with get_openai_callback() as cb:
yield cb
else:
if self.__show_token_usage:
logger.warning(
"Token usage tracking is only supported for OpenAI and Azure models."
)

Copilot uses AI. Check for mistakes.
yield None

def _print_usage(self, cb):
"""Print the token usage if show_token_usage is enabled."""
if cb and self.__show_token_usage:
print(f"[bold]Prompt Tokens:[/bold] {cb.prompt_tokens}")
print(f"[bold]Completion Tokens:[/bold] {cb.completion_tokens}")
print(f"[bold]Total Cost (USD):[/bold] ${cb.total_cost}")
if cb is None or not self.__show_token_usage:
return
print(f"[bold]Prompt Tokens:[/bold] {cb.prompt_tokens}")
print(f"[bold]Completion Tokens:[/bold] {cb.completion_tokens}")
print(f"[bold]Total Cost (USD):[/bold] ${cb.total_cost}")

def _record_chat_history(self, query: str, response: str):
"""Record the chat history if accumulation is enabled."""
Expand Down
53 changes: 47 additions & 6 deletions src/turtle_agent/scripts/llm.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the LLM_PROVIDER env var idea, but I don't think it should default to OpenAI if it doesn't match anthropic or ollama. If we are going to hard-code these names, prefer also hard-coding openai in the if/else and failing if there is no match to avoid confusion (error message should include the supported options).

Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,55 @@


def get_llm(streaming: bool = False):
"""A helper function to get the LLM instance."""
"""A helper function to get the LLM instance.
Supports OpenAI (default), Anthropic and Ollama models.
Set the LLM_PROVIDER env variable to switch between providers:
- "openai" (default): uses OPENAI_API_KEY
- "anthropic": uses ANTHROPIC_API_KEY
- "ollama": uses local Ollama instance
"""
dotenv.load_dotenv(dotenv.find_dotenv())

llm = ChatOpenAI(
api_key=get_env_variable("OPENAI_API_KEY"),
model="gpt-5.1",
streaming=streaming,
)
provider = os.getenv("LLM_PROVIDER", "openai").lower()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the get_env_variable function and remove defaults (likely to cause confusion if this doesn't fail early with explicit error messages).


if provider == "openai":
llm = ChatOpenAI(
api_key=get_env_variable("OPENAI_API_KEY"),
model=get_env_variable("OPENAI_MODEL"),
Copy link

Copilot AI Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description claims this change is "backwards compatible" and that "Nothing breaks," but the change from a hardcoded model to requiring the OPENAI_MODEL environment variable is a breaking change for turtle_agent users. Existing users who upgrade will encounter errors if they don't update their .env file. This should be documented as a breaking change in the PR description or a migration guide should be provided.

Copilot uses AI. Check for mistakes.
streaming=streaming,
)
elif provider == "anthropic":
try:
from langchain_anthropic import ChatAnthropic
except ImportError:
raise ImportError(
"langchain-anthropic is required for Anthropic support. "
"Install it with: pip install langchain-anthropic"
)
llm = ChatAnthropic(
api_key=get_env_variable("ANTHROPIC_API_KEY"),
model=get_env_variable("ANTHROPIC_MODEL"),
streaming=streaming,
)
elif provider == "ollama":
try:
from langchain_ollama import ChatOllama
except ImportError:
raise ImportError(
"langchain-ollama is required for Ollama support. "
"Install it with: pip install langchain-ollama"
)
llm = ChatOllama(
model=os.getenv("OLLAMA_MODEL", "llama3"),
base_url=os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"),
Copy link

Copilot AI Jan 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The streaming parameter is not passed to ChatOllama, unlike ChatOpenAI and ChatAnthropic. This inconsistency means that streaming won't work properly when using Ollama. Add the streaming parameter to maintain consistent behavior across all providers.

Suggested change
base_url=os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"),
base_url=os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"),
streaming=streaming,

Copilot uses AI. Check for mistakes.
streaming=streaming,
)
else:
raise ValueError(
f"Unknown LLM provider: '{provider}'. "
"Supported providers are: 'openai', 'anthropic', 'ollama'."
)

return llm

Expand Down
Loading