Skip to content

Commit

Permalink
✨ feat(layer_seq): ValueCausalSeq (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
jean-francoisreboud committed Jul 1, 2024
1 parent 6dd84dd commit 8ab07d5
Show file tree
Hide file tree
Showing 12 changed files with 1,609 additions and 51 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file.

## [unreleased]

**layer_seq:** ValueCausalSeq ([126](https://github.com/owkin/GrAIdient/pull/126))\
**layer_seq:** QueryCausalSeq ([125](https://github.com/owkin/GrAIdient/pull/125))\
**layer_seq:** RoPESeq ([124](https://github.com/owkin/GrAIdient/pull/124))\
**layer_seq:** RMSNormSeq ([123](https://github.com/owkin/GrAIdient/pull/123))\
Expand Down
35 changes: 15 additions & 20 deletions Sources/GrAIdient/LayerSeq/QuerySeq.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1236,20 +1236,20 @@ public class QueryCausalSeq: LayerMergeSeq

let query = (_layersPrev[0] as! LayerSeq).neurons!
let key = (_layersPrev[1] as! LayerSeq).neurons!

let size = (_layersPrev[0] as! LayerSeq).nbNeurons / _nbHeadsQuery
let nbBlocksHead = _nbHeadsQuery / _nbHeadsKey

for batch in 0..<batchSize {
for headQuery in 0..<_nbHeadsQuery {
let headKey = headQuery / nbBlocksHead
for seqQ in 0..<sequence {
for seqK in 0..<sequence {
for elem in 0..<nbSameElems
{
if seqK <= seqQ
{
let headKey = _nbHeadsQuery == _nbHeadsKey ?
headQuery : headQuery / _nbHeadsKey
var sum = 0.0

for j in 0..<size
{
let depthPrevKey = j + headKey * size
Expand Down Expand Up @@ -1279,6 +1279,7 @@ public class QueryCausalSeq: LayerMergeSeq

for batch in 0..<batchSize {
for headQuery in 0..<_nbHeadsQuery {
let headKey = headQuery / nbBlocksHead
for seqQ in 0..<sequence {
for seqK in 0..<sequence {
var offset = nbSameElems
Expand All @@ -1289,10 +1290,7 @@ public class QueryCausalSeq: LayerMergeSeq
{
if seqK <= seqQ
{
let headKey = _nbHeadsQuery == _nbHeadsKey ?
headQuery : headQuery / _nbHeadsKey
var sum = 0.0

for j in 0..<size
{
let depthPrevKey = j + headKey * size
Expand Down Expand Up @@ -1361,22 +1359,23 @@ public class QueryCausalSeq: LayerMergeSeq

let query = (_layersPrev[0] as! LayerSeq).neurons!
let key = (_layersPrev[1] as! LayerSeq).neurons!

let nbNeuronsPrevQuery = (_layersPrev[0] as! LayerSeq).nbNeurons
let nbNeuronsPrevKey = (_layersPrev[1] as! LayerSeq).nbNeurons

let size = (_layersPrev[0] as! LayerSeq).nbNeurons / _nbHeadsQuery
let nbBlocksHead = _nbHeadsQuery / _nbHeadsKey

for batch in 0..<batchSize {
for headQuery in 0..<_nbHeadsQuery {
let headKey = headQuery / nbBlocksHead
for seqQ in 0..<sequence {
for seqK in 0..<sequence {
for elem in 0..<nbSameElems
{
if seqK <= seqQ
{
let headKey = _nbHeadsQuery == _nbHeadsKey ?
headQuery : headQuery / _nbHeadsKey
var sum = 0.0

for j in 0..<size
{
let depthPrevKey = j + headKey * size
Expand Down Expand Up @@ -1409,6 +1408,7 @@ public class QueryCausalSeq: LayerMergeSeq

for batch in 0..<batchSize {
for headQuery in 0..<_nbHeadsQuery {
let headKey = headQuery / nbBlocksHead
for seqQ in 0..<sequence {
for seqK in 0..<sequence {
var offset = nbSameElems
Expand All @@ -1419,10 +1419,7 @@ public class QueryCausalSeq: LayerMergeSeq
{
if seqK <= seqQ
{
let headKey = _nbHeadsQuery == _nbHeadsKey ?
headQuery : headQuery / _nbHeadsKey
var sum = 0.0

for j in 0..<size
{
let depthPrevKey = j + headKey * size
Expand Down Expand Up @@ -1487,17 +1484,17 @@ public class QueryCausalSeq: LayerMergeSeq

let query = (_layersPrev[0] as! LayerSeq).neurons!
let key = (_layersPrev[1] as! LayerSeq).neurons!

let size = (_layersPrev[0] as! LayerSeq).nbNeurons / _nbHeadsQuery
let nbBlocksHead = _nbHeadsQuery / _nbHeadsKey

for elem in 0..<batchSize {
for headQuery in 0..<_nbHeadsQuery {
let headKey = headQuery / nbBlocksHead
for seqQ in 0..<sequence {
for seqK in 0...seqQ
{
let headKey = _nbHeadsQuery == _nbHeadsKey ?
headQuery : headQuery / _nbHeadsKey
var sum = 0.0

for j in 0..<size
{
let depthPrevKey = j + headKey * size
Expand Down Expand Up @@ -1569,14 +1566,15 @@ public class QueryCausalSeq: LayerMergeSeq

let query = (_layersPrev[0] as! LayerSeq).neurons!
let key = (_layersPrev[1] as! LayerSeq).neurons!

let size = (_layersPrev[0] as! LayerSeq).nbNeurons / _nbHeadsQuery
let nbBlocksHead = _nbHeadsQuery / _nbHeadsKey

if _layersPrev[0].computeDelta
{
for elem in 0..<batchSize {
for headQuery in 0..<_nbHeadsQuery {
let headKey = _nbHeadsQuery == _nbHeadsKey ?
headQuery : headQuery / _nbHeadsKey
let headKey = headQuery / nbBlocksHead
for seqQ in 0..<sequence {
for j in 0..<size
{
Expand Down Expand Up @@ -1607,9 +1605,6 @@ public class QueryCausalSeq: LayerMergeSeq
}
if _layersPrev[1].computeDelta
{
let nbBlocksHead = _nbHeadsQuery == _nbHeadsKey ?
1 : _nbHeadsQuery / _nbHeadsKey

for elem in 0..<batchSize {
for headKey in 0..<_nbHeadsKey {
for seqK in 0..<sequence {
Expand Down
Loading

0 comments on commit 8ab07d5

Please sign in to comment.