Library for using pose models created with Teachable Machine.
There is one link related to your model that will be provided by Teachable Machine
https://teachablemachine.withgoogle.com/models/MODEL_ID/
Which you can use to access:
- The model topology:
https://teachablemachine.withgoogle.com/models/MODEL_ID/model.json
- The model metadata:
https://teachablemachine.withgoogle.com/models/MODEL_ID/metadata.json
There are two ways to easily use the model provided by Teachable Machine in your Javascript project: by using this library via script tags or by installing this library from NPM (and using a build tool ike Parcel, WebPack, or Rollup)
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@teachablemachine/[email protected]/dist/teachablemachine-pose.min.js"></script>
npm i @tensorflow/tfjs
npm i @teachablemachine/pose
import * as tf from '@tensorflow/tfjs';
import * as tmPose from '@teachablemachine/pose';
<div>Teachable Machine Pose Model</div>
<button type='button' onclick='init()'>Start</button>
<div><canvas id='canvas'></canvas></div>
<div id='label-container'></div>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@teachablemachine/[email protected]/dist/teachablemachine-pose.min.js"></script>
<script type="text/javascript">
// More API functions here:
// https://github.com/googlecreativelab/teachablemachine-community/tree/master/libraries/pose
// the link to your model provided by Teachable Machine export panel
const URL = '{{URL}}';
let model, webcam, ctx, labelContainer, maxPredictions;
async function init() {
const modelURL = URL + 'model.json';
const metadataURL = URL + 'metadata.json';
// load the model and metadata
// Refer to tmPose.loadFromFiles() in the API to support files from a file picker
model = await tmPose.load(modelURL, metadataURL);
maxPredictions = model.getTotalClasses();
// Convenience function to setup a webcam
const flip = true; // whether to flip the webcam
webcam = new tmPose.Webcam(200, 200, flip); // width, height, flip
await webcam.setup(); // request access to the webcam
webcam.play();
window.requestAnimationFrame(loop);
// append/get elements to the DOM
const canvas = document.getElementById('canvas');
canvas.width = 200; canvas.height = 200;
ctx = canvas.getContext('2d');
labelContainer = document.getElementById('label-container');
for (let i = 0; i < maxPredictions; i++) { // and class labels
labelContainer.appendChild(document.createElement('div'));
}
}
async function loop(timestamp) {
webcam.update(); // update the webcam frame
await predict();
window.requestAnimationFrame(loop);
}
async function predict() {
// Prediction #1: run input through posenet
// estimatePose can take in an image, video or canvas html element
const { pose, posenetOutput } = await model.estimatePose(webcam.canvas);
// Prediction 2: run input through teachable machine classification model
const prediction = await model.predict(posenetOutput);
for (let i = 0; i < maxPredictions; i++) {
const classPrediction =
prediction[i].className + ': ' + prediction[i].probability.toFixed(2);
labelContainer.childNodes[i].innerHTML = classPrediction;
}
// finally draw the poses
drawPose(pose);
}
function drawPose(pose) {
ctx.drawImage(webcam.canvas, 0, 0);
// draw the keypoints and skeleton
if (pose) {
const minPartConfidence = 0.5;
tmPose.drawKeypoints(pose.keypoints, minPartConfidence, ctx);
tmPose.drawSkeleton(pose.keypoints, minPartConfidence, ctx);
}
}
</script>
tmPose
is the module name, which is automatically included when you use the <script src>
method. It gets added as an object to your window so you can access via window.tmPose
or simply tmPose
.
tmPose.load(
checkpoint: string,
metadata?: string | Metadata
)
Args:
- checkpoint: a URL to a json file that contains the model topology and a reference to a bin file (model weights)
- metadata: a URL to a json file that contains the text labels of your model and additional information
Usage:
await tmPose.load(checkpointURL, metadataURL);
You can upload your model files from a local hard drive by using a file picker and the File interface.
tmPose.loadFromFiles(
model: File,
weights: File,
metadata: File
)
Args:
- model: a File object that contains the model topology (.json)
- weights: a File object with the model weights (.bin)
- metadata: a File object that contains the text labels of your model and additional information (.json)
Usage:
// you need to create File objects, like with file input elements (<input type="file" ...>)
const uploadModel = document.getElementById('upload-model');
const uploadWeights = document.getElementById('upload-weights');
const uploadMetadata = document.getElementById('upload-metadata');
model = await tmPose.loadFromFiles(uploadModel.files[0], uploadWeights.files[0], uploadMetadata.files[0])
Once you have loaded a model, you can obtain the total number of classes in the model.
This method exists on the model that is loaded from tmPose.load
.
model.getTotalClasses()
Returns a number representing the total number of classes
You'll have to run your input through two models to make a prediction: first through posenet and then through the classification model created via Teachable Machine.
This method exists on the model that is loaded from tmPose.load
.
model.estimatePose(
sample: ImageData | HTMLImageElement | HTMLCanvasElement | HTMLVideoElement | tf.Tensor3D,
flipHorizontal = false
)
Args:
- sample: an image, canvas, or video to pass through posenet
- flipHorizontal: a boolean to trigger whether to flip on X the pose keypoints
Usage:
const flipHorizontal = false;
const { pose, posenetOutput } = await model.estimatePose(webcamElement, flipHorizontal);
The function returns pose
an object with the keypoints data (for drawing) and posenetOutput
a Float32Array of concatenated posenet output data (for the classification prediction).
Once you have the output from posenet, you can make a classificaiton with the Teachable Machine model you trained.
This method exists on the model that is loaded from tmPose.load
.
model.predict(
poseOutput: Float32Array
)
Args:
- poseOutput: an array representing the output of posenet from the
mode.estimatePose
function
Usage:
// predict can take in an image, video or canvas html element
// if using the webcam utility, we set flip to true since the webcam was only
// flipped in CSS
const flipHorizontal = false;
const { pose, posenetOutput } = await model.estimatePose(webcamElement, flipHorizontal);
const prediction = await model.predict(posenetOutput);
An alternative function to predict()
which returns probabilities for all classes.
This method exists on the model that is loaded from tmPose.load
.
model.predictTopK(
poseOutput: Float32Array,
maxPredictions = 10
)
Args:
- poseOutput: an array representing the output of posenet from the
mode.estimatePose
function - maxPredictions: total number of predictions to return
Usage:
// predictTopK can take in an image, video or canvas html element
// if using the webcam utility, we set flip to true since the webcam was only
// flipped in CSS
const maxPredictions = model.getTotalClasses();
const flipHorizontal = false;
const { pose, posenetOutput } = await model.estimatePose(webcamElement, flipHorizontal);
const prediction = await model.predictTopK(posenetOutput, maxPredictions);
You can optionally use a webcam class that comes with the library, or spin up your own webcam. This class exists on the tmPose
module.
Please note that the default webcam used in Teachable Machine was flipped on X - so you should probably set flip = true
if creating your own webcam unless you flipped it manually in Teachable Machine.
new tmPose.Webcam(
width = 400,
height = 400,
flip = false,
)
Args:
- width: width of the webcam. It should ideally be square since that's how the model was trained with Teachable Machine.
- height: height of the webcam. It should ideally be square since that's how the model was trained with Teachable Machine.
- flip: boolean to signal whether webcam should be flipped on X. Please note this is only flipping on CSS.
Usage:
// webcam has a square ratio and is flipped by default to match training
const webcam = new tmPose.Webcam(200, 200, true);
await webcam.setup();
webcam.play();
document.body.appendChild(webcam.canvas);
After creating a Webcam object you need to call setup just once to set it up.
webcam.setup(
options: MediaTrackConstraints = {}
)
Args:
- options: optional media track contraints for the webcam
Usage:
await webcam.setup();
webcam.play();
webcam.pause();
webcam.stop();
Webcam play loads and starts playback of a media resource. Returns a promise.
Call on update to update the webcam frame.
webcam.update();
You can optionally use a utility function to draw the pose keypoints from model.estimatePose
.
tmPose.drawKeypoints(
keypoints: Keypoint[],
minConfidence: number,
ctx: CanvasRenderingContext2D,
keypointSize: number = 4,
fillColor: string = 'aqua',
strokeColor: string = 'aqua',
scale = 1
)
Args:
- keypoints: keypoints array
- minConfidence: will not draw keypoints below this confidence score
- ctx: canvas to draw on
- keypointsSize: size of the keypoints for drawing
- fillColor: css fill color
- strokeColor: css stroke colo
- scale: a scale factor for the drawing
Usage:
const flipHorizontal = false;
const { pose, posenetOutput } = await model.estimatePose(webcamEl, flipHorizontal);
const minPartConfidence = 0.5;
tmPose.drawKeypoints(pose.keypoints, minPartConfidence, canvasContext);
You can optionally use a utility function to draw the pose keypoints from model.estimatePose
.
tmPose.drawSkeleton(
keypoints: Keypoint[],
minConfidence: number,
ctx: CanvasRenderingContext2D,
lineWidth: number = 2,
strokeColor: string = 'aqua',
scale = 1
)
Args:
- keypoints: keypoints array
- minConfidence: will not draw keypoints below this confidence score
- ctx: canvas to draw on
- lineWidth: width of the segment lines to draw
- strokeColor: css stroke colo
- scale: a scale factor for the drawing
Usage:
const flipHorizontal = false;
const { pose, posenetOutput } = await model.estimatePose(webcamEl, flipHorizontal);
const minPartConfidence = 0.5;
tmPose.drawKeypoints(pose.keypoints, minPartConfidence, canvasContext);
tmPose.drawSkeleton(pose.keypoints, minPartConfidence, canvasContext);