1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110import os
import pytest
import sqlalchemy
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.testclient import TestClient
from databases import Database, DatabaseURL
assert "TEST_DATABASE_URLS" in os.environ, "TEST_DATABASE_URLS is not set."
DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")]
metadata = sqlalchemy.MetaData()
notes = sqlalchemy.Table(
"notes",
metadata,
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
sqlalchemy.Column("text", sqlalchemy.String(length=100)),
sqlalchemy.Column("completed", sqlalchemy.Boolean),
)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
# Create test databases
for url in DATABASE_URLS:
database_url = DatabaseURL(url)
if database_url.scheme in ["mysql", "mysql+aiomysql"]:
url = str(database_url.replace(driver="pymysql"))
elif database_url.scheme in [
"postgresql+aiopg",
"sqlite+aiosqlite",
"postgresql+asyncpg",
]:
url = str(database_url.replace(driver=None))
engine = sqlalchemy.create_engine(url)
metadata.create_all(engine)
# Run the test suite
yield
# Drop test databases
for url in DATABASE_URLS:
database_url = DatabaseURL(url)
if database_url.scheme in ["mysql", "mysql+aiomysql"]:
url = str(database_url.replace(driver="pymysql"))
elif database_url.scheme in [
"postgresql+aiopg",
"sqlite+aiosqlite",
"postgresql+asyncpg",
]:
url = str(database_url.replace(driver=None))
engine = sqlalchemy.create_engine(url)
metadata.drop_all(engine)
def get_app(database_url):
database = Database(database_url, force_rollback=True)
app = Starlette()
@app.on_event("startup")
async def startup():
await database.connect()
@app.on_event("shutdown")
async def shutdown():
await database.disconnect()
@app.route("/notes", methods=["GET"])
async def list_notes(request):
query = notes.select()
results = await database.fetch_all(query)
content = [
{"text": result["text"], "completed": result["completed"]}
for result in results
]
return JSONResponse(content)
@app.route("/notes", methods=["POST"])
async def add_note(request):
data = await request.json()
query = notes.insert().values(text=data["text"], completed=data["completed"])
await database.execute(query)
return JSONResponse({"text": data["text"], "completed": data["completed"]})
return app
@pytest.mark.parametrize("database_url", DATABASE_URLS)
def test_integration(database_url):
app = get_app(database_url)
with TestClient(app) as client:
response = client.post("/notes", json={"text": "example", "completed": True})
assert response.status_code == 200
assert response.json() == {"text": "example", "completed": True}
response = client.get("/notes")
assert response.status_code == 200
assert response.json() == [{"text": "example", "completed": True}]
with TestClient(app) as client:
# Ensure sessions are isolated
response = client.get("/notes")
assert response.status_code == 200
assert response.json() == []