diff --git a/runner/gen_openapi.py b/runner/gen_openapi.py index b2934159..e5e94707 100644 --- a/runner/gen_openapi.py +++ b/runner/gen_openapi.py @@ -1,22 +1,64 @@ -from fastapi.openapi.utils import get_openapi -from app.main import app, use_route_names_as_operation_ids -from app.routes import health, text_to_image, image_to_image, image_to_video +import argparse import json + import yaml -import argparse +from app.main import app, use_route_names_as_operation_ids +from app.routes import health, image_to_image, image_to_video, text_to_image +from fastapi.openapi.utils import get_openapi # Specify Endpoints for OpenAPI schema generation. SERVERS = [ - {"url": "http://gateway-endpoint.ai/", "description": "Example Gateway"}, + { + "url": "https://dream-gateway.livepeer.cloud", + "description": "Livepeer Cloud Community Gateway", + }, ] -def write_openapi(fname): +def translate_to_gateway(openapi): + """Translate the OpenAPI schema from the 'runner' entrypoint to the 'gateway' + entrypoint created by the https://github.com/livepeer/go-livepeer package. + + .. note:: + Differences between 'runner' and 'gateway' entrypoints: + - 'health' endpoint is removed. + - 'model_id' is enforced in all endpoints. + + Args: + openapi (dict): The OpenAPI schema to be translated. + + Returns: + dict: The translated OpenAPI schema. + """ + # Remove 'health' endpoint + openapi["paths"].pop("/health") + + # Enforce 'model_id' in all endpoints + for _, methods in openapi["paths"].items(): + for _, details in methods.items(): + if "requestBody" in details: + for _, content_details in details["requestBody"]["content"].items(): + if ( + "schema" in content_details + and "$ref" in content_details["schema"] + ): + ref = content_details["schema"]["$ref"] + schema_name = ref.split("/")[-1] + schema = openapi["components"]["schemas"][schema_name] + if "model_id" in schema["properties"]: + schema["required"].append("model_id") + + return openapi + + +def write_openapi(fname, entrypoint="runner"): """Write OpenAPI schema to file. Args: fname (str): The file name to write to. The file extension determines the file type. Either 'json' or 'yaml'. + entrypoint (str): The entrypoint to generate the OpenAPI schema for, either + 'gateway' or 'runner'. Default is 'runner'. """ app.include_router(health.router) app.include_router(text_to_image.router) @@ -25,32 +67,34 @@ def write_openapi(fname): use_route_names_as_operation_ids(app) + print(f"Generating OpenAPI schema for '{entrypoint}' entrypoint...") + openapi = get_openapi( + title="Livepeer AI Runner", + version="0.1.0", + openapi_version=app.openapi_version, + description="An application to run AI pipelines", + routes=app.routes, + servers=SERVERS, + ) + + # Translate OpenAPI schema to 'gateway' side entrypoint if requested. + if entrypoint == "gateway": + print("Translating OpenAPI schema from 'runner' to 'gateway' entrypoint...") + openapi = translate_to_gateway(openapi) + fname = fname.replace(".yaml", "_gateway.yaml") + # Write OpenAPI schema to file. with open(fname, "w") as f: print(f"Writing OpenAPI schema to '{fname}'...") if fname.endswith(".yaml"): yaml.dump( - get_openapi( - title="Livepeer AI Runner", - version="0.1.0", - openapi_version=app.openapi_version, - description="An application to run AI pipelines", - routes=app.routes, - servers=SERVERS, - ), + openapi, f, sort_keys=False, ) else: json.dump( - get_openapi( - title="Livepeer AI Runner", - version="0.1.0", - openapi_version=app.openapi_version, - description="An application to run AI pipelines", - routes=app.routes, - servers=SERVERS, - ), + openapi, f, indent=4, # Make human readable. ) @@ -66,6 +110,16 @@ def write_openapi(fname): default="json", help="File type to write to, either 'json' or 'yaml'. Default is 'json'", ) + parser.add_argument( + "--entrypoint", + type=str, + choices=["gateway", "runner"], + default="runner", + help=( + "The entrypoint to generate the OpenAPI schema for, either 'gateway' or " + "'runner'. Default is 'runner'", + ), + ) args = parser.parse_args() - write_openapi(f"openapi.{args.type.lower()}") + write_openapi(f"openapi.{args.type.lower()}", args.entrypoint) diff --git a/worker/runner.gen.go b/worker/runner.gen.go index 9e2e74b3..5a704de4 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -1077,25 +1077,26 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xXS2/jNhD+KwTbozd23E1T+Jb0tUabbhC720MQGIw0lrkrkSw5TNcI/N8LkrZEvSqn", - "yKZAkZNew5lvZr556JEmslBSgEBDZ4/UJBsomL+9uJ7/qLXU7l5pqUAjB/+lMJm7IMcc6IxemYyOKG6V", - "ezCoucjobjeiGv60XENKZ7f+yN2oPFLqLs/J+4+QIN2N6KVMtytesAxWKPc3jUclDbZhZZanTCSwMglz", - "Vh5pCmtmc6Sz85OzyvjPezmy8HIlBGGLe9AOgrfiFKylLhjSGb3nguktrZTMvUjL7REtZAr5iqc1+zQ6", - "eeUEyDztOiwgY8gfYKW0LBT26vhtL0eug1yXKluEaJmVAt2l8DTSZwviPTLkGnRLKxcIWQhNpedwth+C", - "AUhjyYV77lJqUIPIcFODNzn5rgK4OEi0stUgmjqgCTmMOHcsrwYp+cBTkM3HbkqulanzsILzkzKdsdgA", - "zzb1RJ2df1udexe+dx39z2hbSORSrO5t8gmwqeR0eh5rcZLk0kvWtEV+CMkNrJjNVj3EmEwj6jphcmEz", - "0s+RJ1DxL542zJ1Opm8rc3/47+2TDRoOsK+fQh3se7dcXvd04hSQ8dzdfa1hTWf0q3HVz8f7Zj4uu20T", - "5f54BLOy1QPkA8t5ylwSByFxhMIMYWvq21VYfgiaSiBMa7b1PsRomwq6cAPLcfP9BpJPbbwGGdp6ldL3", - "v9C49XiBrglX1WRloMO+L7obMEoKA20EoUsfHbErSDmL4xQad1ecWow0ca7rsDpwB0vtiB1bS1bnsdzv", - "Oh/cE6yX8RYipAFIB8IlfMal9I5cM81C8L7UVlB15iN68esa8PQ1oOy9T2y2ezARYdq86CDPYCvLZVKr", - "Sia279d0dvvY8vGxBfEuKtBfZeLNtEp01FqlwZieAR1eVKIeM1m6t0NF5fwIpvaSUaSOaJ8f3HTqb19r", - "zYpG+3piH2vEpNyQguKBvrY3H7tUw9tyyDMysZrjduGgBOxulFwC06DL3yB36D68KpVsEBXdOR1crGWo", - "I5Nornx+Z/RCEKZUzkPCCUqirSAXc6K4gpyL4M+BF/wBFIB232+sEN7QA2gTdE1OTk8mLiBSgWCK0xn9", - "xr8aUcVw42GPN370+EYHvh5darzxeVpOJupCFuLhT00nE3dJpEAQ/lQEevzROPOHf8GhNMazzwemHpCF", - "TRIwZm1zUqbEp8AWhVtNS4ju5dh3qjco35Sr7GGvrrvlK3tf4DTwAQy6HavhV2Fz5IppHLud+E3KkB3v", - "2rF/DLs6J1Fb2H3BiNfn9rExH9G3z5n1ck/ssH/JUnITUuLtTqfPare1MrYRVCKkXCvPXsr9uUDQguVk", - "AfoBNKl270Pf8TMk7ji3d7u7uCZ8islShmncqA3/tzBYG74LvlRt9P/PvHBt1Hv/a238n2sjMNzXBsJn", - "PGJsRGvhP1bGv3e+vXi+DofXAnjeAnAci2fDbvd3AAAA//92bsyPxhcAAA==", + "H4sIAAAAAAAC/+xX3W/bNhD/VwRuj07seM0y+C3ZutbYsgax1z0EgcFIZ5mtRHLkKY0R+H8feLQl6sOT", + "M6QZUORJEnUfv/s+PrJY5VpJkGjZ5JHZeAU5p9fzq+lbY5Rx79ooDQYF0J/cpu6BAjNgE3ZpUzZguNbu", + "w6IRMmWbzYAZ+LsQBhI2uSGW20HJUsou+dTdJ4iRbQbsQiXrhch5CgtU25fGp1YW27DSQiRcxrCwMXda", + "HlkCS15kyCZnx6eV8ndbumhGdCUEWeR3YBwE0uIELJXJObIJuxOSmzWrhEyJpGX2gOUqgWwhkpp+FnBe", + "OoJomnQxS0g5intYaKNyjXtl/LGli648XZeoIvfesgsNpkvgSSCvyCOyyEZXYFpShURIvWsqOTve/RAs", + "QBJSztx3l1CLBmSKqxq80fFPFcDZjqIVrUai6R0aH8Mg5w7Nq96UvBcJqOZnd0outa3Z9GMF51dtO32x", + "ApGu6oE6PQv43vv/Xaz/W9rmCoWSi7si/gzYFHIyPgulOMrogihr0gI7pBIWFrxIF3sSYzQOUtcRR+dF", + "Gu3PkSek4heRNNSdjMZvKnV/0f82ZyMNe7Jvfwp1ZN/7+fxqTydOALnI3Nv3BpZswr4bVv18uG3mw7Lb", + "NlFu2QOYla49QD7yTCTcBbEXkkDIbR+2prxNheUXL6kEwo3ha7IhRNsU0IUbeIarn1cQf27jtcixqFcp", + "+/AbC1sPEXRNuKomKwUd+qnorsFqJS20EfgufbDHLiERPPSTb9xdfmplpA1jXYfVgdtranvs0FoqTBbS", + "/Wmy3j2hIBrSECD1QDoQzuEB54oMueKGe+d9ra2g6swH9OLXNeDpa0DZe5/YbLdggoRp50VH8vS2skzF", + "tarkcv1hySY3jy0bH1sQb4MC/V3FpKZVooPWKg3W7hnQ/qAiJczR3J32FZWzw6vaUgaeOqB9fnTTaX/7", + "WhqeN9rXE/tYwyflhuQF9/S1rfrQpBrelkGUkXFhBK5nDorH7kbJBXADprwGOaY7f1QKWSFqtnEyhFwq", + "X0c2NkJTfCfsXEZc60z4gEeoIlPI6HwaaaEhE9Lbs8sLcQ8awLj/14WUpOgejPWyRscnxyPnEKVBci3Y", + "hP1ARwOmOa4I9nBFo4caHVA9utCQ8mlSTibmXOb9QVzj0cg9YiURJHEFoIefrFO/uwv2hTGcfeSYukNm", + "RRyDtcsii8qQUAiKPHeraQnRHQ6pUx2hOipX2d1eXTeLKntb4MznA1h0O1bDrrzIUGhucOh24qOEIz/c", + "tENvDJt6TqIpYPMVPV6f24f6fMDePGfUyz2xQ/8FT6JrHxLSOx4/q97WythGUJFE5Vp5+lLmTyWCkTyL", + "ZmDuwUTV7r3rOzRDwo5zc7u5DWuCQhzNlZ/Gjdqg20JvbVAXfKna2H+feeHaqPf+19r4lmvDZzjVBsID", + "HjA2grXwXyvjvxvfXjxfh8NrATxvAbgcC2cD8Tphlljr+t4+8FxnEL3jCF/4mm1v5bTITobD1B8fgUy0", + "EhKPuRi6y8s/AQAA///zRBgJGBgAAA==", } // GetSwagger returns the content of the embedded swagger specification file