|
4 | 4 |
|
5 | 5 | import pytest |
6 | 6 | from mcp import McpError |
| 7 | +from mcp.client.auth import OAuthClientProvider |
7 | 8 | from pydantic import AnyUrl |
8 | 9 |
|
9 | 10 | from fastmcp.client import Client |
| 11 | +from fastmcp.client.auth.bearer import BearerAuth |
10 | 12 | from fastmcp.client.transports import ( |
11 | 13 | FastMCPTransport, |
12 | 14 | MCPConfigTransport, |
@@ -810,3 +812,73 @@ def test_infer_fastmcp_v1_server(self): |
810 | 812 | server = FastMCP1() |
811 | 813 | transport = infer_transport(server) |
812 | 814 | assert isinstance(transport, FastMCPTransport) |
| 815 | + |
| 816 | + |
| 817 | +class TestAuth: |
| 818 | + def test_default_auth_is_none(self): |
| 819 | + client = Client(transport=StreamableHttpTransport("http://localhost:8000")) |
| 820 | + assert client.transport.auth is None |
| 821 | + |
| 822 | + def test_stdio_doesnt_support_auth(self): |
| 823 | + with pytest.raises(ValueError, match="This transport does not support auth"): |
| 824 | + Client(transport=StdioTransport("echo", ["hello"]), auth="oauth") |
| 825 | + |
| 826 | + def test_oauth_literal_sets_up_oauth_shttp(self): |
| 827 | + client = Client( |
| 828 | + transport=StreamableHttpTransport("http://localhost:8000"), auth="oauth" |
| 829 | + ) |
| 830 | + assert isinstance(client.transport, StreamableHttpTransport) |
| 831 | + assert isinstance(client.transport.auth, OAuthClientProvider) |
| 832 | + |
| 833 | + def test_oauth_literal_pass_direct_to_transport(self): |
| 834 | + client = Client( |
| 835 | + transport=StreamableHttpTransport("http://localhost:8000", auth="oauth"), |
| 836 | + ) |
| 837 | + assert isinstance(client.transport, StreamableHttpTransport) |
| 838 | + assert isinstance(client.transport.auth, OAuthClientProvider) |
| 839 | + |
| 840 | + def test_oauth_literal_sets_up_oauth_sse(self): |
| 841 | + client = Client(transport=SSETransport("http://localhost:8000"), auth="oauth") |
| 842 | + assert isinstance(client.transport, SSETransport) |
| 843 | + assert isinstance(client.transport.auth, OAuthClientProvider) |
| 844 | + |
| 845 | + def test_oauth_literal_pass_direct_to_transport_sse(self): |
| 846 | + client = Client(transport=SSETransport("http://localhost:8000", auth="oauth")) |
| 847 | + assert isinstance(client.transport, SSETransport) |
| 848 | + assert isinstance(client.transport.auth, OAuthClientProvider) |
| 849 | + |
| 850 | + def test_auth_string_sets_up_bearer_auth_shttp(self): |
| 851 | + client = Client( |
| 852 | + transport=StreamableHttpTransport("http://localhost:8000"), |
| 853 | + auth="test_token", |
| 854 | + ) |
| 855 | + assert isinstance(client.transport, StreamableHttpTransport) |
| 856 | + assert isinstance(client.transport.auth, BearerAuth) |
| 857 | + assert client.transport.auth.token.get_secret_value() == "test_token" |
| 858 | + |
| 859 | + def test_auth_string_pass_direct_to_transport_shttp(self): |
| 860 | + client = Client( |
| 861 | + transport=StreamableHttpTransport( |
| 862 | + "http://localhost:8000", auth="test_token" |
| 863 | + ), |
| 864 | + ) |
| 865 | + assert isinstance(client.transport, StreamableHttpTransport) |
| 866 | + assert isinstance(client.transport.auth, BearerAuth) |
| 867 | + assert client.transport.auth.token.get_secret_value() == "test_token" |
| 868 | + |
| 869 | + def test_auth_string_sets_up_bearer_auth_sse(self): |
| 870 | + client = Client( |
| 871 | + transport=SSETransport("http://localhost:8000"), |
| 872 | + auth="test_token", |
| 873 | + ) |
| 874 | + assert isinstance(client.transport, SSETransport) |
| 875 | + assert isinstance(client.transport.auth, BearerAuth) |
| 876 | + assert client.transport.auth.token.get_secret_value() == "test_token" |
| 877 | + |
| 878 | + def test_auth_string_pass_direct_to_transport_sse(self): |
| 879 | + client = Client( |
| 880 | + transport=SSETransport("http://localhost:8000", auth="test_token"), |
| 881 | + ) |
| 882 | + assert isinstance(client.transport, SSETransport) |
| 883 | + assert isinstance(client.transport.auth, BearerAuth) |
| 884 | + assert client.transport.auth.token.get_secret_value() == "test_token" |
0 commit comments