Skip to content

Commit

Permalink
Save transfer learning implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
JulienVig committed May 29, 2024
1 parent 4a049e4 commit a334219
Showing 1 changed file with 77 additions and 45 deletions.
122 changes: 77 additions & 45 deletions discojs/src/default_tasks/skin_mnist.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ import * as tf from '@tensorflow/tfjs'

import type { Model, Task, TaskProvider } from '../index.js'
import { data, models } from '../index.js'
import baseModel from '../models/mobileNet_v1_025_224.js'
// import baseModel from '../models/mobileNet_v1_025_224.js'

const IMAGE_SIZE = 224
// Using mobilenet requires using image size of 224
const IMAGE_SIZE = 32

export const skinMnist: TaskProvider = {
getTask (): Task {
Expand Down Expand Up @@ -42,22 +43,72 @@ export const skinMnist: TaskProvider = {
const imageChannels = 3
const numOutputClasses = 7

// const model = tf.sequential()

// model.add(
// tf.layers.conv2d({
// inputShape: [IMAGE_SIZE, IMAGE_SIZE, imageChannels],
// filters: 256,
// kernelSize: 3,
// strides: 1,
// kernelInitializer: 'varianceScaling',
// activation: 'relu'
// const mobilenet = await tf.loadLayersModel({
// load: async () => Promise.resolve(baseModel),
// })

// // Get the mobilenet layers up until the last pooling layer (i.e. before the classification layers)
// const x = mobilenet.getLayer('global_average_pooling2d_1')

// // Add dropout
// // const dropout = tf.layers.dropout({ rate: 0.3 }).apply(x.output) as tf.SymbolicTensor

// // Redefine the output layer for the skin mnist categories
// const predictions = tf.layers
// .dense({
// units: numOutputClasses,
// activation: 'softmax',
// kernelInitializer: 'varianceScaling'
// })
// )

// model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }))
// .apply(x.output) as tf.SymbolicTensor

// // Put everything together
// const model = tf.model({ inputs: mobilenet.input, outputs: predictions })

// // Freeze most of the pre-trained layers (84 total layers)
// // Leaves 3 convolution blocks left to retrain
// // You can find the mobilenet architecture at
// // https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json
// for (let i = 0; i < model.layers.length; ++i) {
// const layer = model.layers[i]
// if (i < 79) {
// model.layers[i].trainable = false;
// } else if (layer.getClassName() == 'BatchNormalization') {
// model.layers[i].trainable = false; // Freeze all batch normalization layers
// }
// }

const model = tf.sequential()

model.add(
tf.layers.conv2d({
inputShape: [IMAGE_SIZE, IMAGE_SIZE, imageChannels],
filters: 8,
kernelSize: 3,
strides: 1,
kernelInitializer: 'varianceScaling',
activation: 'relu'
})
)
model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }))
// model.add(tf.layers.dropout({ rate: 0.2 }))

const convFilters = [16]
for (const filters of convFilters) {
console.log(filters)
model.add(
tf.layers.conv2d({
filters: filters,
kernelSize: 3,
strides: 1,
kernelInitializer: 'varianceScaling',
activation: 'relu'
})
)

model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }))
// model.add(tf.layers.dropout({ rate: 0.2 }))
}
// model.add(
// tf.layers.conv2d({
// filters: 128,
Expand All @@ -84,38 +135,19 @@ export const skinMnist: TaskProvider = {
// model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }))
// model.add(tf.layers.dropout({ rate: 0.2 }))

// model.add(tf.layers.flatten())

// model.add(tf.layers.dense({
// units: numOutputClasses,
// kernelInitializer: 'varianceScaling',
// activation: 'softmax'
// }))

const mobilenet = await tf.loadLayersModel({
load: async () => Promise.resolve(baseModel),
})

const x = mobilenet.getLayer('global_average_pooling2d_1')
const predictions = tf.layers
.dense({ units: numOutputClasses, activation: 'softmax', kernelInitializer: 'varianceScaling', name: 'denseModified' })
.apply(x.output) as tf.SymbolicTensor
model.add(tf.layers.flatten())
model.add(tf.layers.dense({
units: 64,
kernelInitializer: 'varianceScaling',
activation: 'relu',
}))

const model = tf.model({
inputs: mobilenet.input,
outputs: predictions,
name: 'mobileNetTransferSkinMNIST'
})

console.log(model.layers.length)

// Freeze most of the pre-trained layers
// Leaves 3 convolution blocks left to retrain
// You can find the mobilenet architecture at
// https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json
for (let i = 0; i < 73; ++i) {
model.layers[i].trainable = false;
}
model.add(tf.layers.dense({
units: numOutputClasses,
kernelInitializer: 'varianceScaling',
activation: 'softmax'
}))

model.compile({
optimizer: tf.train.adam(),
Expand Down

0 comments on commit a334219

Please sign in to comment.