Skip to content

Instantly share code, notes, and snippets.

@wellflat
Last active January 29, 2024 09:01
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wellflat/6f8958e822fa7cc103c2f30847370691 to your computer and use it in GitHub Desktop.
Save wellflat/6f8958e822fa7cc103c2f30847370691 to your computer and use it in GitHub Desktop.
Image Classification using ONNX Runtime for Web (ORT Web)
<template>
<div class="hello">
<h1>{{ msg }}</h1>
<canvas width="32" height="32" ref="canvas"></canvas>
<button type="button" @click="inference">inference</button>
<span>{{ infoLabel }}</span>
</div>
</template>
<script lang="ts">
import { defineComponent } from 'vue';
import { InferenceSession, Tensor } from 'onnxruntime-web';
interface DataType {
modelPath: string,
imagePath: string,
imageData: ImageData | null,
ctx: CanvasRenderingContext2D | null,
session: InferenceSession | null,
infoLabel: string
}
export default defineComponent({
name: 'HelloONNX',
props: {
msg: String,
},
data(): DataType {
return {
modelPath: 'cifar10_net.onnx', // ONNXモデルファイル名
imagePath: require("@/assets/cat9.png"), // テスト画像を埋め込み
imageData: null,
ctx: null,
session: null,
infoLabel: ""
};
},
async mounted() {
const option = {executionProviders: ['wasm', 'webgl']};
this.session = await InferenceSession.create(this.modelPath, option);
this.infoLabel = "loading model complete."
const image = new Image();
image.src = this.imagePath;
const isCanvas = (x: any): x is HTMLCanvasElement => x instanceof HTMLCanvasElement;
image.onload = () => {
// Canvas要素に画像ファイルを貼り、画像データを取得する
const ref = this.$refs;
if(!isCanvas(ref.canvas)) return;
this.ctx = ref.canvas.getContext("2d");
if(this.ctx == null) return;
const [w, h] = [this.ctx.canvas.width, this.ctx.canvas.height];
this.ctx.drawImage(image, 0, 0, w, h);
this.imageData = this.ctx.getImageData(0, 0, w, h);
};
},
methods: {
async inference(): Promise<void> {
if(this.session == null || this.imageData == null) return;
const { data, width, height } = this.imageData;
// 入力データの正規化(標準化)
const processed = this.normalize((data as Uint8ClampedArray), width, height);
// ORT Web用にデータ変換して、推論処理を実行
const tensor = new Tensor("float32", processed, [1, 3, width, height]);
const feed = { input: tensor };
const result = await this.session.run(feed);
// 推論結果(カテゴリと信頼度)を取得
const predicted = this.softmax((result.output.data as Float32Array));
// 信頼度が一番高いカテゴリと信頼度を画面表示
this.infoLabel = this.getClass(predicted).toString();
},
normalize(src: Uint8ClampedArray, width: number, height: number): Float32Array {
const dst = new Float32Array(width * height * 3);
const transforms = [[0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]];
const step = width * height;
for(let y = 0; y < height; y++) {
for(let x = 0; x < width; x++) {
const [di, si] = [y * width + x, (y * width + x) * 4];
// 各チャンネルで平均を引いて標準偏差で割る(標準化)
// さらに RGBARGBARGBA... から RRR...GGG...BBB... の順にデータを詰め替え
dst[di] = ((src[si + 0] / 255) - transforms[0][0]) / transforms[1][0];
dst[di + step] = ((src[si + 1] / 255) - transforms[0][1]) / transforms[1][1];
dst[di + step * 2] = ((src[si + 2] / 255) - transforms[0][2]) / transforms[1][2];
}
}
return dst;
},
softmax(data: Float32Array): Float32Array {
const max = Math.max(...data);
const d = data.map(y => Math.exp(y - max)).reduce((a, b) => a + b);
return data.map((value, index) => Math.exp(value - max) / d);
},
getClass(data: Float32Array): [string, number] {
const classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'];
const maxProb = Math.max(...data);
return [classes[data.indexOf(maxProb)], maxProb];
}
}
});
</script>
... CSSパートは省略
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment